├── llava ├── serve │ ├── __init__.py │ ├── examples │ │ ├── waterview.jpg │ │ └── extreme_ironing.jpg │ ├── register_worker.py │ ├── test_message.py │ └── cli_graph.py ├── __init__.py ├── mol_utils.py ├── model │ ├── __init__.py │ ├── language_model │ │ ├── mpt │ │ │ ├── custom_embedding.py │ │ │ ├── adapt_tokenizer.py │ │ │ ├── blocks.py │ │ │ ├── norm.py │ │ │ └── meta_init_context.py │ │ ├── llava_graph_llama.py │ │ ├── llava_llama.py │ │ └── llava_mpt.py │ ├── utils.py │ ├── consolidate.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ ├── clip_encoder.py │ │ └── moleculeSTM_gnn_model.py │ ├── multimodal_projector │ │ └── builder.py │ ├── apply_delta.py │ └── make_delta.py ├── constants.py ├── train │ ├── train_mem.py │ ├── llava_trainer.py │ └── llama_flash_attn_monkey_patch.py ├── eval │ ├── molecule_metrics │ │ ├── property_metrics.py │ │ ├── fingerprint_metrics.py │ │ ├── mol_translation_selfies.py │ │ └── MoleculeNet_classification.py │ ├── eval_molcap.py │ └── run_llava.py ├── datasets │ ├── property_pred_dataset.py │ ├── retrosynthesis_dataset.py │ ├── collators.py │ ├── MoleculeNet_classification_dataset.py │ ├── forward_pred_dataset.py │ ├── __init__.py │ ├── reagent_pred_dataset.py │ ├── lazy_supervised_dataset.py │ └── smiles2graph.py ├── mm_utils.py └── utils.py ├── assets ├── chebi-20_data │ └── test.txt └── static │ ├── teaser.png │ └── overview.png ├── docs └── static │ ├── images │ ├── xl.png │ ├── user.png │ ├── teaser.png │ ├── overview.png │ ├── dataset-size.png │ └── example │ │ ├── molcap.png │ │ ├── Beta-Amyrin.png │ │ ├── retrosynthesis.png │ │ ├── reagent_prediction.png │ │ └── forward_reaction_prediction.png │ ├── js │ ├── index.js │ └── bulma-slider.min.js │ └── css │ ├── index.css │ └── bulma-carousel.min.css ├── .vscode └── settings.json ├── cli.sh ├── requirements.txt ├── scripts ├── zero2.json ├── mlp │ ├── eval_molcap.sh │ ├── pretrain_mlp.sh │ └── finetune_lora_molcap_mlp.sh ├── merge_lora_weights.py ├── zero3.json ├── eval │ └── molcap.sh ├── zero3_offload.json ├── 13B │ ├── pretrain_13B.sh │ └── finetue_lora_molcap_13B.sh ├── pretrain.sh ├── finetune_lora_MoleculeNet.sh ├── finetune_lora_molcap.sh ├── all │ └── finetune_lora_all.sh ├── freezeLLM │ └── finetune_lora_molcap.sh ├── finetune_lora_property_pred.sh ├── finetune_lora_reagent_pred.sh ├── finetune_lora_retrosynthesis.sh └── finetune_lora_forward_pred.sh ├── pyproject.toml ├── .gitignore ├── Evaluation.md └── README.md /llava/serve/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/chebi-20_data/test.txt: -------------------------------------------------------------------------------- 1 | /cto_labs/AIDD/DATA/MolT5/ChEBI-20_data/test.txt -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM, LlavaGraphLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /assets/static/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/InstructMol/publish/assets/static/teaser.png -------------------------------------------------------------------------------- /docs/static/images/xl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/InstructMol/publish/docs/static/images/xl.png -------------------------------------------------------------------------------- /assets/static/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/InstructMol/publish/assets/static/overview.png -------------------------------------------------------------------------------- /docs/static/images/user.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/InstructMol/publish/docs/static/images/user.png -------------------------------------------------------------------------------- /docs/static/images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/InstructMol/publish/docs/static/images/teaser.png -------------------------------------------------------------------------------- /docs/static/images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/InstructMol/publish/docs/static/images/overview.png -------------------------------------------------------------------------------- /docs/static/images/dataset-size.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/InstructMol/publish/docs/static/images/dataset-size.png -------------------------------------------------------------------------------- /llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/InstructMol/publish/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /docs/static/images/example/molcap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/InstructMol/publish/docs/static/images/example/molcap.png -------------------------------------------------------------------------------- /llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/InstructMol/publish/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /docs/static/images/example/Beta-Amyrin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/InstructMol/publish/docs/static/images/example/Beta-Amyrin.png -------------------------------------------------------------------------------- /docs/static/images/example/retrosynthesis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/InstructMol/publish/docs/static/images/example/retrosynthesis.png -------------------------------------------------------------------------------- /docs/static/images/example/reagent_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/InstructMol/publish/docs/static/images/example/reagent_prediction.png -------------------------------------------------------------------------------- /docs/static/images/example/forward_reaction_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/InstructMol/publish/docs/static/images/example/forward_reaction_prediction.png -------------------------------------------------------------------------------- /llava/mol_utils.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | 3 | def check_smiles_validity(smiles:str)->bool: 4 | # check if valid smiles 5 | m = Chem.MolFromSmiles(smiles,sanitize=False) 6 | if m is None: 7 | return False 8 | return True 9 | 10 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "livePreview.defaultPreviewPath": "/docs/index.html", 3 | "python-envs.defaultEnvManager": "ms-python.python:conda", 4 | "python-envs.defaultPackageManager": "ms-python.python:conda", 5 | "python-envs.pythonProjects": [] 6 | } -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 2 | from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig 3 | from .language_model.llava_graph_llama import LlavaGraphLlamaForCausalLM, LlavaGraphLlamaConfig -------------------------------------------------------------------------------- /cli.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | MODEL_PATH="" # NOTE: Insert path to model here.(e.g., checkpoints/Graph-LLaVA/llava-moleculestm-vicuna-v1-3-7b-pretrain) 3 | 4 | python -m llava.serve.cli_graph \ 5 | --model-path $MODEL_PATH \ 6 | --model-base checkpoints/vicuna-v1-3-7b \ 7 | --graph-checkpoint-path checkpoints/graphmvp.pth \ 8 | --debug 9 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /llava/model/language_model/mpt/custom_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | class SharedEmbedding(nn.Embedding): 7 | 8 | def forward(self, input: Tensor, unembed: bool=False) -> Tensor: 9 | if unembed: 10 | return F.linear(input, self.weight) 11 | return super().forward(input) -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from llava.train.train_drug import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | deepspeed==0.9.5 2 | einops==0.7.0 3 | fastapi==0.104.1 4 | flash_attn==2.2.3.post2 5 | gradio==4.7.1 6 | huggingface_hub==0.17.2 7 | nltk==3.8.1 8 | numpy==1.26.2 9 | ogb==1.3.6 10 | openai==0.28.0 11 | packaging==23.2 12 | peft==0.5.0 13 | Pillow==9.5.0 14 | Pillow==10.1.0 15 | pydantic==1.10.12 16 | python_Levenshtein==0.23.0 17 | rdkit==2023.3.3 18 | Requests==2.31.0 19 | rouge_score==0.1.2 20 | scikit_learn==1.3.0 21 | selfies==2.1.1 22 | shortuuid==1.0.11 23 | torch==1.12.1+cu116 24 | torch_geometric==2.3.1 25 | torch_scatter==2.1.0+pt112cu116 26 | tqdm==4.66.1 27 | transformers==4.33.2 28 | uvicorn==0.24.0.post1 29 | -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/mlp/eval_molcap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GRAPH_TOWER="moleculestm" 4 | if [ "$GRAPH_TOWER" == "graphmvp" ]; then 5 | INIT_CHECKPOINT_GNN="./checkpoints/graphmvp.pth" 6 | elif [ "$GRAPH_TOWER" == "moleculestm" ]; then 7 | INIT_CHECKPOINT_GNN="./checkpoints/MoleculeSTM/molecule_model.pth" 8 | else 9 | echo "Not supported graph tower" 10 | fi 11 | 12 | MODEL_PATH=checkpoints/Graph-LLaVA-mlp/molcap-llava-moleculestm-vicuna-v1-3-7b-finetune_lora 13 | EPOCH=20 14 | OUT_FILE=eval_result/mlp/$GRAPH_TOWER-chebi20-molcap-lora-${EPOCH}ep.jsonl 15 | 16 | python -m llava.eval.model_molcap \ 17 | --model-path $MODEL_PATH \ 18 | --in-file assets/chebi-20_data/test.txt \ 19 | --answers-file $OUT_FILE \ 20 | --graph-checkpoint-path $INIT_CHECKPOINT_GNN \ 21 | --model-base checkpoints/vicuna-v1-3-7b \ 22 | --batch_size 1 \ 23 | --debug -------------------------------------------------------------------------------- /llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /scripts/merge_lora_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from llava.model.builder import load_pretrained_model 3 | from llava.mm_utils import get_model_name_from_path 4 | 5 | 6 | def merge_lora(args): 7 | model_name = get_model_name_from_path(args.model_path) 8 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu') 9 | 10 | model.save_pretrained(args.save_model_path) 11 | tokenizer.save_pretrained(args.save_model_path) 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model-path", type=str, required=True) 17 | parser.add_argument("--model-base", type=str, required=True) 18 | parser.add_argument("--save-model-path", type=str, required=True) 19 | 20 | args = parser.parse_args() 21 | 22 | merge_lora(args) 23 | -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /scripts/eval/molcap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GRAPH_TOWER="moleculestm" 4 | if [ "$GRAPH_TOWER" == "graphmvp" ]; then 5 | INIT_CHECKPOINT_GNN="./checkpoints/graphmvp.pth" 6 | elif [ "$GRAPH_TOWER" == "moleculestm" ]; then 7 | INIT_CHECKPOINT_GNN="./checkpoints/MoleculeSTM/molecule_model.pth" 8 | else 9 | echo "Not supported graph tower" 10 | fi 11 | 12 | MODEL_PATH=checkpoints/llava-moleculestm-vicuna-v1-3-7b-finetune_lora 13 | EPOCH=20 14 | OUT_FILE=eval_result/$GRAPH_TOWER-chebi20-molcap-lora-${EPOCH}ep.jsonl 15 | 16 | python -m llava.eval.model_molcap \ 17 | --model-path $MODEL_PATH \ 18 | --in-file assets/chebi-20_data/test.txt \ 19 | --answers-file $OUT_FILE \ 20 | --graph-checkpoint-path $INIT_CHECKPOINT_GNN \ 21 | --model-base checkpoints/vicuna-v1-3-7b \ 22 | --batch_size 4 \ 23 | --debug 24 | 25 | # # evaluation 26 | # python -m llava.eval.eval_molcap \ 27 | # --molcap_result_file $OUT_FILE \ 28 | # --text2mol_bert_path checkpoints/scibert_scivocab_uncased -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "InstructMol" 7 | version = "1.0.1" 8 | description = "InstructMol: Multi-Modal Integration for Building a Versatile and Reliable Molecular Assistant in Drug Discovery" 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "einops", "fastapi", "gradio==3.35.2", "markdown2[all]", "numpy", 17 | "requests", "sentencepiece", "tokenizers>=0.12.1", 18 | "torchvision", "uvicorn", "wandb", 19 | "shortuuid", "httpx==0.24.0", 20 | "deepspeed==0.9.5", 21 | "peft==0.4.0", 22 | "transformers==4.31.0", 23 | "accelerate==0.21.0", 24 | "bitsandbytes==0.41.0", 25 | "scikit-learn==1.2.2", 26 | "sentencepiece==0.1.99", 27 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", 28 | "gradio_client==0.2.9" 29 | ] 30 | 31 | [tool.setuptools.packages.find] 32 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 33 | 34 | [tool.wheel] 35 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 36 | -------------------------------------------------------------------------------- /llava/eval/molecule_metrics/property_metrics.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from sklearn.metrics import mean_absolute_error 4 | from typing import List 5 | 6 | def compute_mae(eval_result_file:str, except_idxs:List[int]=[]): 7 | with open(eval_result_file) as f: 8 | results = json.load(f) 9 | gts = [] 10 | preds = [] 11 | for i, result in enumerate(results): 12 | if i in except_idxs: 13 | continue 14 | pred = result['pred_self'] 15 | gt = result['gt_self'] 16 | gts.append(float(gt)) 17 | preds.append(float(pred)) 18 | return mean_absolute_error(gts, preds) 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--eval_result_file", type=str, required=True) 23 | args = parser.parse_args() 24 | # read except_idxs 25 | with open('/cto_labs/AIDD/DATA/Mol-Instructions/Molecule-oriented_Instructions/property_overlap.txt', 'r') as f: 26 | except_idxs = [int(line.split('\t')[0]) for line in f.readlines()] 27 | mae = compute_mae(args.eval_result_file, except_idxs) 28 | print(mae) 29 | 30 | 31 | """ 32 | # property_pred 33 | TASK=property_pred 34 | EPOCH=5 35 | GRAPH_TOWER=moleculestm 36 | python -m llava.eval.molecule_metrics.property_metrics \ 37 | --eval_result_file=eval_result/$GRAPH_TOWER-$TASK-${EPOCH}ep.jsonl 38 | """ -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | from .clip_encoder import CLIPVisionTower 2 | from .gnn_graphmvp import GraphMVP 3 | from .moleculeSTM_gnn_model import GNN_graphpred, GNN 4 | 5 | 6 | def build_vision_tower(vision_tower_cfg, **kwargs): 7 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 8 | if vision_tower.startswith("openai") or vision_tower.startswith("laion"): 9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 10 | 11 | raise ValueError(f'Unknown vision tower: {vision_tower}') 12 | 13 | 14 | def build_graph_tower(graph_tower_cfg, **kwargs): 15 | graph_tower = getattr(graph_tower_cfg, 'mm_graph_tower', getattr(graph_tower_cfg, 'graph_tower', None)) 16 | if graph_tower.startswith("graphmvp"): 17 | return GraphMVP(config=graph_tower_cfg) 18 | elif graph_tower.startswith("moleculestm"): 19 | # actually, 'graph_tower_cfg' is identical to 'model_args' 20 | molecule_node_model = GNN( 21 | num_layer=graph_tower_cfg.gin_num_layers, 22 | emb_dim=graph_tower_cfg.gin_hidden_dim, 23 | JK='last', # default to 'last' 24 | drop_ratio=graph_tower_cfg.drop_ratio, 25 | gnn_type='gin', # default to 'gin' 26 | ) 27 | return GNN_graphpred( 28 | emb_dim=graph_tower_cfg.gin_hidden_dim, 29 | graph_pooling=graph_tower_cfg.graph_pooling, 30 | molecule_node_model=molecule_node_model, 31 | init_checkpoint=graph_tower_cfg.init_checkpoint, 32 | ) 33 | 34 | raise ValueError(f'Unknown graph tower: {graph_tower}') -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_xmodal_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') -------------------------------------------------------------------------------- /llava/model/language_model/mpt/adapt_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast 3 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 4 | NUM_SENTINEL_TOKENS: int = 100 5 | 6 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): 7 | """Adds sentinel tokens and padding token (if missing). 8 | 9 | Expands the tokenizer vocabulary to include sentinel tokens 10 | used in mixture-of-denoiser tasks as well as a padding token. 11 | 12 | All added tokens are added as special tokens. No tokens are 13 | added if sentinel tokens and padding token already exist. 14 | """ 15 | sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)] 16 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True) 17 | if tokenizer.pad_token is None: 18 | tokenizer.add_tokens('', special_tokens=True) 19 | tokenizer.pad_token = '' 20 | assert tokenizer.pad_token_id is not None 21 | sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)]) 22 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids 23 | tokenizer.sentinel_token_ids = _sentinel_token_ids 24 | 25 | class AutoTokenizerForMOD(AutoTokenizer): 26 | """AutoTokenizer + Adaptation for MOD. 27 | 28 | A simple wrapper around AutoTokenizer to make instantiating 29 | an MOD-adapted tokenizer a bit easier. 30 | 31 | MOD-adapted tokenizers have sentinel tokens (e.g., ), 32 | a padding token, and a property to get the token ids of the 33 | sentinel tokens. 34 | """ 35 | 36 | @classmethod 37 | def from_pretrained(cls, *args, **kwargs): 38 | """See `AutoTokenizer.from_pretrained` docstring.""" 39 | tokenizer = super().from_pretrained(*args, **kwargs) 40 | adapt_tokenizer_for_denoising(tokenizer) 41 | return tokenizer -------------------------------------------------------------------------------- /scripts/13B/pretrain_13B.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | MODEL_VERSION=vicuna-v1-3-13b 5 | # MODEL_VERSION=llama-2-7b-chat 6 | 7 | ########### DO NOT CHANGE ########### 8 | ########### USE THIS FOR BOTH ########### 9 | PROMPT_VERSION=plain 10 | ########### DO NOT CHANGE ########### 11 | 12 | GRAPH_TOWER="moleculestm" 13 | if [ "$GRAPH_TOWER" == "graphmvp" ]; then 14 | INIT_CHECKPOINT_GNN="./checkpoints/graphmvp.pth" 15 | elif [ "$GRAPH_TOWER" == "moleculestm" ]; then 16 | INIT_CHECKPOINT_GNN="./checkpoints/MoleculeSTM/molecule_model.pth" 17 | else 18 | echo "Not supported graph tower" 19 | fi 20 | 21 | CHECKPOINT_FOLDER_PREFIX="./checkpoints/Graph-LLaVA-13B" 22 | DATA_PATH="" # NOTE: Insert path to data here.(e.g., pubchemsft_desc/train.pkl) 23 | 24 | deepspeed llava/train/train_mem.py \ 25 | --deepspeed scripts/zero2.json \ 26 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 27 | --version $PROMPT_VERSION \ 28 | --data_path $DATA_PATH \ 29 | --graph_tower $GRAPH_TOWER \ 30 | --init_checkpoint $INIT_CHECKPOINT_GNN \ 31 | --tune_mm_mlp_adapter True \ 32 | --mm_use_im_start_end False \ 33 | --mm_use_im_patch_token False \ 34 | --bf16 True \ 35 | --output_dir ./checkpoints/llava-$GRAPH_TOWER-$MODEL_VERSION-pretrain \ 36 | --num_train_epochs 3 \ 37 | --per_device_train_batch_size 12 \ 38 | --per_device_eval_batch_size 4 \ 39 | --gradient_accumulation_steps 1 \ 40 | --evaluation_strategy "no" \ 41 | --save_strategy "steps" \ 42 | --save_steps 2000 \ 43 | --save_total_limit 1 \ 44 | --learning_rate 2e-3 \ 45 | --weight_decay 0. \ 46 | --warmup_ratio 0.03 \ 47 | --lr_scheduler_type "cosine" \ 48 | --logging_steps 1 \ 49 | --tf32 True \ 50 | --model_max_length 2048 \ 51 | --gradient_checkpointing True \ 52 | --dataloader_num_workers 4 \ 53 | --lazy_preprocess True \ 54 | --report_to none -------------------------------------------------------------------------------- /scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | MODEL_VERSION=vicuna-v1-3-7b 5 | 6 | ########### DO NOT CHANGE ########### 7 | ########### USE THIS FOR BOTH ########### 8 | PROMPT_VERSION=plain 9 | ########### DO NOT CHANGE ########### 10 | 11 | GRAPH_TOWER="moleculestm" 12 | if [ "$GRAPH_TOWER" == "graphmvp" ]; then 13 | INIT_CHECKPOINT_GNN="./checkpoints/graphmvp.pth" 14 | elif [ "$GRAPH_TOWER" == "moleculestm" ]; then 15 | INIT_CHECKPOINT_GNN="./checkpoints/MoleculeSTM/molecule_model.pth" 16 | else 17 | echo "Not supported graph tower" 18 | fi 19 | 20 | CHECKPOINT_FOLDER_PREFIX="./checkpoints/Graph-LLaVA" 21 | DATA_PATH="/home/public_space/zhangxiaohong/public_user/PubChemSFT/train.pkl" # Path to the PubChem dataset 22 | 23 | deepspeed --include=localhost:4,5,6,7 llava/train/train_mem.py \ 24 | --deepspeed scripts/zero2.json \ 25 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 26 | --version $PROMPT_VERSION \ 27 | --data_path $DATA_PATH \ 28 | --graph_tower $GRAPH_TOWER \ 29 | --init_checkpoint $INIT_CHECKPOINT_GNN \ 30 | --tune_mm_mlp_adapter True \ 31 | --mm_use_im_start_end False \ 32 | --mm_use_im_patch_token False \ 33 | --bf16 True \ 34 | --output_dir $CHECKPOINT_FOLDER_PREFIX/llava-$GRAPH_TOWER-$MODEL_VERSION-pretrain \ 35 | --num_train_epochs 5 \ 36 | --per_device_train_batch_size 32 \ 37 | --per_device_eval_batch_size 4 \ 38 | --gradient_accumulation_steps 1 \ 39 | --evaluation_strategy "no" \ 40 | --save_strategy "steps" \ 41 | --save_steps 4000 \ 42 | --save_total_limit 1 \ 43 | --learning_rate 5e-5 \ 44 | --weight_decay 0. \ 45 | --warmup_ratio 0.03 \ 46 | --lr_scheduler_type "cosine" \ 47 | --logging_steps 1 \ 48 | --tf32 True \ 49 | --model_max_length 2048 \ 50 | --gradient_checkpointing True \ 51 | --dataloader_num_workers 4 \ 52 | --lazy_preprocess True \ 53 | --report_to tensorboard -------------------------------------------------------------------------------- /scripts/mlp/pretrain_mlp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | MODEL_VERSION=vicuna-v1-3-7b 5 | # MODEL_VERSION=llama-2-7b-chat 6 | 7 | ########### DO NOT CHANGE ########### 8 | ########### USE THIS FOR BOTH ########### 9 | PROMPT_VERSION=plain 10 | ########### DO NOT CHANGE ########### 11 | 12 | GRAPH_TOWER="moleculestm" 13 | if [ "$GRAPH_TOWER" == "graphmvp" ]; then 14 | INIT_CHECKPOINT_GNN="./checkpoints/graphmvp.pth" 15 | elif [ "$GRAPH_TOWER" == "moleculestm" ]; then 16 | INIT_CHECKPOINT_GNN="./checkpoints/MoleculeSTM/molecule_model.pth" 17 | else 18 | echo "Not supported graph tower" 19 | fi 20 | 21 | CHECKPOINT_FOLDER_PREFIX="./checkpoints/Graph-LLaVA-mlp" 22 | DATA_PATH="" # NOTE: Insert path to data here.(e.g., pubchemsft_desc/train.pkl) 23 | 24 | deepspeed llava/train/train_mem.py \ 25 | --deepspeed scripts/zero2.json \ 26 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 27 | --version $PROMPT_VERSION \ 28 | --data_path $DATA_PATH \ 29 | --graph_tower $GRAPH_TOWER \ 30 | --init_checkpoint $INIT_CHECKPOINT_GNN \ 31 | --mm_projector_type mlp2x_gelu \ 32 | --tune_mm_mlp_adapter True \ 33 | --mm_use_im_start_end False \ 34 | --mm_use_im_patch_token False \ 35 | --bf16 True \ 36 | --output_dir $CHECKPOINT_FOLDER_PREFIX/llava-$GRAPH_TOWER-$MODEL_VERSION-pretrain \ 37 | --num_train_epochs 3 \ 38 | --per_device_train_batch_size 16 \ 39 | --per_device_eval_batch_size 4 \ 40 | --gradient_accumulation_steps 1 \ 41 | --evaluation_strategy "no" \ 42 | --save_strategy "steps" \ 43 | --save_steps 2000 \ 44 | --save_total_limit 1 \ 45 | --learning_rate 2e-3 \ 46 | --weight_decay 0. \ 47 | --warmup_ratio 0.03 \ 48 | --lr_scheduler_type "cosine" \ 49 | --logging_steps 1 \ 50 | --tf32 True \ 51 | --model_max_length 2048 \ 52 | --gradient_checkpointing True \ 53 | --dataloader_num_workers 4 \ 54 | --lazy_preprocess True \ 55 | --report_to none 56 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /scripts/finetune_lora_MoleculeNet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | PROMPT_VERSION=v1 7 | MODEL_VERSION="vicuna-v1-3-7b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | # PROMPT_VERSION="llava_llama_2" 12 | # MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | 15 | GRAPH_TOWER="moleculestm" 16 | if [ "$GRAPH_TOWER" == "graphmvp" ]; then 17 | INIT_CHECKPOINT_GNN="./checkpoints/graphmvp.pth" 18 | elif [ "$GRAPH_TOWER" == "moleculestm" ]; then 19 | INIT_CHECKPOINT_GNN="./checkpoints/MoleculeSTM/molecule_model.pth" 20 | else 21 | echo "Not supported graph tower" 22 | fi 23 | 24 | CHECKPOINT_FOLDER_PREFIX="./checkpoints/Graph-LLaVA" 25 | TASK="MoleculeNet" 26 | 27 | deepspeed llava/train/train_mem.py \ 28 | --deepspeed scripts/zero2.json \ 29 | --lora_enable True \ 30 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 31 | --version $PROMPT_VERSION \ 32 | --data_path /cto_labs/AIDD/DATA/MoleculeNet \ 33 | --data_type $TASK \ 34 | --graph_tower $GRAPH_TOWER \ 35 | --init_checkpoint $INIT_CHECKPOINT_GNN \ 36 | --pretrain_mm_mlp_adapter $CHECKPOINT_FOLDER_PREFIX/llava-$GRAPH_TOWER-$MODEL_VERSION-pretrain/mm_projector.bin \ 37 | --mm_use_im_start_end False \ 38 | --mm_use_im_patch_token False \ 39 | --bf16 True \ 40 | --output_dir $CHECKPOINT_FOLDER_PREFIX/$TASK-llava-$GRAPH_TOWER-$MODEL_VERSION-finetune_lora \ 41 | --num_train_epochs 20 \ 42 | --per_device_train_batch_size 32 \ 43 | --per_device_eval_batch_size 4 \ 44 | --gradient_accumulation_steps 1 \ 45 | --evaluation_strategy "no" \ 46 | --save_strategy "epoch" \ 47 | --save_total_limit 1 \ 48 | --learning_rate 4e-5 \ 49 | --weight_decay 0. \ 50 | --warmup_ratio 0.03 \ 51 | --lr_scheduler_type "cosine" \ 52 | --logging_steps 1 \ 53 | --tf32 True \ 54 | --model_max_length 2048 \ 55 | --gradient_checkpointing True \ 56 | --lazy_preprocess True \ 57 | --dataloader_num_workers 4 \ 58 | --report_to none 59 | -------------------------------------------------------------------------------- /scripts/finetune_lora_molcap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | PROMPT_VERSION=v1 7 | MODEL_VERSION="vicuna-v1-3-7b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | # PROMPT_VERSION="llava_llama_2" 12 | # MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | 15 | GRAPH_TOWER="moleculestm" 16 | if [ "$GRAPH_TOWER" == "graphmvp" ]; then 17 | INIT_CHECKPOINT_GNN="./checkpoints/graphmvp.pth" 18 | elif [ "$GRAPH_TOWER" == "moleculestm" ]; then 19 | INIT_CHECKPOINT_GNN="./checkpoints/MoleculeSTM/molecule_model.pth" 20 | else 21 | echo "Not supported graph tower" 22 | fi 23 | 24 | CHECKPOINT_FOLDER_PREFIX="./checkpoints/Graph-LLaVA" 25 | TASK="molcap" 26 | 27 | deepspeed llava/train/train_mem.py \ 28 | --deepspeed scripts/zero2.json \ 29 | --lora_enable True \ 30 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 31 | --version $PROMPT_VERSION \ 32 | --data_path /cto_labs/AIDD/DATA/MolT5/ChEBI-20_data/train.pkl \ 33 | --graph_tower $GRAPH_TOWER \ 34 | --init_checkpoint $INIT_CHECKPOINT_GNN \ 35 | --pretrain_mm_mlp_adapter $CHECKPOINT_FOLDER_PREFIX/llava-$GRAPH_TOWER-$MODEL_VERSION-pretrain/checkpoint-48000/mm_projector.bin \ 36 | --mm_use_im_start_end False \ 37 | --mm_use_im_patch_token False \ 38 | --bf16 True \ 39 | --output_dir $CHECKPOINT_FOLDER_PREFIX/$TASK-llava-$GRAPH_TOWER-$MODEL_VERSION-finetune_lora \ 40 | --num_train_epochs 50 \ 41 | --per_device_train_batch_size 16 \ 42 | --per_device_eval_batch_size 4 \ 43 | --gradient_accumulation_steps 1 \ 44 | --evaluation_strategy "no" \ 45 | --save_strategy "epoch" \ 46 | --save_total_limit 10 \ 47 | --learning_rate 8e-5 \ 48 | --weight_decay 0. \ 49 | --warmup_ratio 0.03 \ 50 | --lr_scheduler_type "cosine" \ 51 | --logging_steps 1 \ 52 | --tf32 True \ 53 | --model_max_length 2048 \ 54 | --gradient_checkpointing True \ 55 | --lazy_preprocess True \ 56 | --dataloader_num_workers 4 \ 57 | --report_to none 58 | -------------------------------------------------------------------------------- /scripts/all/finetune_lora_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | PROMPT_VERSION=v1 7 | MODEL_VERSION="vicuna-v1-3-7b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | # PROMPT_VERSION="llava_llama_2" 12 | # MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | 15 | GRAPH_TOWER="moleculestm" 16 | if [ "$GRAPH_TOWER" == "graphmvp" ]; then 17 | INIT_CHECKPOINT_GNN="./checkpoints/graphmvp.pth" 18 | elif [ "$GRAPH_TOWER" == "moleculestm" ]; then 19 | INIT_CHECKPOINT_GNN="./checkpoints/MoleculeSTM/molecule_model.pth" 20 | else 21 | echo "Not supported graph tower" 22 | fi 23 | 24 | CHECKPOINT_FOLDER_PREFIX="./checkpoints/Graph-LLaVA" 25 | TASK="all" 26 | 27 | # 8 GPUs 28 | deepspeed llava/train/train_mem.py \ 29 | --deepspeed scripts/zero2.json \ 30 | --lora_enable True \ 31 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 32 | --version $PROMPT_VERSION \ 33 | --data_path "" \ 34 | --data_type $TASK \ 35 | --graph_tower $GRAPH_TOWER \ 36 | --init_checkpoint $INIT_CHECKPOINT_GNN \ 37 | --pretrain_mm_mlp_adapter $CHECKPOINT_FOLDER_PREFIX/llava-$GRAPH_TOWER-$MODEL_VERSION-pretrain/checkpoint-48000/mm_projector.bin \ 38 | --mm_use_im_start_end False \ 39 | --mm_use_im_patch_token False \ 40 | --bf16 True \ 41 | --output_dir $CHECKPOINT_FOLDER_PREFIX/graph-text-molgen/$TASK-llava-$GRAPH_TOWER-$MODEL_VERSION-finetune_lora \ 42 | --num_train_epochs 10 \ 43 | --per_device_train_batch_size 32 \ 44 | --per_device_eval_batch_size 4 \ 45 | --gradient_accumulation_steps 1 \ 46 | --evaluation_strategy "no" \ 47 | --save_strategy "epoch" \ 48 | --save_total_limit 10 \ 49 | --learning_rate 2e-4 \ 50 | --weight_decay 0. \ 51 | --warmup_ratio 0.03 \ 52 | --lr_scheduler_type "cosine" \ 53 | --logging_steps 1 \ 54 | --tf32 True \ 55 | --model_max_length 2048 \ 56 | --gradient_checkpointing True \ 57 | --lazy_preprocess True \ 58 | --dataloader_num_workers 16 \ 59 | --report_to none 60 | -------------------------------------------------------------------------------- /scripts/13B/finetue_lora_molcap_13B.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | PROMPT_VERSION=v1 7 | MODEL_VERSION="vicuna-v1-3-13b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | # PROMPT_VERSION="llava_llama_2" 12 | # MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | 15 | GRAPH_TOWER="moleculestm" 16 | if [ "$GRAPH_TOWER" == "graphmvp" ]; then 17 | INIT_CHECKPOINT_GNN="./checkpoints/graphmvp.pth" 18 | elif [ "$GRAPH_TOWER" == "moleculestm" ]; then 19 | INIT_CHECKPOINT_GNN="./checkpoints/MoleculeSTM/molecule_model.pth" 20 | else 21 | echo "Not supported graph tower" 22 | fi 23 | 24 | CHECKPOINT_FOLDER_PREFIX="./checkpoints/Graph-LLaVA-13B" 25 | DATA_PATH="" # NOTE: Insert path to data here.(e.g., ChEBI-20_data/train.pkl) 26 | TASK="molcap" 27 | 28 | deepspeed llava/train/train_mem.py \ 29 | --deepspeed scripts/zero2.json \ 30 | --lora_enable True \ 31 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 32 | --version $PROMPT_VERSION \ 33 | --data_path $DATA_PATH 34 | --graph_tower $GRAPH_TOWER \ 35 | --init_checkpoint $INIT_CHECKPOINT_GNN \ 36 | --pretrain_mm_mlp_adapter $CHECKPOINT_FOLDER_PREFIX/llava-$GRAPH_TOWER-$MODEL_VERSION-pretrain/mm_projector.bin \ 37 | --mm_use_im_start_end False \ 38 | --mm_use_im_patch_token False \ 39 | --bf16 True \ 40 | --output_dir $CHECKPOINT_FOLDER_PREFIX/$TASK-llava-$GRAPH_TOWER-$MODEL_VERSION-finetune_lora \ 41 | --num_train_epochs 20 \ 42 | --per_device_train_batch_size 16 \ 43 | --per_device_eval_batch_size 4 \ 44 | --gradient_accumulation_steps 1 \ 45 | --evaluation_strategy "no" \ 46 | --save_strategy "epoch" \ 47 | --save_total_limit 10 \ 48 | --learning_rate 8e-5 \ 49 | --weight_decay 0. \ 50 | --warmup_ratio 0.03 \ 51 | --lr_scheduler_type "cosine" \ 52 | --logging_steps 1 \ 53 | --tf32 True \ 54 | --model_max_length 2048 \ 55 | --gradient_checkpointing True \ 56 | --lazy_preprocess True \ 57 | --dataloader_num_workers 4 \ 58 | --report_to none 59 | -------------------------------------------------------------------------------- /scripts/freezeLLM/finetune_lora_molcap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | PROMPT_VERSION=v1 7 | MODEL_VERSION="vicuna-v1-3-7b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | # PROMPT_VERSION="llava_llama_2" 12 | # MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | 15 | GRAPH_TOWER="moleculestm" 16 | if [ "$GRAPH_TOWER" == "graphmvp" ]; then 17 | INIT_CHECKPOINT_GNN="./checkpoints/graphmvp.pth" 18 | elif [ "$GRAPH_TOWER" == "moleculestm" ]; then 19 | INIT_CHECKPOINT_GNN="./checkpoints/MoleculeSTM/molecule_model.pth" 20 | else 21 | echo "Not supported graph tower" 22 | fi 23 | 24 | CHECKPOINT_FOLDER_PREFIX="./checkpoints/Graph-LLaVA-freezeLLM" 25 | DATA_PATH="" # NOTE: Insert path to data here.(e.g., ChEBI-20_data/train.pkl) 26 | TASK="molcap" 27 | 28 | deepspeed llava/train/train_mem.py \ 29 | --deepspeed scripts/zero2.json \ 30 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 31 | --version $PROMPT_VERSION \ 32 | --data_path $DATA_PATH \ 33 | --graph_tower $GRAPH_TOWER \ 34 | --init_checkpoint $INIT_CHECKPOINT_GNN \ 35 | --tune_mm_mlp_adapter True \ 36 | --pretrain_mm_mlp_adapter ./checkpoints/Graph-LLaVA/llava-$GRAPH_TOWER-$MODEL_VERSION-pretrain/mm_projector.bin \ 37 | --mm_use_im_start_end False \ 38 | --mm_use_im_patch_token False \ 39 | --bf16 True \ 40 | --output_dir $CHECKPOINT_FOLDER_PREFIX/$TASK-llava-$GRAPH_TOWER-$MODEL_VERSION-finetune_lora \ 41 | --num_train_epochs 20 \ 42 | --per_device_train_batch_size 16 \ 43 | --per_device_eval_batch_size 4 \ 44 | --gradient_accumulation_steps 1 \ 45 | --evaluation_strategy "no" \ 46 | --save_strategy "epoch" \ 47 | --save_total_limit 5 \ 48 | --learning_rate 8e-5 \ 49 | --weight_decay 0. \ 50 | --warmup_ratio 0.03 \ 51 | --lr_scheduler_type "cosine" \ 52 | --logging_steps 1 \ 53 | --tf32 True \ 54 | --model_max_length 2048 \ 55 | --gradient_checkpointing True \ 56 | --lazy_preprocess True \ 57 | --dataloader_num_workers 4 \ 58 | --report_to none 59 | -------------------------------------------------------------------------------- /scripts/mlp/finetune_lora_molcap_mlp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | PROMPT_VERSION=v1 7 | MODEL_VERSION="vicuna-v1-3-7b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | # PROMPT_VERSION="llava_llama_2" 12 | # MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | 15 | GRAPH_TOWER="moleculestm" 16 | if [ "$GRAPH_TOWER" == "graphmvp" ]; then 17 | INIT_CHECKPOINT_GNN="./checkpoints/graphmvp.pth" 18 | elif [ "$GRAPH_TOWER" == "moleculestm" ]; then 19 | INIT_CHECKPOINT_GNN="./checkpoints/MoleculeSTM/molecule_model.pth" 20 | else 21 | echo "Not supported graph tower" 22 | fi 23 | 24 | CHECKPOINT_FOLDER_PREFIX="./checkpoints/Graph-LLaVA-mlp" 25 | DATA_PATH="" # NOTE: Insert path to data here.(e.g., ChEBI-20_data/train.pkl) 26 | TASK="molcap" 27 | 28 | deepspeed llava/train/train_mem.py \ 29 | --deepspeed scripts/zero2.json \ 30 | --lora_enable True \ 31 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 32 | --version $PROMPT_VERSION \ 33 | --data_path $DATA_PATH \ 34 | --graph_tower $GRAPH_TOWER \ 35 | --init_checkpoint $INIT_CHECKPOINT_GNN \ 36 | --mm_projector_type mlp2x_gelu \ 37 | --pretrain_mm_mlp_adapter $CHECKPOINT_FOLDER_PREFIX/llava-$GRAPH_TOWER-$MODEL_VERSION-pretrain/mm_projector.bin \ 38 | --mm_use_im_start_end False \ 39 | --mm_use_im_patch_token False \ 40 | --bf16 True \ 41 | --output_dir $CHECKPOINT_FOLDER_PREFIX/$TASK-llava-$GRAPH_TOWER-$MODEL_VERSION-finetune_lora \ 42 | --num_train_epochs 20 \ 43 | --per_device_train_batch_size 16 \ 44 | --per_device_eval_batch_size 4 \ 45 | --gradient_accumulation_steps 1 \ 46 | --evaluation_strategy "no" \ 47 | --save_strategy "epoch" \ 48 | --save_total_limit 10 \ 49 | --learning_rate 4e-5 \ 50 | --weight_decay 0. \ 51 | --warmup_ratio 0.03 \ 52 | --lr_scheduler_type "cosine" \ 53 | --logging_steps 1 \ 54 | --tf32 True \ 55 | --model_max_length 2048 \ 56 | --gradient_checkpointing True \ 57 | --lazy_preprocess True \ 58 | --dataloader_num_workers 4 \ 59 | --report_to none 60 | -------------------------------------------------------------------------------- /scripts/finetune_lora_property_pred.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | PROMPT_VERSION=v1 7 | MODEL_VERSION="vicuna-v1-3-7b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | # PROMPT_VERSION="llava_llama_2" 12 | # MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | 15 | GRAPH_TOWER="moleculestm" 16 | if [ "$GRAPH_TOWER" == "graphmvp" ]; then 17 | INIT_CHECKPOINT_GNN="./checkpoints/graphmvp.pth" 18 | elif [ "$GRAPH_TOWER" == "moleculestm" ]; then 19 | INIT_CHECKPOINT_GNN="./checkpoints/MoleculeSTM/molecule_model.pth" 20 | else 21 | echo "Not supported graph tower" 22 | fi 23 | 24 | CHECKPOINT_FOLDER_PREFIX="./checkpoints/Graph-LLaVA" 25 | TASK="property_pred" 26 | 27 | deepspeed llava/train/train_mem.py \ 28 | --deepspeed scripts/zero2.json \ 29 | --lora_enable True \ 30 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 31 | --version $PROMPT_VERSION \ 32 | --data_path /cto_labs/AIDD/DATA/Mol-Instructions/Molecule-oriented_Instructions/property_prediction_train.json \ 33 | --data_type $TASK \ 34 | --graph_tower $GRAPH_TOWER \ 35 | --init_checkpoint $INIT_CHECKPOINT_GNN \ 36 | --pretrain_mm_mlp_adapter $CHECKPOINT_FOLDER_PREFIX/llava-$GRAPH_TOWER-$MODEL_VERSION-pretrain/mm_projector.bin \ 37 | --mm_use_im_start_end False \ 38 | --mm_use_im_patch_token False \ 39 | --bf16 True \ 40 | --output_dir $CHECKPOINT_FOLDER_PREFIX/graph-text-molgen/$TASK-llava-$GRAPH_TOWER-$MODEL_VERSION-finetune_lora-add_seqs \ 41 | --num_train_epochs 5 \ 42 | --per_device_train_batch_size 16 \ 43 | --per_device_eval_batch_size 4 \ 44 | --gradient_accumulation_steps 1 \ 45 | --evaluation_strategy "no" \ 46 | --save_strategy "epoch" \ 47 | --save_total_limit 1 \ 48 | --learning_rate 2e-5 \ 49 | --weight_decay 0. \ 50 | --warmup_ratio 0.03 \ 51 | --lr_scheduler_type "cosine" \ 52 | --logging_steps 1 \ 53 | --tf32 True \ 54 | --model_max_length 2048 \ 55 | --gradient_checkpointing True \ 56 | --lazy_preprocess True \ 57 | --dataloader_num_workers 4 \ 58 | --report_to none 59 | -------------------------------------------------------------------------------- /llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default= 59 | "Tell me a story with more than 1000 words.") 60 | args = parser.parse_args() 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /scripts/finetune_lora_reagent_pred.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | PROMPT_VERSION=v1 7 | MODEL_VERSION="vicuna-v1-3-7b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | # PROMPT_VERSION="llava_llama_2" 12 | # MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | 15 | GRAPH_TOWER="moleculestm" 16 | if [ "$GRAPH_TOWER" == "graphmvp" ]; then 17 | INIT_CHECKPOINT_GNN="./checkpoints/graphmvp.pth" 18 | elif [ "$GRAPH_TOWER" == "moleculestm" ]; then 19 | INIT_CHECKPOINT_GNN="./checkpoints/MoleculeSTM/molecule_model.pth" 20 | else 21 | echo "Not supported graph tower" 22 | fi 23 | 24 | CHECKPOINT_FOLDER_PREFIX="./checkpoints/Graph-LLaVA" 25 | TASK="reagent_pred" 26 | 27 | deepspeed llava/train/train_mem.py \ 28 | --deepspeed scripts/zero2.json \ 29 | --lora_enable True \ 30 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 31 | --version $PROMPT_VERSION \ 32 | --data_path /shared_space/caohe/AIDD/DATA/Mol-Instructions/Molecule-oriented_Instructions/reagent_prediction_train.json \ 33 | --data_type reagent_pred \ 34 | --graph_tower $GRAPH_TOWER \ 35 | --init_checkpoint $INIT_CHECKPOINT_GNN \ 36 | --pretrain_mm_mlp_adapter $CHECKPOINT_FOLDER_PREFIX/llava-$GRAPH_TOWER-$MODEL_VERSION-pretrain/checkpoint-48000/mm_projector.bin \ 37 | --mm_use_im_start_end False \ 38 | --mm_use_im_patch_token False \ 39 | --bf16 True \ 40 | --output_dir $CHECKPOINT_FOLDER_PREFIX/graph-text-molgen/$TASK-llava-$GRAPH_TOWER-$MODEL_VERSION-finetune_lora \ 41 | --num_train_epochs 5 \ 42 | --per_device_train_batch_size 16 \ 43 | --per_device_eval_batch_size 4 \ 44 | --gradient_accumulation_steps 1 \ 45 | --evaluation_strategy "no" \ 46 | --save_strategy "epoch" \ 47 | --save_total_limit 1 \ 48 | --learning_rate 2e-5 \ 49 | --weight_decay 0. \ 50 | --warmup_ratio 0.03 \ 51 | --lr_scheduler_type "cosine" \ 52 | --logging_steps 1 \ 53 | --tf32 True \ 54 | --model_max_length 2048 \ 55 | --gradient_checkpointing True \ 56 | --lazy_preprocess True \ 57 | --dataloader_num_workers 4 \ 58 | --report_to none 59 | -------------------------------------------------------------------------------- /scripts/finetune_lora_retrosynthesis.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | PROMPT_VERSION=v1 7 | MODEL_VERSION="vicuna-v1-3-7b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | # PROMPT_VERSION="llava_llama_2" 12 | # MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | 15 | GRAPH_TOWER="moleculestm" 16 | if [ "$GRAPH_TOWER" == "graphmvp" ]; then 17 | INIT_CHECKPOINT_GNN="./checkpoints/graphmvp.pth" 18 | elif [ "$GRAPH_TOWER" == "moleculestm" ]; then 19 | INIT_CHECKPOINT_GNN="./checkpoints/MoleculeSTM/molecule_model.pth" 20 | else 21 | echo "Not supported graph tower" 22 | fi 23 | 24 | CHECKPOINT_FOLDER_PREFIX="./checkpoints/Graph-LLaVA" 25 | TASK="retrosynthesis" 26 | 27 | deepspeed llava/train/train_mem.py \ 28 | --deepspeed scripts/zero2.json \ 29 | --lora_enable True \ 30 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 31 | --version $PROMPT_VERSION \ 32 | --data_path /shared_space/caohe/AIDD/DATA/Mol-Instructions/Molecule-oriented_Instructions/retrosynthesis_train.json \ 33 | --data_type retrosynthesis \ 34 | --graph_tower $GRAPH_TOWER \ 35 | --init_checkpoint $INIT_CHECKPOINT_GNN \ 36 | --pretrain_mm_mlp_adapter $CHECKPOINT_FOLDER_PREFIX/llava-$GRAPH_TOWER-$MODEL_VERSION-pretrain/checkpoint-48000/mm_projector.bin \ 37 | --mm_use_im_start_end False \ 38 | --mm_use_im_patch_token False \ 39 | --bf16 True \ 40 | --output_dir $CHECKPOINT_FOLDER_PREFIX/graph-text-molgen/$TASK-llava-$GRAPH_TOWER-$MODEL_VERSION-finetune_lora \ 41 | --num_train_epochs 5 \ 42 | --per_device_train_batch_size 16 \ 43 | --per_device_eval_batch_size 4 \ 44 | --gradient_accumulation_steps 1 \ 45 | --evaluation_strategy "no" \ 46 | --save_strategy "epoch" \ 47 | --save_total_limit 1 \ 48 | --learning_rate 2e-5 \ 49 | --weight_decay 0. \ 50 | --warmup_ratio 0.03 \ 51 | --lr_scheduler_type "cosine" \ 52 | --logging_steps 1 \ 53 | --tf32 True \ 54 | --model_max_length 2048 \ 55 | --gradient_checkpointing True \ 56 | --lazy_preprocess True \ 57 | --dataloader_num_workers 4 \ 58 | --report_to none 59 | -------------------------------------------------------------------------------- /scripts/finetune_lora_forward_pred.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | PROMPT_VERSION=v1 7 | MODEL_VERSION="vicuna-v1-3-7b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | # PROMPT_VERSION="llava_llama_2" 12 | # MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | 15 | GRAPH_TOWER="moleculestm" 16 | if [ "$GRAPH_TOWER" == "graphmvp" ]; then 17 | INIT_CHECKPOINT_GNN="./checkpoints/graphmvp.pth" 18 | elif [ "$GRAPH_TOWER" == "moleculestm" ]; then 19 | INIT_CHECKPOINT_GNN="./checkpoints/MoleculeSTM/molecule_model.pth" 20 | else 21 | echo "Not supported graph tower" 22 | fi 23 | 24 | CHECKPOINT_FOLDER_PREFIX="./checkpoints/Graph-LLaVA" 25 | TASK="forward_pred" 26 | 27 | deepspeed llava/train/train_mem.py \ 28 | --deepspeed scripts/zero2.json \ 29 | --lora_enable True \ 30 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 31 | --version $PROMPT_VERSION \ 32 | --data_path /shared_space/caohe/AIDD/DATA/Mol-Instructions/Molecule-oriented_Instructions/forward_reaction_prediction_train.json \ 33 | --data_type forward_pred \ 34 | --graph_tower $GRAPH_TOWER \ 35 | --init_checkpoint $INIT_CHECKPOINT_GNN \ 36 | --pretrain_mm_mlp_adapter $CHECKPOINT_FOLDER_PREFIX/llava-$GRAPH_TOWER-$MODEL_VERSION-pretrain/checkpoint-48000/mm_projector.bin \ 37 | --mm_use_im_start_end False \ 38 | --mm_use_im_patch_token False \ 39 | --bf16 True \ 40 | --output_dir $CHECKPOINT_FOLDER_PREFIX/graph-text-molgen/$TASK-llava-$GRAPH_TOWER-$MODEL_VERSION-finetune_lora \ 41 | --num_train_epochs 5 \ 42 | --per_device_train_batch_size 16 \ 43 | --per_device_eval_batch_size 4 \ 44 | --gradient_accumulation_steps 1 \ 45 | --evaluation_strategy "no" \ 46 | --save_strategy "epoch" \ 47 | --save_total_limit 1 \ 48 | --learning_rate 2e-5 \ 49 | --weight_decay 0. \ 50 | --warmup_ratio 0.03 \ 51 | --lr_scheduler_type "cosine" \ 52 | --logging_steps 1 \ 53 | --tf32 True \ 54 | --model_max_length 2048 \ 55 | --gradient_checkpointing True \ 56 | --lazy_preprocess True \ 57 | --dataloader_num_workers 4 \ 58 | --report_to none 59 | -------------------------------------------------------------------------------- /llava/train/llava_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from transformers import Trainer 5 | from typing import Optional 6 | 7 | 8 | def maybe_zero_3(param, ignore_status=False, name=None): 9 | from deepspeed import zero 10 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 11 | if hasattr(param, "ds_id"): 12 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 13 | if not ignore_status: 14 | print(name, 'no ignore status') 15 | with zero.GatheredParameters([param]): 16 | param = param.data.detach().cpu().clone() 17 | else: 18 | param = param.detach().cpu().clone() 19 | return param 20 | 21 | 22 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 23 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 24 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} 25 | return to_return 26 | 27 | 28 | class LLaVATrainer(Trainer): 29 | 30 | def _save_checkpoint(self, model, trial, metrics=None): 31 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 32 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 33 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 34 | 35 | run_dir = self._get_output_dir(trial=trial) 36 | output_dir = os.path.join(run_dir, checkpoint_folder) 37 | 38 | # Only save Adapter 39 | keys_to_match = ['mm_projector'] 40 | if getattr(self.args, "use_im_start_end", False): 41 | keys_to_match.extend(['embed_tokens', 'embed_in']) 42 | 43 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) 44 | 45 | if self.args.local_rank == 0 or self.args.local_rank == -1: 46 | self.model.config.save_pretrained(output_dir) 47 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 48 | else: 49 | super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) 50 | 51 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 52 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 53 | pass 54 | else: 55 | super(LLaVATrainer, self)._save(output_dir, state_dict) 56 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /docs/static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | var INTERP_BASE = "./static/interpolation/stacked"; 4 | var NUM_INTERP_FRAMES = 240; 5 | 6 | var interp_images = []; 7 | function preloadInterpolationImages() { 8 | for (var i = 0; i < NUM_INTERP_FRAMES; i++) { 9 | var path = INTERP_BASE + '/' + String(i).padStart(6, '0') + '.jpg'; 10 | interp_images[i] = new Image(); 11 | interp_images[i].src = path; 12 | } 13 | } 14 | 15 | function setInterpolationImage(i) { 16 | var image = interp_images[i]; 17 | image.ondragstart = function() { return false; }; 18 | image.oncontextmenu = function() { return false; }; 19 | $('#interpolation-image-wrapper').empty().append(image); 20 | } 21 | 22 | 23 | $(document).ready(function() { 24 | // Check for click events on the navbar burger icon 25 | $(".navbar-burger").click(function() { 26 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu" 27 | $(".navbar-burger").toggleClass("is-active"); 28 | $(".navbar-menu").toggleClass("is-active"); 29 | 30 | }); 31 | 32 | var options = { 33 | slidesToScroll: 1, 34 | slidesToShow: 3, 35 | loop: true, 36 | infinite: true, 37 | autoplay: false, 38 | autoplaySpeed: 3000, 39 | } 40 | 41 | // Initialize all div with carousel class 42 | var carousels = bulmaCarousel.attach('.carousel', options); 43 | 44 | // Loop on each carousel initialized 45 | for(var i = 0; i < carousels.length; i++) { 46 | // Add listener to event 47 | carousels[i].on('before:show', state => { 48 | console.log(state); 49 | }); 50 | } 51 | 52 | // Access to bulmaCarousel instance of an element 53 | var element = document.querySelector('#my-element'); 54 | if (element && element.bulmaCarousel) { 55 | // bulmaCarousel instance is available as element.bulmaCarousel 56 | element.bulmaCarousel.on('before-show', function(state) { 57 | console.log(state); 58 | }); 59 | } 60 | 61 | /*var player = document.getElementById('interpolation-video'); 62 | player.addEventListener('loadedmetadata', function() { 63 | $('#interpolation-slider').on('input', function(event) { 64 | console.log(this.value, player.duration); 65 | player.currentTime = player.duration / 100 * this.value; 66 | }) 67 | }, false);*/ 68 | preloadInterpolationImages(); 69 | 70 | $('#interpolation-slider').on('input', function(event) { 71 | setInterpolationImage(this.value); 72 | }); 73 | setInterpolationImage(0); 74 | $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1); 75 | 76 | bulmaSlider.attach(); 77 | 78 | }) 79 | -------------------------------------------------------------------------------- /llava/model/language_model/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | from .attention import ATTN_CLASS_REGISTRY 6 | from .norm import NORM_CLASS_REGISTRY 7 | 8 | class MPTMLP(nn.Module): 9 | 10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): 11 | super().__init__() 12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 13 | self.act = nn.GELU(approximate='none') 14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 15 | self.down_proj._is_residual = True 16 | 17 | def forward(self, x): 18 | return self.down_proj(self.act(self.up_proj(x))) 19 | 20 | class MPTBlock(nn.Module): 21 | 22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs): 23 | del kwargs 24 | super().__init__() 25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] 27 | self.norm_1 = norm_class(d_model, device=device) 28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device) 29 | self.norm_2 = norm_class(d_model, device=device) 30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) 31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 33 | 34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 35 | a = self.norm_1(x) 36 | (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) 37 | x = x + self.resid_attn_dropout(b) 38 | m = self.norm_2(x) 39 | n = self.ffn(m) 40 | x = x + self.resid_ffn_dropout(n) 41 | return (x, attn_weights, past_key_value) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _cast_if_autocast_enabled(tensor): 4 | if torch.is_autocast_enabled(): 5 | if tensor.device.type == 'cuda': 6 | dtype = torch.get_autocast_gpu_dtype() 7 | elif tensor.device.type == 'cpu': 8 | dtype = torch.get_autocast_cpu_dtype() 9 | else: 10 | raise NotImplementedError() 11 | return tensor.to(dtype=dtype) 12 | return tensor 13 | 14 | class LPLayerNorm(torch.nn.LayerNorm): 15 | 16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): 17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) 18 | 19 | def forward(self, x): 20 | module_device = x.device 21 | downcast_x = _cast_if_autocast_enabled(x) 22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 24 | with torch.autocast(enabled=False, device_type=module_device.type): 25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) 26 | 27 | def rms_norm(x, weight=None, eps=1e-05): 28 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 29 | if weight is not None: 30 | return output * weight 31 | return output 32 | 33 | class RMSNorm(torch.nn.Module): 34 | 35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 36 | super().__init__() 37 | self.eps = eps 38 | if weight: 39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) 40 | else: 41 | self.register_parameter('weight', None) 42 | 43 | def forward(self, x): 44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 45 | 46 | class LPRMSNorm(RMSNorm): 47 | 48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) 50 | 51 | def forward(self, x): 52 | downcast_x = _cast_if_autocast_enabled(x) 53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 54 | with torch.autocast(enabled=False, device_type=x.device.type): 55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} -------------------------------------------------------------------------------- /docs/static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | } 4 | 5 | 6 | .footer .icon-link { 7 | font-size: 25px; 8 | color: #000; 9 | } 10 | 11 | .link-block a { 12 | margin-top: 5px; 13 | margin-bottom: 5px; 14 | } 15 | 16 | .dnerf { 17 | font-variant: small-caps; 18 | } 19 | 20 | 21 | .teaser .hero-body { 22 | padding-top: 0; 23 | padding-bottom: 3rem; 24 | } 25 | 26 | .teaser { 27 | font-family: 'Google Sans', sans-serif; 28 | } 29 | 30 | 31 | .publication-title { 32 | } 33 | 34 | .publication-banner { 35 | max-height: parent; 36 | 37 | } 38 | 39 | .publication-banner video { 40 | position: relative; 41 | left: auto; 42 | top: auto; 43 | transform: none; 44 | object-fit: fit; 45 | } 46 | 47 | .publication-header .hero-body { 48 | } 49 | 50 | .publication-title { 51 | font-family: 'Google Sans', sans-serif; 52 | } 53 | 54 | .publication-authors { 55 | font-family: 'Google Sans', sans-serif; 56 | } 57 | 58 | .publication-venue { 59 | color: #555; 60 | width: fit-content; 61 | font-weight: bold; 62 | } 63 | 64 | .publication-awards { 65 | color: #ff3860; 66 | width: fit-content; 67 | font-weight: bolder; 68 | } 69 | 70 | .publication-authors { 71 | } 72 | 73 | .publication-authors a { 74 | color: hsl(204, 86%, 53%) !important; 75 | } 76 | 77 | .publication-authors a:hover { 78 | text-decoration: underline; 79 | } 80 | 81 | .author-block { 82 | display: inline-block; 83 | } 84 | 85 | .publication-banner img { 86 | } 87 | 88 | .publication-authors { 89 | /*color: #4286f4;*/ 90 | } 91 | 92 | .publication-video { 93 | position: relative; 94 | width: 100%; 95 | height: 0; 96 | padding-bottom: 56.25%; 97 | 98 | overflow: hidden; 99 | border-radius: 10px !important; 100 | } 101 | 102 | .publication-video iframe { 103 | position: absolute; 104 | top: 0; 105 | left: 0; 106 | width: 100%; 107 | height: 100%; 108 | } 109 | 110 | .publication-body img { 111 | } 112 | 113 | .results-carousel { 114 | overflow: hidden; 115 | } 116 | 117 | .results-carousel .item { 118 | margin: 5px; 119 | overflow: hidden; 120 | border: 1px solid #bbb; 121 | border-radius: 10px; 122 | padding: 0; 123 | font-size: 0; 124 | } 125 | 126 | .results-carousel video { 127 | margin: 0; 128 | } 129 | 130 | 131 | .interpolation-panel { 132 | background: #f5f5f5; 133 | border-radius: 10px; 134 | } 135 | 136 | .interpolation-panel .interpolation-image { 137 | width: 100%; 138 | border-radius: 5px; 139 | } 140 | 141 | .interpolation-video-column { 142 | } 143 | 144 | .interpolation-panel .slider { 145 | margin: 0 !important; 146 | } 147 | 148 | .interpolation-panel .slider { 149 | margin: 0 !important; 150 | } 151 | 152 | #interpolation-image-wrapper { 153 | width: 100%; 154 | } 155 | #interpolation-image-wrapper img { 156 | border-radius: 5px; 157 | } 158 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | 17 | if not delay_load: 18 | self.load_model() 19 | else: 20 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 21 | 22 | def load_model(self): 23 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 24 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 25 | self.vision_tower.requires_grad_(False) 26 | 27 | self.is_loaded = True 28 | 29 | def feature_select(self, image_forward_outs): 30 | image_features = image_forward_outs.hidden_states[self.select_layer] 31 | if self.select_feature == 'patch': 32 | image_features = image_features[:, 1:] 33 | elif self.select_feature == 'cls_patch': 34 | image_features = image_features 35 | else: 36 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 37 | return image_features 38 | 39 | @torch.no_grad() 40 | def forward(self, images): 41 | if type(images) is list: 42 | image_features = [] 43 | for image in images: 44 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 45 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 46 | image_features.append(image_feature) 47 | else: 48 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 49 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 50 | 51 | return image_features 52 | 53 | @property 54 | def dummy_feature(self): 55 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 56 | 57 | @property 58 | def dtype(self): 59 | return self.vision_tower.dtype 60 | 61 | @property 62 | def device(self): 63 | return self.vision_tower.device 64 | 65 | @property 66 | def config(self): 67 | if self.is_loaded: 68 | return self.vision_tower.config 69 | else: 70 | return self.cfg_only 71 | 72 | @property 73 | def hidden_size(self): 74 | return self.config.hidden_size 75 | 76 | @property 77 | def num_patches(self): 78 | return (self.config.image_size // self.config.patch_size) ** 2 79 | -------------------------------------------------------------------------------- /llava/datasets/property_pred_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | import copy 5 | import pickle 6 | from typing import Dict, Optional, Sequence, List 7 | import selfies 8 | import torch 9 | from torch.utils.data import Dataset 10 | import transformers 11 | from .preprocess import preprocess, preprocess_multimodal 12 | from .smiles2graph import smiles2graph 13 | 14 | 15 | class PropertyPredSupervisedGraphDataset(Dataset): 16 | """We use MolInstruction https://huggingface.co/datasets/zjunlp/Mol-Instructions/viewer/Molecule-oriented%20Instructions/ (124K) """ 17 | add_selfies = True 18 | def __init__(self, 19 | data_path: str, 20 | tokenizer: transformers.PreTrainedTokenizer, 21 | data_args, 22 | ): 23 | super(PropertyPredSupervisedGraphDataset, self).__init__() 24 | with open(data_path, "rb") as f: 25 | list_data_dict = json.load(f) 26 | 27 | self.tokenizer = tokenizer 28 | self.list_data_dict = list_data_dict 29 | self.data_args = data_args 30 | 31 | def selfies2smiles(self, selfies_str): 32 | try: 33 | smiles_str = selfies.decoder(selfies_str) 34 | except: 35 | smiles_str = None 36 | return smiles_str 37 | 38 | def __len__(self): 39 | return len(self.list_data_dict) 40 | 41 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 42 | raw = self.list_data_dict[i] 43 | instruction = raw['instruction'] 44 | if self.add_selfies: 45 | instruction += f" The compound SELFIES sequence is: {raw['input']}" 46 | if random.random() < 0.5: 47 | instruction = "\n" + instruction 48 | else: 49 | instruction = instruction + "\n" 50 | 51 | input_selfies, target = raw['input'], str(raw['output']) 52 | # convert input selfies to smiles for building graph 53 | graph=smiles2graph(self.selfies2smiles(input_selfies)) 54 | sources = dict( 55 | conversations=[ 56 | {"from": "human", "value": instruction}, 57 | {"from": "gpt", "value": target} 58 | ] 59 | ) 60 | 61 | if isinstance(i, int): 62 | sources = [sources] 63 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME 64 | 65 | if graph is not None: 66 | sources = preprocess_multimodal( 67 | copy.deepcopy([e["conversations"] for e in sources]), 68 | self.data_args) 69 | else: 70 | sources = copy.deepcopy([e["conversations"] for e in sources]) 71 | 72 | data_dict = preprocess( 73 | sources, 74 | self.tokenizer, 75 | has_image=(graph is not None)) 76 | if isinstance(i, int): 77 | data_dict = dict(input_ids=data_dict["input_ids"][0], 78 | labels=data_dict["labels"][0]) 79 | 80 | # graph exist in the data 81 | if graph is not None: 82 | data_dict['graph'] = graph 83 | elif self.data_args.is_multimodal: 84 | raise ValueError("Graph does not exist in the data, but the model is multimodal") 85 | return data_dict -------------------------------------------------------------------------------- /llava/datasets/retrosynthesis_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | import copy 5 | import pickle 6 | from typing import Dict, Optional, Sequence, List 7 | import selfies 8 | import torch 9 | from torch.utils.data import Dataset 10 | import transformers 11 | from .preprocess import preprocess, preprocess_multimodal 12 | from .smiles2graph import smiles2graph 13 | 14 | 15 | class RetrosynthesisSupervisedGraphDataset(Dataset): 16 | """We use MolInstruction https://huggingface.co/datasets/zjunlp/Mol-Instructions/viewer/Molecule-oriented%20Instructions/ (124K) """ 17 | add_selfies = True 18 | def __init__(self, 19 | data_path: str, 20 | tokenizer: transformers.PreTrainedTokenizer, 21 | data_args, 22 | ): 23 | super(RetrosynthesisSupervisedGraphDataset, self).__init__() 24 | with open(data_path, "rb") as f: 25 | list_data_dict = json.load(f) 26 | 27 | self.tokenizer = tokenizer 28 | self.list_data_dict = list_data_dict 29 | self.data_args = data_args 30 | 31 | def selfies2smiles(self, selfies_str): 32 | try: 33 | smiles_str = selfies.decoder(selfies_str) 34 | except: 35 | smiles_str = None 36 | return smiles_str 37 | 38 | def __len__(self): 39 | return len(self.list_data_dict) 40 | 41 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 42 | raw = self.list_data_dict[i] 43 | instruction = raw['instruction'] 44 | if self.add_selfies: 45 | instruction += f" The product is: {raw['input']}" 46 | if random.random() < 0.5: 47 | instruction = "\n" + instruction 48 | else: 49 | instruction = instruction + "\n" 50 | 51 | input_selfies, output_selfies = raw['input'], raw['output'] 52 | # convert input selfies to smiles for building graph 53 | reactant_smiles = self.selfies2smiles(input_selfies) 54 | graph=smiles2graph(reactant_smiles) 55 | sources = dict( 56 | conversations=[ 57 | {"from": "human", "value": instruction}, 58 | {"from": "gpt", "value": output_selfies} 59 | ] 60 | ) 61 | 62 | if isinstance(i, int): 63 | sources = [sources] 64 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME 65 | 66 | if graph is not None: 67 | sources = preprocess_multimodal( 68 | copy.deepcopy([e["conversations"] for e in sources]), 69 | self.data_args) 70 | else: 71 | sources = copy.deepcopy([e["conversations"] for e in sources]) 72 | 73 | data_dict = preprocess( 74 | sources, 75 | self.tokenizer, 76 | has_image=(graph is not None)) 77 | if isinstance(i, int): 78 | data_dict = dict(input_ids=data_dict["input_ids"][0], 79 | labels=data_dict["labels"][0]) 80 | 81 | # graph exist in the data 82 | if graph is not None: 83 | data_dict['graph'] = graph 84 | elif self.data_args.is_multimodal: 85 | raise ValueError("Graph does not exist in the data, but the model is multimodal") 86 | return data_dict -------------------------------------------------------------------------------- /llava/datasets/collators.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional, Sequence, List 3 | 4 | import torch 5 | from torch_geometric.data import Batch, Data 6 | import transformers 7 | 8 | from llava.constants import IGNORE_INDEX 9 | 10 | @dataclass 11 | class DataCollatorForSupervisedDataset(object): 12 | """Collate examples for supervised fine-tuning.""" 13 | 14 | tokenizer: transformers.PreTrainedTokenizer 15 | 16 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 17 | input_ids, labels = tuple([instance[key] for instance in instances] 18 | for key in ("input_ids", "labels")) 19 | input_ids = torch.nn.utils.rnn.pad_sequence( 20 | input_ids, 21 | batch_first=True, 22 | padding_value=self.tokenizer.pad_token_id) 23 | labels = torch.nn.utils.rnn.pad_sequence(labels, 24 | batch_first=True, 25 | padding_value=IGNORE_INDEX) 26 | input_ids = input_ids[:, :self.tokenizer.model_max_length] 27 | labels = labels[:, :self.tokenizer.model_max_length] 28 | batch = dict( 29 | input_ids=input_ids, 30 | labels=labels, 31 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 32 | ) 33 | 34 | if 'image' in instances[0]: 35 | images = [instance['image'] for instance in instances] 36 | if all(x is not None and x.shape == images[0].shape for x in images): 37 | batch['images'] = torch.stack(images) 38 | else: 39 | batch['images'] = images 40 | 41 | return batch 42 | 43 | 44 | @dataclass 45 | class GraphDataCollatorForSupervisedDataset(object): 46 | """Collate graph-QA examples for supervised fine-tuning.""" 47 | 48 | tokenizer: transformers.PreTrainedTokenizer 49 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 50 | input_ids, labels = tuple([instance[key] for instance in instances] 51 | for key in ("input_ids", "labels")) 52 | input_ids = torch.nn.utils.rnn.pad_sequence( 53 | input_ids, 54 | batch_first=True, 55 | padding_value=self.tokenizer.pad_token_id) 56 | labels = torch.nn.utils.rnn.pad_sequence(labels, 57 | batch_first=True, 58 | padding_value=IGNORE_INDEX) 59 | input_ids = input_ids[:, :self.tokenizer.model_max_length] 60 | labels = labels[:, :self.tokenizer.model_max_length] 61 | batch = dict( 62 | input_ids=input_ids, 63 | labels=labels, 64 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 65 | ) 66 | 67 | if 'graph' in instances[0]: 68 | g = Batch.from_data_list([self._convert_dict_to_Data(instance["graph"]) for instance in instances]) 69 | batch['graphs'] = g 70 | return batch 71 | 72 | def _convert_dict_to_Data(self, data_dict: Dict) -> Data: 73 | return Data( 74 | x=torch.asarray(data_dict['node_feat']), 75 | edge_attr=torch.asarray(data_dict['edge_feat']), 76 | edge_index=torch.asarray(data_dict['edge_index']), 77 | ) -------------------------------------------------------------------------------- /llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | from pydantic import BaseModel, Field 5 | 6 | import torch 7 | from transformers import StoppingCriteria 8 | from llava.constants import IMAGE_TOKEN_INDEX 9 | 10 | 11 | def load_image_from_base64(image): 12 | return Image.open(BytesIO(base64.b64decode(image))) 13 | 14 | 15 | def process_images(images, image_processor, model_cfg): 16 | return image_processor(images, return_tensors='pt')['pixel_values'] 17 | 18 | 19 | def tokenizer_image_token(prompt:str, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 20 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 21 | 22 | def insert_separator(X, sep): 23 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 24 | 25 | input_ids = [] 26 | offset = 0 27 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 28 | offset = 1 29 | input_ids.append(prompt_chunks[0][0]) 30 | 31 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 32 | input_ids.extend(x[offset:]) 33 | 34 | if return_tensors is not None: 35 | if return_tensors == 'pt': 36 | return torch.tensor(input_ids, dtype=torch.long) 37 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 38 | return input_ids 39 | 40 | 41 | def get_model_name_from_path(model_path): 42 | model_path = model_path.strip("/") 43 | model_paths = model_path.split("/") 44 | if model_paths[-1].startswith('checkpoint-'): 45 | return model_paths[-2] + "_" + model_paths[-1] 46 | else: 47 | return model_paths[-1] 48 | 49 | 50 | 51 | 52 | class KeywordsStoppingCriteria(StoppingCriteria): 53 | def __init__(self, keywords, tokenizer, input_ids): 54 | self.keywords = keywords 55 | self.keyword_ids = [] 56 | for keyword in keywords: 57 | cur_keyword_ids = tokenizer(keyword).input_ids 58 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 59 | cur_keyword_ids = cur_keyword_ids[1:] 60 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 61 | self.tokenizer = tokenizer 62 | self.start_len = input_ids.shape[1] 63 | 64 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 65 | bs = output_ids.shape[0] 66 | if bs == 1: 67 | offset = min(output_ids.shape[1] - self.start_len, 3) 68 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] # cast to device 69 | for keyword_id in self.keyword_ids: 70 | if output_ids[0, -keyword_id.shape[0]:] == keyword_id: 71 | return True 72 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 73 | for keyword in self.keywords: 74 | if keyword in outputs: 75 | return True 76 | return False 77 | else: 78 | raise NotImplementedError("Only support batch size 1 (yet)") 79 | 80 | class MM_ENCODER_CFG(BaseModel): 81 | gin_num_layers: int = 5 82 | gin_hidden_dim: int = 300 83 | drop_ratio: float = 0.1 84 | init_checkpoint: str = None 85 | graph_pooling: str = 'mean' -------------------------------------------------------------------------------- /llava/eval/eval_molcap.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger(__name__) 3 | 4 | import argparse 5 | import json 6 | import os 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | from nltk.translate.bleu_score import corpus_bleu 11 | from nltk.translate.meteor_score import meteor_score 12 | from rouge_score import rouge_scorer 13 | import torch 14 | from transformers import BertTokenizerFast 15 | 16 | 17 | def test_molcap_from_file(file, args): 18 | tokenizer = BertTokenizerFast.from_pretrained(args.text2mol_bert_path) 19 | output_tokens = [] 20 | gt_tokens = [] 21 | meteor_scores = [] 22 | rouge_scores = [] 23 | scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL']) 24 | 25 | with open(file, "r") as f: 26 | for i,log in tqdm(enumerate(json.load(f))): 27 | cid,pred,gt = log['cid'],log['text'],log['gt'] 28 | output_tokens.append(tokenizer.tokenize(pred, truncation=True, max_length=512, padding='max_length')) 29 | output_tokens[i] = list(filter(('[PAD]').__ne__, output_tokens[i])) 30 | output_tokens[i] = list(filter(('[CLS]').__ne__, output_tokens[i])) 31 | output_tokens[i] = list(filter(('[SEP]').__ne__, output_tokens[i])) 32 | 33 | gt_tokens.append(tokenizer.tokenize(gt, truncation=True, max_length=512, padding='max_length')) 34 | gt_tokens[i] = list(filter(('[PAD]').__ne__, gt_tokens[i])) 35 | gt_tokens[i] = list(filter(('[CLS]').__ne__, gt_tokens[i])) 36 | gt_tokens[i] = [list(filter(('[SEP]').__ne__, gt_tokens[i]))] 37 | 38 | meteor_scores.append(meteor_score(gt_tokens[i], output_tokens[i])) 39 | rouge_scores.append(scorer.score(gt, pred)) 40 | bleu2 = corpus_bleu(gt_tokens, output_tokens, weights=(0.5, 0.5)) 41 | bleu4 = corpus_bleu(gt_tokens, output_tokens, weights=(0.25, 0.25, 0.25, 0.25)) 42 | 43 | # extract top-10 meteor scores 44 | meteor_scores = np.array(meteor_scores) 45 | Start,K = 500,100 46 | idxes = np.argsort(meteor_scores)[::-1][Start:Start+K] 47 | cids = [log['cid'] for i,log in enumerate(json.load(open(file, "r"))) if i in idxes] 48 | cids.sort(key=lambda x: int(x)) 49 | 50 | return { 51 | "BLEU-2": bleu2, 52 | "BLEU-4": bleu4, 53 | "Meteor": np.mean(meteor_scores), 54 | "ROUGE-1": np.mean([rs['rouge1'].fmeasure for rs in rouge_scores]), 55 | "ROUGE-2": np.mean([rs['rouge2'].fmeasure for rs in rouge_scores]), 56 | "ROUGE-L": np.mean([rs['rougeL'].fmeasure for rs in rouge_scores]), 57 | } 58 | 59 | 60 | def add_arguments(parser): 61 | parser.add_argument("--molcap_result_file", type=str, required=True) 62 | parser.add_argument("--text2mol_bert_path", type=str, default="checkpoints/scibert_scivocab_uncased") 63 | 64 | if __name__ == "__main__": 65 | logging.basicConfig( 66 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 67 | datefmt="%m/%d/%Y %H:%M:%S", 68 | level=logging.INFO, 69 | ) 70 | 71 | parser = argparse.ArgumentParser() 72 | add_arguments(parser) 73 | args = parser.parse_args() 74 | result = test_molcap_from_file(args.molcap_result_file, args) 75 | print(result) 76 | 77 | """ 78 | python -m llava.eval.eval_molcap \ 79 | --molcap_result_file eval_result/chebi20-molcap-lora-10ep.json \ 80 | --text2mol_bert_path checkpoints/scibert_scivocab_uncased 81 | """ 82 | -------------------------------------------------------------------------------- /llava/datasets/MoleculeNet_classification_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import copy 4 | import pickle 5 | from typing import Dict, Optional, Sequence, List 6 | import torch 7 | from torch.utils.data import Dataset 8 | import transformers 9 | import selfies 10 | from .preprocess import preprocess, preprocess_multimodal 11 | 12 | def smiles2selfies(smiles_str): 13 | try: 14 | selfies_str = selfies.encoder(smiles_str) 15 | except: 16 | selfies_str = None 17 | return selfies_str 18 | 19 | class MoleculeNetSupervisedGraphDataset(Dataset): 20 | add_selfies = False 21 | def __init__(self, 22 | data_path: str, 23 | tokenizer: transformers.PreTrainedTokenizer, 24 | data_args, 25 | ): 26 | super(MoleculeNetSupervisedGraphDataset, self).__init__() 27 | self.dataspace = data_path 28 | self.tokenizer = tokenizer 29 | self.list_data_dict = self._load_pickle() 30 | self.data_args = data_args 31 | if self.add_selfies: 32 | print("WARNING: Add SELFIES to the instruction") 33 | 34 | def _load_pickle(self): 35 | # load "bace" "bbbp" "hiv" three datasets 36 | split = "random" # "" for scaffold 37 | list_data_dict = [] 38 | for dataset in ["bace", "bbbp", "hiv"]: 39 | with open(os.path.join(self.dataspace, dataset, "processed", f"instruct-{split}-train.pkl"), "rb") as f: 40 | list_data_dict += pickle.load(f) 41 | return list_data_dict 42 | 43 | def __len__(self): 44 | return len(self.list_data_dict) 45 | 46 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 47 | raw = self.list_data_dict[i] 48 | instruction = raw['instruction'] 49 | if self.add_selfies: 50 | selfies_str = smiles2selfies(raw['SMILES']) 51 | instruction += f" The compound SELFIES sequence is: {selfies_str}" 52 | if random.random() < 0.5: 53 | instruction = "\n" + instruction 54 | else: 55 | instruction = instruction + "\n" 56 | graph, target = raw['graph'], str(raw['label']) 57 | sources = dict( 58 | conversations=[ 59 | {"from": "human", "value": instruction}, 60 | {"from": "gpt", "value": target} 61 | ] 62 | ) 63 | 64 | if isinstance(i, int): 65 | sources = [sources] 66 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME 67 | 68 | if graph is not None: 69 | sources = preprocess_multimodal( 70 | copy.deepcopy([e["conversations"] for e in sources]), 71 | self.data_args) 72 | else: 73 | sources = copy.deepcopy([e["conversations"] for e in sources]) 74 | 75 | data_dict = preprocess( 76 | sources, 77 | self.tokenizer, 78 | has_image=(graph is not None)) 79 | if isinstance(i, int): 80 | data_dict = dict(input_ids=data_dict["input_ids"][0], 81 | labels=data_dict["labels"][0]) 82 | 83 | # graph exist in the data 84 | if graph is not None: 85 | data_dict['graph'] = graph 86 | elif self.data_args.is_multimodal: 87 | raise ValueError("Graph does not exist in the data, but the model is multimodal") 88 | return data_dict -------------------------------------------------------------------------------- /docs/static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /llava/datasets/forward_pred_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | import copy 5 | import pickle 6 | from typing import Dict, Optional, Sequence, List 7 | import selfies 8 | import torch 9 | from torch.utils.data import Dataset 10 | import transformers 11 | from .preprocess import preprocess, preprocess_multimodal 12 | from .smiles2graph import smiles2graph 13 | 14 | 15 | class ForwardPredSupervisedGraphDataset(Dataset): 16 | """We use MolInstruction https://huggingface.co/datasets/zjunlp/Mol-Instructions/viewer/Molecule-oriented%20Instructions/forward_reaction_prediction (124K) """ 17 | add_selfies = True 18 | def __init__(self, 19 | data_path: str, 20 | tokenizer: transformers.PreTrainedTokenizer, 21 | data_args, 22 | ): 23 | super(ForwardPredSupervisedGraphDataset, self).__init__() 24 | with open(data_path, "rb") as f: 25 | list_data_dict = json.load(f) 26 | 27 | self.tokenizer = tokenizer 28 | self.list_data_dict = list_data_dict 29 | self.data_args = data_args 30 | 31 | def selfies2smiles(self, selfies_str): 32 | try: 33 | smiles_str = selfies.decoder(selfies_str) 34 | except: 35 | smiles_str = None 36 | return smiles_str 37 | 38 | def __len__(self): 39 | return len(self.list_data_dict) 40 | 41 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 42 | raw = self.list_data_dict[i] 43 | instruction = raw['instruction'] 44 | 45 | inputs, output_selfies = raw['input'].split('.'), raw['output'] 46 | # input is multiple(single) selfies, concat with '.' 47 | # convert first input selfies to smiles for building graph 48 | reactant_smiles = self.selfies2smiles(inputs[0]) 49 | graph=smiles2graph(reactant_smiles) 50 | if self.add_selfies: 51 | instruction += " " + raw['input'] 52 | else: 53 | # insert the remaining reactants to the instruction 54 | if len(inputs) > 1: 55 | instruction += f" The other joint reactants are: {','.join(inputs[1:])}" 56 | 57 | if random.random() < 0.5: 58 | instruction = "\n" + instruction 59 | else: 60 | instruction = instruction + "\n" 61 | sources = dict( 62 | conversations=[ 63 | {"from": "human", "value": instruction}, 64 | {"from": "gpt", "value": output_selfies} 65 | ] 66 | ) 67 | 68 | if isinstance(i, int): 69 | sources = [sources] 70 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME 71 | 72 | if graph is not None: 73 | sources = preprocess_multimodal( 74 | copy.deepcopy([e["conversations"] for e in sources]), 75 | self.data_args) 76 | else: 77 | sources = copy.deepcopy([e["conversations"] for e in sources]) 78 | 79 | data_dict = preprocess( 80 | sources, 81 | self.tokenizer, 82 | has_image=(graph is not None)) 83 | if isinstance(i, int): 84 | data_dict = dict(input_ids=data_dict["input_ids"][0], 85 | labels=data_dict["labels"][0]) 86 | 87 | # graph exist in the data 88 | if graph is not None: 89 | data_dict['graph'] = graph 90 | elif self.data_args.is_multimodal: 91 | raise ValueError("Graph does not exist in the data, but the model is multimodal") 92 | return data_dict -------------------------------------------------------------------------------- /llava/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .lazy_supervised_dataset import LazySupervisedDataset, LazySupervisedGraphDataset 2 | from .reagent_pred_dataset import ReagentPredSupervisedGraphDataset 3 | from .forward_pred_dataset import ForwardPredSupervisedGraphDataset 4 | from .retrosynthesis_dataset import RetrosynthesisSupervisedGraphDataset 5 | from .property_pred_dataset import PropertyPredSupervisedGraphDataset 6 | from .collators import DataCollatorForSupervisedDataset, GraphDataCollatorForSupervisedDataset 7 | from .MoleculeNet_classification_dataset import MoleculeNetSupervisedGraphDataset 8 | from torch.utils.data import ConcatDataset 9 | 10 | 11 | def build_dataset(tokenizer, data_args): 12 | data_type = data_args.data_type 13 | if data_type == "supervised": 14 | dataset = LazySupervisedGraphDataset( 15 | data_path=data_args.data_path, 16 | tokenizer=tokenizer, 17 | data_args=data_args, 18 | ) 19 | elif data_type == "reagent_pred": 20 | dataset = ReagentPredSupervisedGraphDataset( 21 | data_path=data_args.data_path, 22 | tokenizer=tokenizer, 23 | data_args=data_args, 24 | ) 25 | elif data_type == "forward_pred": 26 | dataset = ForwardPredSupervisedGraphDataset( 27 | data_path=data_args.data_path, 28 | tokenizer=tokenizer, 29 | data_args=data_args, 30 | ) 31 | elif data_type == "retrosynthesis": 32 | dataset = RetrosynthesisSupervisedGraphDataset( 33 | data_path=data_args.data_path, 34 | tokenizer=tokenizer, 35 | data_args=data_args, 36 | ) 37 | elif data_type == "property_pred": 38 | dataset = PropertyPredSupervisedGraphDataset( 39 | data_path=data_args.data_path, 40 | tokenizer=tokenizer, 41 | data_args=data_args, 42 | ) 43 | elif data_type == "all": 44 | # combine molcap, reagent_pred, forward_pred, retrosynthesis, property_pred 45 | # hard code for data path 46 | molcap_data = LazySupervisedGraphDataset( 47 | data_path="/cto_labs/AIDD/DATA/MolT5/ChEBI-20_data/train.pkl", 48 | tokenizer=tokenizer, 49 | data_args=data_args, 50 | ) 51 | reagent_pred_data = ReagentPredSupervisedGraphDataset( 52 | data_path="/cto_labs/AIDD/DATA/Mol-Instructions/Molecule-oriented_Instructions/reagent_prediction_train.json", 53 | tokenizer=tokenizer, 54 | data_args=data_args, 55 | ) 56 | forward_pred_data = ForwardPredSupervisedGraphDataset( 57 | data_path="/cto_labs/AIDD/DATA/Mol-Instructions/Molecule-oriented_Instructions/forward_reaction_prediction_train.json", 58 | tokenizer=tokenizer, 59 | data_args=data_args, 60 | ) 61 | retrosynthesis_data = RetrosynthesisSupervisedGraphDataset( 62 | data_path="/cto_labs/AIDD/DATA/Mol-Instructions/Molecule-oriented_Instructions/retrosynthesis_train.json", 63 | tokenizer=tokenizer, 64 | data_args=data_args, 65 | ) 66 | property_pred_data = PropertyPredSupervisedGraphDataset( 67 | data_path="/cto_labs/AIDD/DATA/Mol-Instructions/Molecule-oriented_Instructions/property_prediction_train.json", 68 | tokenizer=tokenizer, 69 | data_args=data_args, 70 | ) 71 | dataset = ConcatDataset([molcap_data, reagent_pred_data, forward_pred_data, retrosynthesis_data, property_pred_data]) 72 | elif data_type == "MoleculeNet": 73 | dataset = MoleculeNetSupervisedGraphDataset( 74 | data_path=data_args.data_path, 75 | tokenizer=tokenizer, 76 | data_args=data_args, 77 | ) 78 | else: 79 | raise NotImplementedError(f"Unknown data type: {data_type}") 80 | return dataset -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | checkpoints/ 163 | eval_result/ 164 | playground/ -------------------------------------------------------------------------------- /llava/eval/run_llava.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | 16 | 17 | def load_image(image_file): 18 | if image_file.startswith('http') or image_file.startswith('https'): 19 | response = requests.get(image_file) 20 | image = Image.open(BytesIO(response.content)).convert('RGB') 21 | else: 22 | image = Image.open(image_file).convert('RGB') 23 | return image 24 | 25 | 26 | def eval_model(args): 27 | # Model 28 | disable_torch_init() 29 | 30 | model_name = get_model_name_from_path(args.model_path) 31 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name) 32 | 33 | qs = args.query 34 | if model.config.mm_use_im_start_end: 35 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 36 | else: 37 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 38 | 39 | if 'llama-2' in model_name.lower(): 40 | conv_mode = "llava_llama_2" 41 | elif "v1" in model_name.lower(): 42 | conv_mode = "llava_v1" 43 | elif "mpt" in model_name.lower(): 44 | conv_mode = "mpt" 45 | else: 46 | conv_mode = "llava_v0" 47 | 48 | if args.conv_mode is not None and conv_mode != args.conv_mode: 49 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 50 | else: 51 | args.conv_mode = conv_mode 52 | 53 | conv = conv_templates[args.conv_mode].copy() 54 | conv.append_message(conv.roles[0], qs) 55 | conv.append_message(conv.roles[1], None) 56 | prompt = conv.get_prompt() 57 | 58 | image = load_image(args.image_file) 59 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() 60 | 61 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 62 | 63 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 64 | keywords = [stop_str] 65 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 66 | 67 | with torch.inference_mode(): 68 | output_ids = model.generate( 69 | input_ids, 70 | images=image_tensor, 71 | do_sample=True, 72 | temperature=0.2, 73 | max_new_tokens=1024, 74 | use_cache=True, 75 | stopping_criteria=[stopping_criteria]) 76 | 77 | input_token_len = input_ids.shape[1] 78 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 79 | if n_diff_input_output > 0: 80 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 81 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 82 | outputs = outputs.strip() 83 | if outputs.endswith(stop_str): 84 | outputs = outputs[:-len(stop_str)] 85 | outputs = outputs.strip() 86 | print(outputs) 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 91 | parser.add_argument("--model-base", type=str, default=None) 92 | parser.add_argument("--image-file", type=str, required=True) 93 | parser.add_argument("--query", type=str, required=True) 94 | parser.add_argument("--conv-mode", type=str, default=None) 95 | args = parser.parse_args() 96 | 97 | eval_model(args) 98 | -------------------------------------------------------------------------------- /llava/model/language_model/mpt/meta_init_context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import torch 3 | import torch.nn as nn 4 | 5 | @contextmanager 6 | def init_empty_weights(include_buffers: bool=False): 7 | """Meta initialization context manager. 8 | 9 | A context manager under which models are initialized with all parameters 10 | on the meta device, therefore creating an empty model. Useful when just 11 | initializing the model would blow the available RAM. 12 | 13 | Args: 14 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 15 | not to also put all buffers on the meta device while initializing. 16 | 17 | Example: 18 | ```python 19 | import torch.nn as nn 20 | 21 | # Initialize a model with 100 billions parameters in no time and without using any RAM. 22 | with init_empty_weights(): 23 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) 24 | ``` 25 | 26 | 27 | 28 | Any model created under this context manager has no weights. As such you can't do something like 29 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. 30 | 31 | 32 | """ 33 | with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f: 34 | yield f 35 | 36 | @contextmanager 37 | def init_on_device(device: torch.device, include_buffers: bool=False): 38 | """Device initialization context manager. 39 | 40 | A context manager under which models are initialized with all parameters 41 | on the specified device. 42 | 43 | Args: 44 | device (`torch.device`): Device to initialize all parameters on. 45 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 46 | not to also put all buffers on the meta device while initializing. 47 | 48 | Example: 49 | ```python 50 | import torch.nn as nn 51 | 52 | with init_on_device(device=torch.device("cuda")): 53 | tst = nn.Liner(100, 100) # on `cuda` device 54 | ``` 55 | """ 56 | old_register_parameter = nn.Module.register_parameter 57 | if include_buffers: 58 | old_register_buffer = nn.Module.register_buffer 59 | 60 | def register_empty_parameter(module, name, param): 61 | old_register_parameter(module, name, param) 62 | if param is not None: 63 | param_cls = type(module._parameters[name]) 64 | kwargs = module._parameters[name].__dict__ 65 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) 66 | 67 | def register_empty_buffer(module, name, buffer): 68 | old_register_buffer(module, name, buffer) 69 | if buffer is not None: 70 | module._buffers[name] = module._buffers[name].to(device) 71 | if include_buffers: 72 | tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']} 73 | else: 74 | tensor_constructors_to_patch = {} 75 | 76 | def patch_tensor_constructor(fn): 77 | 78 | def wrapper(*args, **kwargs): 79 | kwargs['device'] = device 80 | return fn(*args, **kwargs) 81 | return wrapper 82 | try: 83 | nn.Module.register_parameter = register_empty_parameter 84 | if include_buffers: 85 | nn.Module.register_buffer = register_empty_buffer 86 | for torch_function_name in tensor_constructors_to_patch.keys(): 87 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) 88 | yield 89 | finally: 90 | nn.Module.register_parameter = old_register_parameter 91 | if include_buffers: 92 | nn.Module.register_buffer = old_register_buffer 93 | for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items(): 94 | setattr(torch, torch_function_name, old_torch_function) -------------------------------------------------------------------------------- /Evaluation.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | 3 | ## Property Prediction 4 | ### Classification task 5 | ```Shell 6 | # Sampling 7 | TASK="MoleculeNet" 8 | GRAPH_TOWER=moleculestm 9 | DATASET=bbbp # [bace, hiv] 10 | EPOCH=20 11 | python -m llava.eval.molecule_metrics.MoleculeNet_classification \ 12 | --dataspace /cto_labs/AIDD/DATA/MoleculeNet \ 13 | --dataset $DATASET \ 14 | --model-path checkpoints/Graph-LLaVA/$TASK-llava-$GRAPH_TOWER-vicuna-v1-3-7b-finetune_lora \ 15 | --graph-checkpoint-path checkpoints/MoleculeSTM/molecule_model.pth \ 16 | --model-base checkpoints/vicuna-v1-3-7b \ 17 | --batch_size 1 \ 18 | --add-selfies \ # if set to True, then use InstructMol-GS to inference 19 | --debug 20 | # Evaluation 21 | python -m llava.eval.molecule_metrics.property_metrics \ 22 | --eval_result_file eval_result/$GRAPH_TOWER-$TASK-${EPOCH}ep.jsonl 23 | ``` 24 | 25 | ### Regression task 26 | 27 | Please download the regression test set from [Huggingface Mol-Instructions Dataset](https://huggingface.co/datasets/zjunlp/Mol-Instructions/blob/main/data/Molecule-oriented_Instructions.zip) 28 | ```Shell 29 | # Sampling 30 | TASK=property_pred 31 | GRAPH_TOWER=moleculestm 32 | EPOCH=5 33 | python -m llava.eval.molecule_metrics.generate_sample \ 34 | --task $TASK \ 35 | --model-path LORA_MODEL_PATH \ 36 | --in-file PATH_TO_PROPERTY_PREDICTION_TEST \ 37 | --answers-file eval_result/$GRAPH_TOWER-$TASK-${EPOCH}ep.jsonl \ 38 | --graph-checkpoint-path checkpoints/$GRAPH_TOWER/molecule_model.pth \ 39 | --model-base checkpoints/vicuna-v1-3-7b \ 40 | --batch_size 1 --temperature 0.2 --top_p 1.0 \ 41 | --add-selfies \ # if set to True, then use InstructMol-GS to inference 42 | --debug 43 | # Evaluation 44 | python -m llava.eval.molecule_metrics.property_metrics \ 45 | --eval_result_file eval_result/$GRAPH_TOWER-$TASK-${EPOCH}ep.jsonl 46 | ``` 47 | 48 | ## Molecule Description Generation 49 | We use the [ChEBI-20 test dataset](assets/chebi-20_data/test.txt) for evaluation. 50 | ```Shell 51 | # Sampling 52 | GRAPH_TOWER=moleculestm 53 | EPOCH=20 54 | OUT_FILE=eval_result/$GRAPH_TOWER-chebi20-molcap-lora-${EPOCH}ep.jsonl 55 | python -m llava.eval.model_molcap \ 56 | --model-path LORA_MODEL_PATH \ 57 | --in-file assets/chebi-20_data/test.txt \ 58 | --answers-file $OUT_FILE \ 59 | --graph-checkpoint-path $INIT_CHECKPOINT_GNN \ 60 | --model-base checkpoints/vicuna-v1-3-7b \ 61 | --batch_size 1 \ 62 | --add-selfies \ # if set to True, then use InstructMol-GS to inference 63 | --debug 64 | # Evaluation 65 | python -m llava.eval.eval_molcap --molcap_result_file $OUT_FILE \ 66 | --text2mol_bert_path checkpoints/scibert_scivocab_uncased 67 | ``` 68 | 69 | 70 | ## Chemical Reaction Analysis 71 | We take **Forward Reaction Prediction** as an example. For **Retrosynthesis Prediction** and **Reagent Prediction** task, just change the `$TASK` to `retrosynthesis` and `reagent_pred` respectively. 72 | 73 | 74 | Please download the test set from [Huggingface Mol-Instructions Dataset](https://huggingface.co/datasets/zjunlp/Mol-Instructions/blob/main/data/Molecule-oriented_Instructions.zip) 75 | ```Shell 76 | # Sampling 77 | TASK=forward_pred 78 | GRAPH_TOWER=moleculestm 79 | EPOCH=5 80 | python -m llava.eval.molecule_metrics.generate_sample \ 81 | --task $TASK \ 82 | --model-path LORA_MODEL_PATH \ 83 | --in-file PATH_TO_FOWARD_REACTION_PREDICTION_TEST \ 84 | --answers-file eval_result/$GRAPH_TOWER-$TASK-${EPOCH}ep.jsonl \ 85 | --graph-checkpoint-path checkpoints/$GRAPH_TOWER/molecule_model.pth \ 86 | --model-base checkpoints/vicuna-v1-3-7b \ 87 | --batch_size 1 --temperature 0.2 --top_p 1.0 \ 88 | --add-selfies \ # if set to True, then use InstructMol-GS to inference 89 | --debug 90 | # Evaluation 91 | ## Calculate the 'BLEU', 'exact match score', 'Levenshtein score' and 'validity' 92 | python -m llava.eval.molecule_metrics.mol_translation_selfies \ 93 | --input_file=eval_result/${GRAPH_TOWER}-${TASK}-${EPOCH}ep.jsonl 94 | ## Calculate the 'MACCS', 'RDK' and 'Morgan' similarity 95 | python -m llava.eval.molecule_metrics.fingerprint_metrics \ 96 | --input_file=eval_result/${GRAPH_TOWER}-${TASK}-${EPOCH}ep.jsonl 97 | ``` -------------------------------------------------------------------------------- /llava/datasets/reagent_pred_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | import copy 5 | import pickle 6 | from typing import Dict, Optional, Sequence, List 7 | import selfies 8 | import torch 9 | from torch.utils.data import Dataset 10 | import transformers 11 | from .preprocess import preprocess, preprocess_multimodal 12 | from .smiles2graph import smiles2graph 13 | 14 | 15 | def construct_instruct_question(product:str): 16 | """ 17 | Construct instruct question for each graph 18 | """ 19 | question_pools = [ 20 | 'Can you suggest some possible reagents that could have been used in the following chemical reaction?', 21 | 'Give some possible reagents that could have been used in the following chemical reaction.', 22 | 'Please propose potential reagents that might have been utilized in the provided chemical reaction.', 23 | 'Please provide possible reagents based on the following chemical reaction.', 24 | ] 25 | question = random.choice(question_pools) 26 | question += f"\nThe product is {product}" 27 | return question 28 | 29 | 30 | class ReagentPredSupervisedGraphDataset(Dataset): 31 | """We use MolInstruction https://huggingface.co/datasets/zjunlp/Mol-Instructions/viewer/Molecule-oriented%20Instructions/reagent_prediction (128K) """ 32 | add_selfies = True 33 | def __init__(self, 34 | data_path: str, 35 | tokenizer: transformers.PreTrainedTokenizer, 36 | data_args, 37 | ): 38 | super(ReagentPredSupervisedGraphDataset, self).__init__() 39 | with open(data_path, "rb") as f: 40 | list_data_dict = json.load(f) 41 | 42 | self.tokenizer = tokenizer 43 | self.list_data_dict = list_data_dict 44 | self.data_args = data_args 45 | 46 | def selfies2smiles(self, selfies_str): 47 | try: 48 | smiles_str = selfies.decoder(selfies_str) 49 | except: 50 | smiles_str = None 51 | return smiles_str 52 | 53 | def __len__(self): 54 | return len(self.list_data_dict) 55 | 56 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 57 | raw = self.list_data_dict[i] 58 | input, output_selfies = raw['input'], raw['output'] 59 | # input: "reactant>>product" 60 | reactant, product = input.split(">>") 61 | # convert input selfies to smiles for building graph 62 | reactant_smiles = self.selfies2smiles(reactant) 63 | if not self.add_selfies: 64 | # insert product to the instruction end 65 | instruction = construct_instruct_question(product) 66 | else: 67 | instruction = raw['instruction'] + f" The reaction is {input}" 68 | if random.random() < 0.5: 69 | instruction = "\n" + instruction 70 | else: 71 | instruction = instruction + "\n" 72 | graph=smiles2graph(reactant_smiles) 73 | sources = dict( 74 | conversations=[ 75 | {"from": "human", "value": instruction}, 76 | {"from": "gpt", "value": output_selfies} 77 | ] 78 | ) 79 | 80 | if isinstance(i, int): 81 | sources = [sources] 82 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME 83 | 84 | if graph is not None: 85 | sources = preprocess_multimodal( 86 | copy.deepcopy([e["conversations"] for e in sources]), 87 | self.data_args) 88 | else: 89 | sources = copy.deepcopy([e["conversations"] for e in sources]) 90 | 91 | data_dict = preprocess( 92 | sources, 93 | self.tokenizer, 94 | has_image=(graph is not None)) 95 | if isinstance(i, int): 96 | data_dict = dict(input_ids=data_dict["input_ids"][0], 97 | labels=data_dict["labels"][0]) 98 | 99 | # graph exist in the data 100 | if graph is not None: 101 | data_dict['graph'] = graph 102 | elif self.data_args.is_multimodal: 103 | raise ValueError("Graph does not exist in the data, but the model is multimodal") 104 | return data_dict -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True) 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | import logging 3 | 4 | import torch 5 | from torch import nn 6 | 7 | import transformers 8 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 9 | 10 | from einops import rearrange 11 | 12 | try: 13 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 14 | except ImportError: 15 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 16 | from flash_attn.bert_padding import unpad_input, pad_input 17 | 18 | 19 | def forward( 20 | self, 21 | hidden_states: torch.Tensor, 22 | attention_mask: Optional[torch.Tensor] = None, 23 | position_ids: Optional[torch.Tensor] = None, 24 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 25 | output_attentions: bool = False, 26 | use_cache: bool = False, 27 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 28 | """Input shape: Batch x Time x Channel 29 | 30 | attention_mask: [bsz, q_len] 31 | """ 32 | bsz, q_len, _ = hidden_states.size() 33 | 34 | query_states = ( 35 | self.q_proj(hidden_states) 36 | .view(bsz, q_len, self.num_heads, self.head_dim) 37 | .transpose(1, 2) 38 | ) 39 | key_states = ( 40 | self.k_proj(hidden_states) 41 | .view(bsz, q_len, self.num_heads, self.head_dim) 42 | .transpose(1, 2) 43 | ) 44 | value_states = ( 45 | self.v_proj(hidden_states) 46 | .view(bsz, q_len, self.num_heads, self.head_dim) 47 | .transpose(1, 2) 48 | ) 49 | # [bsz, q_len, nh, hd] 50 | # [bsz, nh, q_len, hd] 51 | 52 | kv_seq_len = key_states.shape[-2] 53 | assert past_key_value is None, "past_key_value is not supported" 54 | 55 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 56 | query_states, key_states = apply_rotary_pos_emb( 57 | query_states, key_states, cos, sin, position_ids 58 | ) 59 | # [bsz, nh, t, hd] 60 | assert not output_attentions, "output_attentions is not supported" 61 | assert not use_cache, "use_cache is not supported" 62 | 63 | # Flash attention codes from 64 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 65 | 66 | # transform the data into the format required by flash attention 67 | qkv = torch.stack( 68 | [query_states, key_states, value_states], dim=2 69 | ) # [bsz, nh, 3, q_len, hd] 70 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 71 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 72 | # the attention_mask should be the same as the key_padding_mask 73 | key_padding_mask = attention_mask 74 | 75 | if key_padding_mask is None: 76 | qkv = rearrange(qkv, "b s ... -> (b s) ...") 77 | max_s = q_len 78 | cu_q_lens = torch.arange( 79 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 80 | ) 81 | output = flash_attn_unpadded_qkvpacked_func( 82 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 83 | ) 84 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz) 85 | else: 86 | nheads = qkv.shape[-2] 87 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 88 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 89 | x_unpad = rearrange( 90 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 91 | ) 92 | output_unpad = flash_attn_unpadded_qkvpacked_func( 93 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 94 | ) 95 | output = rearrange( 96 | pad_input( 97 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len 98 | ), 99 | "b s (h d) -> b s h d", 100 | h=nheads, 101 | ) 102 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None 103 | 104 | 105 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 106 | # requires the attention mask to be the same as the key_padding_mask 107 | def _prepare_decoder_attention_mask( 108 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 109 | ): 110 | # [bsz, seq_len] 111 | return attention_mask 112 | 113 | 114 | def replace_llama_attn_with_flash_attn(): 115 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 116 | if cuda_major < 8: 117 | logging.warning( 118 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 119 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 120 | ) 121 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 122 | _prepare_decoder_attention_mask 123 | ) 124 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 125 | -------------------------------------------------------------------------------- /llava/eval/molecule_metrics/fingerprint_metrics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code from https://github.com/blender-nlp/MolT5 3 | 4 | ```bibtex 5 | @article{edwards2022translation, 6 | title={Translation between Molecules and Natural Language}, 7 | author={Edwards, Carl and Lai, Tuan and Ros, Kevin and Honke, Garrett and Ji, Heng}, 8 | journal={arXiv preprint arXiv:2204.11817}, 9 | year={2022} 10 | } 11 | ``` 12 | ''' 13 | 14 | import argparse 15 | import csv 16 | import os.path as osp 17 | import numpy as np 18 | import json 19 | from rdkit import Chem 20 | from rdkit.Chem import MACCSkeys 21 | from rdkit import DataStructs 22 | from rdkit.Chem import AllChem 23 | from rdkit import RDLogger 24 | RDLogger.DisableLog('rdApp.*') 25 | import selfies as sf 26 | 27 | 28 | def sf_encode(selfies): 29 | try: 30 | smiles = sf.decoder(selfies) 31 | return smiles 32 | except Exception: 33 | return None 34 | 35 | def convert_to_canonical_smiles(smiles): 36 | if smiles is None: 37 | return None 38 | molecule = Chem.MolFromSmiles(smiles) 39 | if molecule is not None: 40 | canonical_smiles = Chem.MolToSmiles(molecule, isomericSmiles=False, canonical=True) 41 | return canonical_smiles 42 | else: 43 | return None 44 | 45 | def build_evaluate_tuple(result:dict): 46 | # pred 47 | # func = lambda x: x.rsplit(']', 1)[0] + ']' if isinstance(x, str) else x 48 | func = lambda x: x 49 | result["pred_smi"] = convert_to_canonical_smiles(func(sf_encode(result["pred_self"]))) 50 | # gt 51 | result["gt_smi"] = convert_to_canonical_smiles(sf_encode(result["gt_self"])) 52 | return result 53 | 54 | 55 | def evaluate(input_file, morgan_r, verbose=False): 56 | outputs = [] 57 | bad_mols = 0 58 | 59 | with open(osp.join(input_file)) as f: 60 | results = json.load(f) 61 | for i, result in enumerate(results): 62 | result = build_evaluate_tuple(result) 63 | try: 64 | gt_smi = result['gt_smi'] 65 | ot_smi = result['pred_smi'] 66 | 67 | gt_m = Chem.MolFromSmiles(gt_smi) 68 | ot_m = Chem.MolFromSmiles(ot_smi) 69 | 70 | if ot_m == None: raise ValueError('Bad SMILES') 71 | outputs.append((result['prompt'], gt_m, ot_m)) 72 | except: 73 | bad_mols += 1 74 | validity_score = len(outputs)/(len(outputs)+bad_mols) 75 | if verbose: 76 | print('validity:', validity_score) 77 | 78 | 79 | MACCS_sims = [] 80 | morgan_sims = [] 81 | RDK_sims = [] 82 | 83 | enum_list = outputs 84 | 85 | for i, (desc, gt_m, ot_m) in enumerate(enum_list): 86 | 87 | if i % 100 == 0: 88 | if verbose: print(i, 'processed.') 89 | 90 | MACCS_sims.append(DataStructs.FingerprintSimilarity(MACCSkeys.GenMACCSKeys(gt_m), MACCSkeys.GenMACCSKeys(ot_m), metric=DataStructs.TanimotoSimilarity)) 91 | RDK_sims.append(DataStructs.FingerprintSimilarity(Chem.RDKFingerprint(gt_m), Chem.RDKFingerprint(ot_m), metric=DataStructs.TanimotoSimilarity)) 92 | morgan_sims.append(DataStructs.TanimotoSimilarity(AllChem.GetMorganFingerprint(gt_m,morgan_r), AllChem.GetMorganFingerprint(ot_m, morgan_r))) 93 | 94 | maccs_sims_score = np.mean(MACCS_sims) 95 | rdk_sims_score = np.mean(RDK_sims) 96 | morgan_sims_score = np.mean(morgan_sims) 97 | if verbose: 98 | print('Average MACCS Similarity:', maccs_sims_score) 99 | print('Average RDK Similarity:', rdk_sims_score) 100 | print('Average Morgan Similarity:', morgan_sims_score) 101 | return validity_score, maccs_sims_score, rdk_sims_score, morgan_sims_score 102 | 103 | 104 | ## TEST ## 105 | def test_out_selfies_validity(args): 106 | with open(osp.join(args.input_file)) as f: 107 | results = json.load(f) 108 | bad_selfies = 0 109 | bad_mols = 0 110 | bad_gt_selfies = 0 111 | for i, result in enumerate(results): 112 | pred = result['pred_self'] 113 | smi = sf_encode(pred) 114 | if not smi: 115 | bad_selfies += 1 116 | else: 117 | try: 118 | m = Chem.MolFromSmiles(smi) 119 | if m is None: 120 | bad_mols += 1 121 | except: 122 | bad_mols += 1 123 | gt = result['gt_self'] 124 | gt_smi = sf_encode(gt) 125 | if not gt_smi: 126 | bad_gt_selfies += 1 127 | print('Pred: bad selfies:', bad_selfies) 128 | print('Pred: bad mols:', bad_mols) 129 | print('GT: bad selfies:', bad_gt_selfies) 130 | 131 | if __name__ == "__main__": 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument('--input_file', type=str, default='caption2smiles_example.json', help='path where test generations are saved') 134 | parser.add_argument('--morgan_r', type=int, default=2, help='morgan fingerprint radius') 135 | args = parser.parse_args() 136 | # test_out_selfies_validity(args) 137 | evaluate(args.input_file, args.morgan_r, True) 138 | 139 | 140 | """ 141 | # retrosynthesis 142 | python -m llava.eval.molecule_metrics.fingerprint_metrics \ 143 | --input_file=eval_result/moleculestm-retrosynthesis-5ep.jsonl 144 | 145 | # reagent_pred 146 | python -m llava.eval.molecule_metrics.fingerprint_metrics \ 147 | --input_file=eval_result/moleculestm-reagent_pred-5ep.jsonl 148 | 149 | # forward_pred 150 | python -m llava.eval.molecule_metrics.fingerprint_metrics \ 151 | --input_file=eval_result/moleculestm-forward_pred-5ep.jsonl 152 | """ -------------------------------------------------------------------------------- /llava/model/language_model/llava_graph_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, \ 23 | LlamaConfig, LlamaModel, LlamaForCausalLM 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | 27 | from ..llava_graph_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaGraphLlamaConfig(LlamaConfig): 31 | model_type = "llava_graph" 32 | 33 | 34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 35 | config_class = LlavaGraphLlamaConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(LlavaLlamaModel, self).__init__(config) 39 | 40 | 41 | class LlavaGraphLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaGraphLlamaConfig 43 | 44 | def __init__(self, config): 45 | super(LlavaGraphLlamaForCausalLM, self).__init__(config) 46 | self.model = LlavaLlamaModel(config) 47 | 48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.model 55 | 56 | def forward( 57 | self, 58 | input_ids: torch.LongTensor = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | past_key_values: Optional[List[torch.FloatTensor]] = None, 61 | inputs_embeds: Optional[torch.FloatTensor] = None, 62 | labels: Optional[torch.LongTensor] = None, 63 | use_cache: Optional[bool] = None, 64 | output_attentions: Optional[bool] = None, 65 | output_hidden_states: Optional[bool] = None, 66 | graphs: Optional[torch.FloatTensor] = None, 67 | return_dict: Optional[bool] = None, 68 | ) -> Union[Tuple, CausalLMOutputWithPast]: 69 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 70 | output_hidden_states = ( 71 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 72 | ) 73 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 74 | 75 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, graphs) 76 | 77 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 78 | outputs = self.model( 79 | input_ids=input_ids, 80 | attention_mask=attention_mask, 81 | past_key_values=past_key_values, 82 | inputs_embeds=inputs_embeds, 83 | use_cache=use_cache, 84 | output_attentions=output_attentions, 85 | output_hidden_states=output_hidden_states, 86 | return_dict=return_dict 87 | ) 88 | 89 | hidden_states = outputs[0] 90 | logits = self.lm_head(hidden_states) 91 | 92 | loss = None 93 | if labels is not None: 94 | # Shift so that tokens < n predict n 95 | shift_logits = logits[..., :-1, :].contiguous() 96 | shift_labels = labels[..., 1:].contiguous() 97 | # Flatten the tokens 98 | loss_fct = CrossEntropyLoss() 99 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 100 | shift_labels = shift_labels.view(-1) 101 | # Enable model/pipeline parallelism 102 | shift_labels = shift_labels.to(shift_logits.device) 103 | loss = loss_fct(shift_logits, shift_labels) 104 | 105 | if not return_dict: 106 | output = (logits,) + outputs[1:] 107 | return (loss,) + output if loss is not None else output 108 | 109 | return CausalLMOutputWithPast( 110 | loss=loss, 111 | logits=logits, 112 | past_key_values=outputs.past_key_values, 113 | hidden_states=outputs.hidden_states, 114 | attentions=outputs.attentions, 115 | ) 116 | 117 | def prepare_inputs_for_generation( 118 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 119 | ): 120 | if past_key_values: 121 | input_ids = input_ids[:, -1:] 122 | 123 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 124 | if inputs_embeds is not None and past_key_values is None: 125 | model_inputs = {"inputs_embeds": inputs_embeds} 126 | else: 127 | model_inputs = {"input_ids": input_ids} 128 | 129 | model_inputs.update( 130 | { 131 | "past_key_values": past_key_values, 132 | "use_cache": kwargs.get("use_cache"), 133 | "attention_mask": attention_mask, 134 | "graphs": kwargs.get("graphs", None), 135 | } 136 | ) 137 | return model_inputs 138 | 139 | AutoConfig.register("llava_graph", LlavaGraphLlamaConfig) 140 | AutoModelForCausalLM.register(LlavaGraphLlamaConfig, LlavaGraphLlamaForCausalLM) 141 | -------------------------------------------------------------------------------- /llava/datasets/lazy_supervised_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import pickle 5 | from PIL import Image 6 | from typing import Dict, Optional, Sequence, List 7 | 8 | import torch 9 | from torch.utils.data import Dataset 10 | import transformers 11 | from .preprocess import preprocess, preprocess_multimodal 12 | 13 | class LazySupervisedDataset(Dataset): 14 | """Dataset for supervised fine-tuning.""" 15 | 16 | def __init__(self, data_path: str, 17 | tokenizer: transformers.PreTrainedTokenizer, 18 | data_args): 19 | super(LazySupervisedDataset, self).__init__() 20 | list_data_dict = json.load(open(data_path, "r")) 21 | 22 | self.tokenizer = tokenizer 23 | self.list_data_dict = list_data_dict 24 | self.data_args = data_args 25 | 26 | def __len__(self): 27 | return len(self.list_data_dict) 28 | 29 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 30 | sources = self.list_data_dict[i] 31 | if isinstance(i, int): 32 | sources = [sources] 33 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME 34 | if 'image' in sources[0]: 35 | image_file = self.list_data_dict[i]['image'] 36 | image_folder = self.data_args.image_folder 37 | processor = self.data_args.image_processor 38 | image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') 39 | if self.data_args.image_aspect_ratio == 'pad': 40 | def expand2square(pil_img, background_color): 41 | width, height = pil_img.size 42 | if width == height: 43 | return pil_img 44 | elif width > height: 45 | result = Image.new(pil_img.mode, (width, width), background_color) 46 | result.paste(pil_img, (0, (width - height) // 2)) 47 | return result 48 | else: 49 | result = Image.new(pil_img.mode, (height, height), background_color) 50 | result.paste(pil_img, ((height - width) // 2, 0)) 51 | return result 52 | image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) 53 | image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 54 | else: 55 | image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 56 | sources = preprocess_multimodal( 57 | copy.deepcopy([e["conversations"] for e in sources]), 58 | self.data_args) 59 | else: 60 | sources = copy.deepcopy([e["conversations"] for e in sources]) 61 | data_dict = preprocess( 62 | sources, 63 | self.tokenizer, 64 | has_image=('image' in self.list_data_dict[i])) 65 | if isinstance(i, int): 66 | data_dict = dict(input_ids=data_dict["input_ids"][0], 67 | labels=data_dict["labels"][0]) 68 | 69 | # image exist in the data 70 | if 'image' in self.list_data_dict[i]: 71 | data_dict['image'] = image 72 | elif self.data_args.is_multimodal: 73 | # image does not exist in the data, but the model is multimodal 74 | crop_size = self.data_args.image_processor.crop_size 75 | data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) 76 | return data_dict 77 | 78 | 79 | class LazySupervisedGraphDataset(Dataset): 80 | """Dataset for supervised fine-tuning.""" 81 | 82 | def __init__(self, data_path: str, 83 | tokenizer: transformers.PreTrainedTokenizer, 84 | data_args): 85 | super(LazySupervisedGraphDataset, self).__init__() 86 | with open(data_path, "rb") as f: 87 | list_data_dict = pickle.load(f) 88 | 89 | self.tokenizer = tokenizer 90 | self.list_data_dict = list_data_dict 91 | self.data_args = data_args 92 | 93 | def __len__(self): 94 | return len(self.list_data_dict) 95 | 96 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 97 | sources = self.list_data_dict[i] 98 | if isinstance(i, int): 99 | sources = [sources] 100 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME 101 | 102 | if 'graph' in sources[0]: 103 | graph = self.list_data_dict[i]['graph'] 104 | 105 | if 'question' in sources[0] and 'answer' in sources[0]: 106 | for e in sources: 107 | e['conversations'] = [ 108 | {"from": "human", "value": "\n" + e["question"]}, 109 | {"from": "gpt", "value": e["answer"]} 110 | ] 111 | 112 | sources = preprocess_multimodal( 113 | copy.deepcopy([e["conversations"] for e in sources]), 114 | self.data_args) 115 | else: 116 | if 'question' in sources[0] and 'answer' in sources[0]: 117 | for e in sources: 118 | e['conversations'] = [ 119 | {"from": "human", "value": "\n" + e["question"]}, 120 | {"from": "gpt", "value": e["answer"]} 121 | ] 122 | 123 | sources = copy.deepcopy([e["conversations"] for e in sources]) 124 | data_dict = preprocess( 125 | sources, 126 | self.tokenizer, 127 | has_image=('graph' in self.list_data_dict[i])) 128 | if isinstance(i, int): 129 | data_dict = dict(input_ids=data_dict["input_ids"][0], 130 | labels=data_dict["labels"][0]) 131 | 132 | # graph exist in the data 133 | if 'graph' in self.list_data_dict[i]: 134 | data_dict['graph'] = graph 135 | elif self.data_args.is_multimodal: 136 | raise ValueError("Graph does not exist in the data, but the model is multimodal") 137 | return data_dict -------------------------------------------------------------------------------- /llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | 26 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 27 | from .modeling_llama_local_moda import LlamaConfig, LlamaModel, LlamaForCausalLM 28 | 29 | class LlavaConfig(LlamaConfig): 30 | model_type = "llava" 31 | 32 | 33 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 34 | config_class = LlavaConfig 35 | 36 | def __init__(self, config: LlamaConfig): 37 | super(LlavaLlamaModel, self).__init__(config) 38 | 39 | 40 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 41 | config_class = LlavaConfig 42 | 43 | def __init__(self, config): 44 | super(LlamaForCausalLM, self).__init__(config) 45 | self.model = LlavaLlamaModel(config) 46 | 47 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 48 | 49 | # Initialize weights and apply final processing 50 | self.post_init() 51 | 52 | def get_model(self): 53 | return self.model 54 | 55 | def forward( 56 | self, 57 | input_ids: torch.LongTensor = None, 58 | attention_mask: Optional[torch.Tensor] = None, 59 | past_key_values: Optional[List[torch.FloatTensor]] = None, 60 | inputs_embeds: Optional[torch.FloatTensor] = None, 61 | labels: Optional[torch.LongTensor] = None, 62 | use_cache: Optional[bool] = None, 63 | output_attentions: Optional[bool] = None, 64 | output_hidden_states: Optional[bool] = None, 65 | images: Optional[torch.FloatTensor] = None, 66 | return_dict: Optional[bool] = None, 67 | vflag: Optional[torch.Tensor] = None, 68 | tflag: Optional[torch.Tensor] = None, 69 | ) -> Union[Tuple, CausalLMOutputWithPast]: 70 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 71 | output_hidden_states = ( 72 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 73 | ) 74 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 75 | 76 | input_ids, attention_mask, past_key_values, inputs_embeds, labels, vflag_pro, tflag_pro = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 77 | 78 | vflag = vflag if vflag_pro is None else vflag_pro 79 | tflag = tflag if tflag_pro is None else tflag_pro 80 | 81 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 82 | outputs = self.model( 83 | input_ids=input_ids, 84 | attention_mask=attention_mask, 85 | past_key_values=past_key_values, 86 | inputs_embeds=inputs_embeds, 87 | use_cache=use_cache, 88 | output_attentions=output_attentions, 89 | output_hidden_states=output_hidden_states, 90 | return_dict=return_dict, 91 | vflag = vflag, 92 | tflag = tflag 93 | ) 94 | 95 | hidden_states = outputs[0] 96 | logits = self.lm_head(hidden_states) 97 | 98 | loss = None 99 | if labels is not None: 100 | # Shift so that tokens < n predict n 101 | shift_logits = logits[..., :-1, :].contiguous() 102 | shift_labels = labels[..., 1:].contiguous() 103 | # Flatten the tokens 104 | loss_fct = CrossEntropyLoss() 105 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 106 | shift_labels = shift_labels.view(-1) 107 | # Enable model/pipeline parallelism 108 | shift_labels = shift_labels.to(shift_logits.device) 109 | loss = loss_fct(shift_logits, shift_labels) 110 | 111 | if not return_dict: 112 | output = (logits,) + outputs[1:] 113 | return (loss,) + output if loss is not None else output 114 | 115 | return CausalLMOutputWithPast( 116 | loss=loss, 117 | logits=logits, 118 | past_key_values=outputs.past_key_values, 119 | hidden_states=outputs.hidden_states, 120 | attentions=outputs.attentions, 121 | ) 122 | 123 | def prepare_inputs_for_generation( 124 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 125 | ): 126 | if past_key_values: 127 | input_ids = input_ids[:, -1:] 128 | 129 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 130 | if inputs_embeds is not None and past_key_values is None: 131 | model_inputs = {"inputs_embeds": inputs_embeds} 132 | else: 133 | model_inputs = {"input_ids": input_ids} 134 | 135 | model_inputs.update( 136 | { 137 | "past_key_values": past_key_values, 138 | "use_cache": kwargs.get("use_cache"), 139 | "attention_mask": attention_mask, 140 | "images": kwargs.get("images", None), 141 | } 142 | ) 143 | return model_inputs 144 | 145 | AutoConfig.register("llava", LlavaConfig) 146 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) -------------------------------------------------------------------------------- /llava/eval/molecule_metrics/mol_translation_selfies.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code from https://github.com/blender-nlp/MolT5 3 | 4 | ```bibtex 5 | @article{edwards2022translation, 6 | title={Translation between Molecules and Natural Language}, 7 | author={Edwards, Carl and Lai, Tuan and Ros, Kevin and Honke, Garrett and Ji, Heng}, 8 | journal={arXiv preprint arXiv:2204.11817}, 9 | year={2022} 10 | } 11 | ``` 12 | ''' 13 | 14 | import pickle 15 | import argparse 16 | import csv 17 | import json 18 | import os.path as osp 19 | import numpy as np 20 | from nltk.translate.bleu_score import corpus_bleu 21 | from Levenshtein import distance as lev 22 | from rdkit import Chem 23 | from rdkit import RDLogger 24 | RDLogger.DisableLog('rdApp.*') 25 | import selfies as sf 26 | 27 | def sf_encode(selfies): 28 | try: 29 | smiles = sf.decoder(selfies) 30 | return smiles 31 | except Exception: 32 | return None 33 | 34 | def convert_to_canonical_smiles(smiles): 35 | if smiles is None: 36 | return None 37 | molecule = Chem.MolFromSmiles(smiles) 38 | if molecule is not None: 39 | canonical_smiles = Chem.MolToSmiles(molecule, isomericSmiles=False, canonical=True) 40 | return canonical_smiles 41 | else: 42 | return None 43 | 44 | def build_evaluate_tuple(result:dict): 45 | # pred 46 | # func = lambda x: x.rsplit(']', 1)[0] + ']' if isinstance(x, str) else x 47 | func = lambda x: x 48 | result["pred_smi"] = convert_to_canonical_smiles(func(sf_encode(result["pred_self"]))) 49 | # gt 50 | result["gt_smi"] = convert_to_canonical_smiles(sf_encode(result["gt_self"])) 51 | return result 52 | 53 | 54 | def evaluate(input_file, verbose=False): 55 | outputs = [] 56 | 57 | with open(osp.join(input_file)) as f: 58 | results = json.load(f) 59 | for i, result in enumerate(results): 60 | result = build_evaluate_tuple(result) 61 | gt_self = result['gt_self'] 62 | ot_self = result['pred_self'] 63 | gt_smi = result['gt_smi'] 64 | ot_smi = result['pred_smi'] 65 | if ot_smi is None: 66 | continue 67 | outputs.append((result['prompt'], gt_self, ot_self, gt_smi, ot_smi)) 68 | 69 | 70 | bleu_self_scores = [] 71 | bleu_smi_scores = [] 72 | 73 | references_self = [] 74 | hypotheses_self = [] 75 | 76 | references_smi = [] 77 | hypotheses_smi = [] 78 | 79 | for i, (des, gt_self, ot_self, gt_smi, ot_smi) in enumerate(outputs): 80 | 81 | if i % 100 == 0: 82 | if verbose: 83 | print(i, 'processed.') 84 | 85 | gt_self_tokens = [c for c in gt_self] 86 | out_self_tokens = [c for c in ot_self] 87 | 88 | references_self.append([gt_self_tokens]) 89 | hypotheses_self.append(out_self_tokens) 90 | 91 | if ot_smi is None: 92 | continue 93 | 94 | gt_smi_tokens = [c for c in gt_smi] 95 | ot_smi_tokens = [c for c in ot_smi] 96 | 97 | references_smi.append([gt_smi_tokens]) 98 | hypotheses_smi.append(ot_smi_tokens) 99 | 100 | 101 | # BLEU score 102 | bleu_score_self = corpus_bleu(references_self, hypotheses_self) 103 | if verbose: print(f'SELFIES BLEU score', bleu_score_self) 104 | 105 | references_self = [] 106 | hypotheses_self = [] 107 | 108 | references_smi = [] 109 | hypotheses_smi = [] 110 | 111 | levs_self = [] 112 | levs_smi = [] 113 | 114 | num_exact = 0 115 | 116 | bad_mols = 0 117 | 118 | for i, (des, gt_self, ot_self, gt_smi, ot_smi) in enumerate(outputs): 119 | 120 | hypotheses_self.append(ot_self) 121 | references_self.append(gt_self) 122 | 123 | hypotheses_smi.append(ot_smi) 124 | references_smi.append(gt_smi) 125 | 126 | try: 127 | m_out = Chem.MolFromSmiles(ot_smi) 128 | m_gt = Chem.MolFromSmiles(gt_smi) 129 | 130 | if Chem.MolToInchi(m_out) == Chem.MolToInchi(m_gt): num_exact += 1 131 | #if gt == out: num_exact += 1 #old version that didn't standardize strings 132 | except: 133 | bad_mols += 1 134 | 135 | levs_self.append(lev(ot_self, gt_self)) 136 | levs_smi.append(lev(ot_smi, gt_smi)) 137 | 138 | 139 | # Exact matching score 140 | exact_match_score = num_exact/(i+1) 141 | if verbose: 142 | print('Exact Match:') 143 | print(exact_match_score) 144 | 145 | # Levenshtein score 146 | levenshtein_score_smi = np.mean(levs_smi) 147 | if verbose: 148 | print('SMILES Levenshtein:') 149 | print(levenshtein_score_smi) 150 | 151 | validity_score = 1 - bad_mols/len(outputs) 152 | if verbose: 153 | print('validity:', validity_score) 154 | 155 | 156 | ## TEST ## 157 | def test_out_selfies_validity(args): 158 | with open(osp.join(args.input_file)) as f: 159 | results = json.load(f) 160 | bad_selfies = 0 161 | for i, result in enumerate(results): 162 | pred = result['pred_self'] 163 | if not sf_encode(pred): 164 | print(i, pred, 'bad selfies') 165 | bad_selfies += 1 166 | print('bad selfies:', bad_selfies) 167 | 168 | 169 | 170 | if __name__ == "__main__": 171 | parser = argparse.ArgumentParser() 172 | parser.add_argument('--input_file', type=str, default='caption2smiles_example.json', help='path where test generations are saved') 173 | args = parser.parse_args() 174 | # test_out_selfies_validity(args) 175 | evaluate(args.input_file, verbose=True) 176 | 177 | 178 | """ 179 | # retrosynthesis 180 | python -m llava.eval.molecule_metrics.mol_translation_selfies \ 181 | --input_file=eval_result/moleculestm-retrosynthesis-5ep.jsonl 182 | 183 | # reagent prediction 184 | python -m llava.eval.molecule_metrics.mol_translation_selfies \ 185 | --input_file=eval_result/moleculestm-reagent_pred-5ep.jsonl 186 | 187 | # forward_pred 188 | python -m llava.eval.molecule_metrics.mol_translation_selfies \ 189 | --input_file=eval_result/moleculestm-forward_pred-5ep.jsonl 190 | """ -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple 17 | import warnings 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | import math 22 | 23 | from transformers import AutoConfig, AutoModelForCausalLM 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | 26 | from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel 27 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaMPTConfig(MPTConfig): 31 | model_type = "llava_mpt" 32 | 33 | 34 | class LlavaMPTModel(LlavaMetaModel, MPTModel): 35 | config_class = LlavaMPTConfig 36 | 37 | def __init__(self, config: MPTConfig): 38 | config.hidden_size = config.d_model 39 | super(LlavaMPTModel, self).__init__(config) 40 | 41 | def embed_tokens(self, x): 42 | return self.wte(x) 43 | 44 | 45 | class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM): 46 | config_class = LlavaMPTConfig 47 | supports_gradient_checkpointing = True 48 | 49 | def __init__(self, config): 50 | super(MPTForCausalLM, self).__init__(config) 51 | 52 | if not config.tie_word_embeddings: 53 | raise ValueError('MPTForCausalLM only supports tied word embeddings') 54 | self.transformer = LlavaMPTModel(config) 55 | self.logit_scale = None 56 | if config.logit_scale is not None: 57 | logit_scale = config.logit_scale 58 | if isinstance(logit_scale, str): 59 | if logit_scale == 'inv_sqrt_d_model': 60 | logit_scale = 1 / math.sqrt(config.d_model) 61 | else: 62 | raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") 63 | self.logit_scale = logit_scale 64 | 65 | def get_model(self): 66 | return self.transformer 67 | 68 | def _set_gradient_checkpointing(self, module, value=False): 69 | if isinstance(module, LlavaMPTModel): 70 | module.gradient_checkpointing = value 71 | 72 | def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None): 73 | return_dict = return_dict if return_dict is not None else self.config.return_dict 74 | use_cache = use_cache if use_cache is not None else self.config.use_cache 75 | 76 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 77 | outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache) 78 | # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338 79 | logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight) 80 | if self.logit_scale is not None: 81 | if self.logit_scale == 0: 82 | warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.') 83 | logits *= self.logit_scale 84 | loss = None 85 | if labels is not None: 86 | labels = torch.roll(labels, shifts=-1) 87 | labels[:, -1] = -100 88 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) 89 | return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states) 90 | 91 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 92 | if inputs_embeds is not None: 93 | raise NotImplementedError('inputs_embeds is not implemented for MPT yet') 94 | attention_mask = kwargs['attention_mask'].bool() 95 | if attention_mask[:, -1].sum() != attention_mask.shape[0]: 96 | raise NotImplementedError('MPT does not support generation with right padding.') 97 | if self.transformer.attn_uses_sequence_id and self.training: 98 | sequence_id = torch.zeros_like(input_ids[:1]) 99 | else: 100 | sequence_id = None 101 | if past_key_values is not None: 102 | input_ids = input_ids[:, -1].unsqueeze(-1) 103 | if self.transformer.prefix_lm: 104 | prefix_mask = torch.ones_like(attention_mask) 105 | if kwargs.get('use_cache') == False: 106 | raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') 107 | else: 108 | prefix_mask = None 109 | return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)} 110 | 111 | 112 | AutoConfig.register("llava_mpt", LlavaMPTConfig) 113 | AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM) 114 | -------------------------------------------------------------------------------- /llava/serve/cli_graph.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, MM_ENCODER_CFG 9 | from llava.mol_utils import check_smiles_validity 10 | from llava.datasets.smiles2graph import smiles2graph 11 | 12 | from typing import Dict 13 | from transformers import TextStreamer 14 | from torch_geometric.data import Data 15 | 16 | 17 | def _convert_dict_to_Data(data_dict: Dict) -> Data: 18 | return Data( 19 | x=torch.asarray(data_dict['node_feat']), 20 | edge_attr=torch.asarray(data_dict['edge_feat']), 21 | edge_index=torch.asarray(data_dict['edge_index']), 22 | ) 23 | 24 | 25 | def main(args): 26 | # device 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | 29 | # Model 30 | disable_torch_init() 31 | model_name = get_model_name_from_path(args.model_path) 32 | # graph encoder config 33 | mm_encoder_cfg = MM_ENCODER_CFG(init_checkpoint=args.graph_checkpoint_path) 34 | mm_encoder_cfg = mm_encoder_cfg.dict() 35 | # load model 36 | tokenizer, model, _, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, mm_encoder_cfg=mm_encoder_cfg) 37 | 38 | if 'llama-2' in model_name.lower(): 39 | conv_mode = "llava_llama_2" 40 | elif "v1" in model_name.lower(): 41 | conv_mode = "llava_v1" 42 | elif "mpt" in model_name.lower(): 43 | conv_mode = "mpt" 44 | else: 45 | conv_mode = "llava_v0" 46 | 47 | if args.conv_mode is not None and conv_mode != args.conv_mode: 48 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 49 | else: 50 | args.conv_mode = conv_mode 51 | 52 | conv = conv_templates[args.conv_mode].copy() 53 | if "mpt" in model_name.lower(): 54 | roles = ('user', 'assistant') 55 | else: 56 | roles = conv.roles 57 | 58 | # Input SMILES 59 | smiles = None 60 | while not smiles or not check_smiles_validity(smiles): 61 | smiles = input("Please enter a valid SMILES: ") 62 | graph = smiles2graph(smiles) 63 | graph_tensor = [_convert_dict_to_Data(graph).to(device)] 64 | 65 | while True: 66 | try: 67 | inp = input(f"{roles[0]}: ") 68 | except EOFError: 69 | inp = "" 70 | if inp.lower() in ["quit", "exit"]: 71 | print("exit...") 72 | break 73 | elif inp == "reset": 74 | conv = conv_templates[args.conv_mode].copy() 75 | print("reset conversation...") 76 | smiles = None 77 | while not smiles or not check_smiles_validity(smiles): 78 | smiles = input("Please enter a valid SMILES: ") 79 | graph = smiles2graph(smiles) 80 | graph_tensor = [_convert_dict_to_Data(graph).to(device)] 81 | continue 82 | 83 | print(f"{roles[1]}: ", end="") 84 | 85 | if graph is not None: 86 | # first message 87 | if model.config.mm_use_im_start_end: 88 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp 89 | else: 90 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp 91 | conv.append_message(conv.roles[0], inp) 92 | graph = None 93 | else: 94 | # later messages 95 | conv.append_message(conv.roles[0], inp) 96 | conv.append_message(conv.roles[1], None) 97 | prompt = conv.get_prompt() 98 | 99 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 100 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 101 | keywords = [stop_str] 102 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 103 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 104 | 105 | with torch.inference_mode(): 106 | output_ids = model.generate( 107 | input_ids, 108 | graphs=graph_tensor, 109 | do_sample=True, 110 | temperature=0.2, 111 | max_new_tokens=1024, 112 | streamer=streamer, 113 | use_cache=True, 114 | stopping_criteria=[stopping_criteria]) 115 | 116 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 117 | conv.messages[-1][-1] = outputs 118 | 119 | if args.debug: 120 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 121 | 122 | 123 | if __name__ == "__main__": 124 | parser = argparse.ArgumentParser() 125 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 126 | parser.add_argument("--graph-checkpoint-path", type=str, required=True) 127 | parser.add_argument("--model-base", type=str, default=None) 128 | parser.add_argument("--num-gpus", type=int, default=1) 129 | # parser.add_argument("--smiles", type=str, help="SMILES string", default="C([C@H]([C@H]([C@@H]([C@H](CO)O)O)O)O)O") 130 | parser.add_argument("--conv-mode", type=str, default=None) 131 | parser.add_argument("--temperature", type=float, default=0.2) 132 | parser.add_argument("--max-new-tokens", type=int, default=512) 133 | parser.add_argument("--load-8bit", action="store_true") 134 | parser.add_argument("--load-4bit", action="store_true") 135 | parser.add_argument("--debug", action="store_true") 136 | args = parser.parse_args() 137 | main(args) 138 | 139 | 140 | """ 141 | python -m llava.serve.cli_graph \ 142 | --model-path checkpoints/Graph-LLaVA/molcap-llava-moleculestm-vicuna-v1-3-7b-finetune_lora \ 143 | --graph-checkpoint-path checkpoints/MoleculeSTM/molecule_model.pth \ 144 | --model-base checkpoints/vicuna-v1-3-7b \ 145 | """ -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default}); -------------------------------------------------------------------------------- /llava/datasets/smiles2graph.py: -------------------------------------------------------------------------------- 1 | """ 2 | ref from https://github.com/UCSD-AI4H/drugchat/blob/main/dataset/smiles2graph.py 3 | """ 4 | from rdkit import Chem 5 | import numpy as np 6 | import json 7 | import pickle 8 | import os 9 | from tqdm import tqdm 10 | import random 11 | from typing import Dict 12 | from rdkit.Chem.rdchem import BondType, BondDir, ChiralType 13 | import selfies as sf 14 | 15 | 16 | BOND_TYPE = {BondType.SINGLE: 0, BondType.DOUBLE: 1, BondType.TRIPLE: 2, BondType.AROMATIC: 3} 17 | BOND_DIR = {BondDir.NONE: 0, BondDir.ENDUPRIGHT: 1, BondDir.ENDDOWNRIGHT: 2} 18 | CHI = {ChiralType.CHI_UNSPECIFIED: 0, ChiralType.CHI_TETRAHEDRAL_CW: 1, ChiralType.CHI_TETRAHEDRAL_CCW: 2, ChiralType.CHI_OTHER: 3} 19 | 20 | def bond_dir(bond): 21 | d = bond.GetBondDir() 22 | return BOND_DIR[d] 23 | 24 | def bond_type(bond): 25 | t = bond.GetBondType() 26 | return BOND_TYPE[t] 27 | 28 | def atom_chiral(atom): 29 | c = atom.GetChiralTag() 30 | return CHI[c] 31 | 32 | def atom_to_feature(atom): 33 | num = atom.GetAtomicNum() - 1 34 | if num == -1: 35 | # atom.GetAtomicNum() is 0, which is the generic wildcard atom *, may be used to symbolize an unknown atom of any element. 36 | # See https://biocyc.org/help.html?object=smiles 37 | num = 118 # normal num is [0, 117], so we use 118 to denote wildcard atom * 38 | return [num, atom_chiral(atom)] 39 | 40 | def bond_to_feature(bond): 41 | return [bond_type(bond), bond_dir(bond)] 42 | 43 | def smiles2graph(smiles_string)->Dict: 44 | """ 45 | Converts SMILES string to graph Data object 46 | :input: SMILES string (str) 47 | :return: graph object 48 | """ 49 | 50 | mol = Chem.MolFromSmiles(smiles_string) 51 | 52 | # atoms 53 | atom_features_list = [] 54 | for atom in mol.GetAtoms(): 55 | atom_features_list.append(atom_to_feature(atom)) 56 | x = np.array(atom_features_list, dtype = np.int64) 57 | 58 | # bonds 59 | num_bond_features = 2 60 | if len(mol.GetBonds()) > 0: # mol has bonds 61 | edges_list = [] 62 | edge_features_list = [] 63 | for bond in mol.GetBonds(): 64 | i = bond.GetBeginAtomIdx() 65 | j = bond.GetEndAtomIdx() 66 | 67 | edge_feature = bond_to_feature(bond) 68 | 69 | # add edges in both directions 70 | edges_list.append((i, j)) 71 | edge_features_list.append(edge_feature) 72 | edges_list.append((j, i)) 73 | edge_features_list.append(edge_feature) 74 | 75 | # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] 76 | edge_index = np.array(edges_list, dtype = np.int64).T 77 | 78 | # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] 79 | edge_attr = np.array(edge_features_list, dtype = np.int64) 80 | 81 | else: # mol has no bonds 82 | edge_index = np.empty((2, 0), dtype = np.int64) 83 | edge_attr = np.empty((0, num_bond_features), dtype = np.int64) 84 | 85 | graph = dict() 86 | graph['edge_index'] = edge_index 87 | graph['edge_feat'] = edge_attr 88 | graph['node_feat'] = x 89 | graph['num_nodes'] = len(x) 90 | 91 | return graph 92 | 93 | 94 | def construct_instruct_question(selfies_str:str=None): 95 | """ 96 | Construct instruct question for each graph 97 | """ 98 | question_pools = [ 99 | 'Could you give me a brief overview of this molecule?', 100 | 'Could you provide a description of this molecule?', 101 | 'Describe this molecule.', 102 | 'Please give me some details about this molecule.', 103 | 'Provide a brief overview of this molecule.', 104 | 'Provide a description of this molecule.', 105 | 'What can you tell me about this molecule?' 106 | ] 107 | question = random.choice(question_pools) 108 | if selfies_str is not None: 109 | question += f" The compound SELFIES sequence is: {selfies_str}." 110 | if random.random() < 0.5: 111 | question = "\n" + question 112 | else: 113 | question = question + "\n" 114 | return question 115 | 116 | 117 | def convert_chembl(qa_json, out_dir=None): 118 | assert os.path.exists(qa_json), f"{qa_json} not exists" 119 | qa_name = os.path.basename(qa_json).split(".")[0] 120 | with open(qa_json, "rt") as f: 121 | js = json.load(f) 122 | out = [] 123 | for smi, rec in tqdm(js.items()): 124 | if len(rec) == 0: 125 | continue 126 | graph = smiles2graph(smi) 127 | for question, answer in rec: 128 | out.append({ 129 | "graph": graph, 130 | "conversations": [ 131 | {"from": "human", "value": construct_instruct_question() }, 132 | {"from": "gpt", "value": answer} 133 | ], 134 | }) 135 | print(f"Successfully convert {len(out)} samples.") 136 | 137 | if out_dir is None: 138 | out_dir = os.path.dirname(qa_json) 139 | if not os.path.exists(out_dir): 140 | os.makedirs(out_dir) 141 | with open(os.path.join(out_dir, qa_name+'.pkl'), "wb") as f: 142 | pickle.dump(out, f) 143 | 144 | 145 | def convert_chebi20(txt, out_dir=None, add_selfies=False): 146 | assert os.path.exists(txt), f"{txt} not exists" 147 | qa_name = os.path.basename(txt).split(".")[0] 148 | out = [] 149 | with open(txt, "rt") as f: 150 | f.readline() 151 | for i, line in enumerate(f.readlines()): 152 | cid, smi, desc = line.strip().split("\t") 153 | selfies_str = None 154 | if add_selfies: 155 | try: 156 | selfies_str = sf.encoder(smi) 157 | except: 158 | selfies_str = "" 159 | graph = smiles2graph(smi) 160 | out.append({ 161 | "graph": graph, 162 | "conversations": [ 163 | {"from": "human", "value": construct_instruct_question(selfies_str) }, 164 | {"from": "gpt", "value": desc} 165 | ], 166 | }) 167 | print(f"Successfully convert {len(out)} samples.") 168 | if out_dir is None: 169 | out_dir = os.path.dirname(txt) 170 | if not os.path.exists(out_dir): 171 | os.makedirs(out_dir) 172 | 173 | if add_selfies: 174 | qa_name += "+selfies" 175 | with open(os.path.join(out_dir, qa_name+'.pkl'), "wb") as f: 176 | pickle.dump(out, f) 177 | 178 | 179 | 180 | if __name__ == '__main__': 181 | # graph = smiles2graph('O1C=C[C@H]([C@H]1O2)c3c2cc(OC)c4c3OC(=O)C5=C4CCC(=O)5') 182 | # print(graph) 183 | # qa_json = '/comp_robot/rentianhe/caohe/AIDD/DATA/MolFM/pubcgraphemsft_desc/test.json' 184 | # convert_chembl(qa_json) 185 | 186 | for split in ['train', 'test', 'validation']: 187 | txt = f'/cto_labs/AIDD/DATA/MolT5/ChEBI-20_data/{split}.txt' 188 | convert_chebi20(txt, add_selfies=True) 189 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # InstructMol: Multi-Modal Integration for Building a Versatile and Reliable Molecular Assistant in Drug Discovery (COLING 2025) 2 | Codes for our paper *InstructMol: Multi-Modal Integration for Building a Versatile and Reliable Molecular Assistant in Drug Discovery* 3 | 4 | 5 | 6 | [[Project Page](https://idea-xl.github.io/InstructMol/)] [[Paper](https://arxiv.org/pdf/2311.16208.pdf)] 7 | 8 | ## Overview 9 |

10 | 11 |

12 | The rapid evolution of artificial intelligence in drug discovery encounters challenges with generalization and extensive training, yet Large Language Models (LLMs) offer promise in reshaping interactions with complex molecular data. Our novel contribution, InstructMol, a multi-modal LLM, effectively aligns molecular structures with natural language via an instruction-tuning approach, utilizing a two-stage training strategy that adeptly combines limited domain-specific data with molecular and textual information. InstructMol showcases substantial performance improvements in drug discovery-related molecular tasks, surpassing leading LLMs and significantly reducing the gap with specialized models, thereby establishing a robust foundation for a versatile and dependable drug discovery assistant. 13 | 14 | ## Architecture 15 | The diagram presented below provides an overview of the architectural design of the InstructMol model, along with its two-stage training paradigm. The example molecule in the figure is Terephthalaldehyde (CID 12173). 16 |

17 | 18 |

19 | 20 | ## Release 21 | - [2024/11/30] 🔥 Accepted by COLING 2025. (Jesus, finally get accepted) 22 | - [2023/11/27] 🔥 We first release our code (including training and evaluation scripts). 23 | 24 | 25 | [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE) 26 | [![Data License](https://img.shields.io/badge/Data%20License-CC%20By%20NC%204.0-red.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/DATA_LICENSE) 27 | **Usage and License Notices**: The data, code and checkpoint is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaMA, Vicuna, LLaVA, Mol-Instructions and GPT-4. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes. 28 | 29 | 30 | ## Contents 31 | - [Install](#install) 32 | - [Weights](#weights) 33 | - [Dataset](#dataset) 34 | - [CLI Inference](#cli-inference) 35 | - [Train](#train) 36 | - [Evaluation](#evaluation) 37 | 38 | ## Install 39 | Mostly refer to LLaVA installation 40 | 1. Clone this repository and navigate to project folder 41 | 42 | 2. Install Package 43 | - If you have any trouble install torch-geometric related packages, please refer to [guide-to-pyg-install](https://github.com/chao1224/GraphMVP#environments) for detailed instructions. 44 | ```Shell 45 | conda create -n instructmol python=3.10 -y 46 | conda activate instructmol 47 | pip install --upgrade pip # enable PEP 660 support 48 | pip install -e . 49 | 50 | # Install Graph related packages. We use torch-112 with CUDA-11.6, please change accordingly. 51 | pip install -r requirements.txt 52 | ``` 53 | 54 | 3. Install additional packages for training cases 55 | ``` 56 | pip install ninja 57 | pip install flash-attn --no-build-isolation 58 | ``` 59 | 60 | 61 | ## Weights 62 | 63 | ### Component Weights Download 64 | Create a folder named `checkpoints` in the root directory of this project. 65 | ```Shell 66 | mkdir checkpoints 67 | cd checkpoints 68 | ``` 69 | Download the following weights and put them in the `checkpoints` folder. 70 | ```Shell 71 | # Under the checkpoints folder 72 | # get the weights for the vicuna model (https://huggingface.co/lmsys/vicuna-7b-v1.3) 73 | ln -s YOUR_PATH_TO_vicuna_v1_3_7b vicuna-v1-3-7b 74 | # get the weights for MoleculeSTM model 75 | mkdir MoleculeSTM 76 | wget https://huggingface.co/chao1224/MoleculeSTM/resolve/main/demo/demo_checkpoints_Graph/molecule_model.pth -P MoleculeSTM 77 | # download the weights for scibert_scivocab_uncased model (https://huggingface.co/allenai/scibert_scivocab_uncased) 78 | ln -s YOUR_PATH_TO_scibert_scivocab_uncased scibert_scivocab_uncased 79 | cd .. # back to the root directory 80 | ``` 81 | * [Optional] Get graphmvp weights, please refer to [GraphMVP weights download guidance](https://github.com/chao1224/GraphMVP#for-graphmvp-pre-training). 82 | ```Shell 83 | mv YOUR_PATH_TO_graphmvp.pth checkpoints/ 84 | ``` 85 | 86 | ### InstructMol Weights 87 | * TODO: coming soon 88 | 89 | ## Dataset 90 | * TODO: coming soon 91 | 92 | ## CLI Inference 93 | Chat with InstructMol without the need of Gradio interface. 94 | ```Shell 95 | #!/bin/bash 96 | # NOTE: Insert path of model here.(e.g., checkpoints/Graph-LLaVA/llava-moleculestm-vicuna-v1-3-7b-pretrain) 97 | MODEL_PATH="" 98 | python -m llava.serve.cli_graph \ 99 | --model-path $MODEL_PATH \ 100 | --model-base checkpoints/vicuna-v1-3-7b \ 101 | --graph-checkpoint-path checkpoints/graphmvp.pth 102 | ``` 103 | 104 | 105 | ## Train 106 | LLaVA training consists of two stages: 107 | 108 | * **Stage 1: Alignment Pretraining.** Initial stage aligns molecules with text using a PubChem dataset of 330K pairs. Focuses on fine-tuning the alignment projector while keeping the graph encoder and LLM frozen to leverage pre-trained knowledge. 109 | * **Stage 2: Task-specific Instruction Tuning.** Second stage targets compound property prediction, chemical reaction analysis, and molecule description generation. Utilizes task-specific instruction datasets and LoRA for LLM adaptation, retaining common-sense reasoning capabilities. Allows adaptable adaptors for specific needs or modular knowledge integration. 110 | 111 | ### Stage 1: Alignment Pretraining 112 | See [pretrain.sh](scripts/pretrain.sh) for an example of how to run the pretraining stage. 113 | - `$GRAPH_TOWER` can be chosen from `moleculestm` or `graphmvp`. 114 | 115 | ### Stage 2: Task-specific Instruction Tuning 116 | You can train all specific tasks combine together [finetune_all.sh](scripts/all/finetune_lora_all.sh) or train them separately, (e.g., [molecule description generation task](scripts/finetune_lora_molcap.sh)). 117 | 118 | 119 | ## Evaluation 120 | See [Evaluation.md](Evaluation.md) for detailed instructions on how to evaluate the model. 121 | 122 | ## Citation 123 | If you find InstructMol useful for your your research and applications, please cite using this BibTeX: 124 | ```bibtex 125 | @misc{cao2023instructmol, 126 | title={InstructMol: Multi-Modal Integration for Building a Versatile and Reliable Molecular Assistant in Drug Discovery}, 127 | author={He Cao and Zijing Liu and Xingyu Lu and Yuan Yao and Yu Li}, 128 | year={2023}, 129 | eprint={2311.16208}, 130 | archivePrefix={arXiv}, 131 | primaryClass={q-bio.BM} 132 | } 133 | ``` 134 | 135 | ## Acknowledgement 136 | 137 | - [Vicuna](https://github.com/lm-sys/FastChat): the main base-LLM we used. 138 | - [LLaVA](https://github.com/haotian-liu/LLaVA/tree/main): the codebase we built upon. 139 | -------------------------------------------------------------------------------- /llava/eval/molecule_metrics/MoleculeNet_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import pickle 5 | from tqdm import tqdm 6 | from typing import Generator, Dict 7 | import selfies 8 | from sklearn.metrics import roc_auc_score 9 | from torch_geometric.data import Data 10 | 11 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 12 | from llava.conversation import conv_templates, SeparatorStyle 13 | from llava.model.builder import load_pretrained_model 14 | from llava.utils import disable_torch_init 15 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, MM_ENCODER_CFG 16 | from llava.datasets.smiles2graph import smiles2graph 17 | 18 | 19 | SUPPORT_DATASETS = ["bace", "bbbp", "hiv"] 20 | 21 | def _convert_dict_to_Data(data_dict: Dict) -> Data: 22 | return Data( 23 | x=torch.asarray(data_dict['node_feat']), 24 | edge_attr=torch.asarray(data_dict['edge_feat']), 25 | edge_index=torch.asarray(data_dict['edge_index']), 26 | ) 27 | 28 | def selfies2smiles(selfies_str): 29 | try: 30 | smiles_str = selfies.decoder(selfies_str) 31 | except: 32 | smiles_str = None 33 | return smiles_str 34 | 35 | def smiles2selfies(smiles_str): 36 | try: 37 | selfies_str = selfies.encoder(smiles_str) 38 | except: 39 | selfies_str = None 40 | return selfies_str 41 | 42 | def convert_label_to_int(label): 43 | label = label.strip() 44 | if label.lower() in ['active', "yes", "true"]: 45 | return 1 46 | elif label.lower() in ['inactive', "no", "false"]: 47 | return 0 48 | else: 49 | print("Unknown label:", label) 50 | return -1 51 | 52 | 53 | def iterate_test_files( 54 | args, 55 | batch_size:int=4, 56 | )->Generator: 57 | if args.split == "random": 58 | in_file = os.path.join(args.dataspace, args.dataset, "processed", "instruct-random-test.pkl") 59 | else: 60 | in_file = os.path.join(args.dataspace, args.dataset, "processed", "instruct-test.pkl") 61 | with open(in_file, "rb") as f: 62 | list_data_dict = pickle.load(f) 63 | 64 | batch = [] 65 | for i, raw in enumerate(list_data_dict): 66 | instruction = raw['instruction'] 67 | if args.add_selfies: 68 | selfies_str = smiles2selfies(raw['SMILES']) 69 | instruction += f" The compound SELFIES sequence is: {selfies_str}" 70 | graph = raw['graph'] 71 | batch.append((instruction, graph, raw['label'])) 72 | if len(batch) == batch_size: 73 | yield zip(*batch) 74 | batch = [] 75 | if len(batch) > 0: 76 | yield zip(*batch) 77 | 78 | 79 | def _length_test_file(args)->int: 80 | in_file = os.path.join(args.dataspace, args.dataset, "processed", "instruct-test.pkl") 81 | with open(in_file, "rb") as f: 82 | list_data_dict = pickle.load(f) 83 | return len(list_data_dict) 84 | 85 | 86 | def main(args): 87 | # device 88 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 89 | 90 | # Model 91 | disable_torch_init() 92 | model_name = get_model_name_from_path(args.model_path) 93 | # graph encoder config 94 | mm_encoder_cfg = MM_ENCODER_CFG(init_checkpoint=args.graph_checkpoint_path) 95 | mm_encoder_cfg = mm_encoder_cfg.dict() 96 | 97 | tokenizer, model, _, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, mm_encoder_cfg=mm_encoder_cfg) 98 | model = model.to(torch.bfloat16) 99 | # Sampling 100 | batch_size = args.batch_size 101 | outs = [] 102 | 103 | samples = 0 104 | for instructions, graphs, gts in tqdm( 105 | iterate_test_files(args, batch_size=batch_size), total=_length_test_file(args)//batch_size, 106 | ): 107 | bs = len(instructions) 108 | graph_tensors = [_convert_dict_to_Data(graph).to(device) for graph in graphs] 109 | input_ids_batch = [] 110 | stopping_criteria_batch = [] 111 | for i in range(bs): 112 | qs = instructions[i] 113 | if model.config.mm_use_im_start_end: 114 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 115 | else: 116 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 117 | 118 | conv = conv_templates[args.conv_mode].copy() 119 | conv.append_message(conv.roles[0], qs) 120 | conv.append_message(conv.roles[1], None) 121 | prompt = conv.get_prompt() 122 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device) 123 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 124 | keywords = [stop_str] 125 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 126 | 127 | input_ids_batch.append(input_ids.squeeze(0)) 128 | stopping_criteria_batch.append(stopping_criteria) 129 | # pad input_ids 130 | input_ids_batch = torch.nn.utils.rnn.pad_sequence( 131 | input_ids_batch, 132 | batch_first=True, 133 | padding_value=tokenizer.pad_token_id 134 | ) 135 | 136 | with torch.inference_mode(): 137 | output_ids = model.generate( 138 | input_ids_batch, 139 | graphs=graph_tensors, 140 | do_sample=True, 141 | temperature=args.temperature, 142 | top_p=args.top_p, 143 | num_beams=args.num_beams, 144 | max_new_tokens=args.max_new_tokens, 145 | repetition_penalty=args.repetition_penalty, 146 | use_cache=True, 147 | stopping_criteria=stopping_criteria_batch 148 | ) 149 | 150 | outputs = [] # list of str 151 | for i in range(bs): 152 | output = tokenizer.decode(output_ids[i, input_ids.shape[1]:]).strip() 153 | if output.endswith(stop_str): 154 | output = output[:-len(stop_str)] 155 | output = output.strip() 156 | outputs.append(output) 157 | 158 | for instruction, gt, output in zip(instructions, gts, outputs): 159 | outs.append( 160 | { 161 | "prompt": instruction, 162 | "gt": gt, 163 | "pred": output, 164 | } 165 | ) 166 | if args.debug: 167 | print("\n", {"gt": gt, "outputs": output}, "\n") 168 | samples += bs 169 | # if samples > 20: 170 | # break 171 | 172 | # compute metrics (ROC-AUC) 173 | preds = [convert_label_to_int(out["pred"]) for out in outs] 174 | gts = [convert_label_to_int(out["gt"]) for out in outs] 175 | print("ROC-AUC:", roc_auc_score(gts, preds)) 176 | 177 | 178 | if __name__ == "__main__": 179 | parser = argparse.ArgumentParser() 180 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 181 | parser.add_argument("--dataspace", type=str, default="/cto_labs/AIDD/DATA/MoleculeNet") 182 | parser.add_argument("--dataset", type=str, choices=SUPPORT_DATASETS, default="bace") 183 | parser.add_argument("--graph-checkpoint-path", type=str, required=True) 184 | parser.add_argument("--model-base", type=str, default=None) 185 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 186 | parser.add_argument("--temperature", type=float, default=0.2) 187 | parser.add_argument("--top_p", type=float, default=None) 188 | parser.add_argument("--repetition_penalty", type=float, default=1.0) 189 | parser.add_argument("--num_beams", type=int, default=1) 190 | parser.add_argument("--max-new-tokens", type=int, default=1024) 191 | parser.add_argument("--load-8bit", action="store_true") 192 | parser.add_argument("--load-4bit", action="store_true") 193 | parser.add_argument("--debug", action="store_true") 194 | parser.add_argument("--batch_size", type=int, default=4) 195 | parser.add_argument("--split", type=str, default="random") 196 | parser.add_argument("--add_selfies", action="store_true") 197 | args = parser.parse_args() 198 | main(args) 199 | 200 | 201 | """ 202 | TASK="MoleculeNet" 203 | DATASET=bbbp 204 | EPOCH=20 205 | python -m llava.eval.molecule_metrics.MoleculeNet_classification \ 206 | --dataspace /cto_labs/AIDD/DATA/MoleculeNet \ 207 | --dataset $DATASET \ 208 | --model-path checkpoints/Graph-LLaVA/$TASK-llava-moleculestm-vicuna-v1-3-7b-finetune_lora \ 209 | --graph-checkpoint-path checkpoints/MoleculeSTM/molecule_model.pth \ 210 | --model-base checkpoints/vicuna-v1-3-7b \ 211 | --batch_size 1 \ 212 | --debug 213 | """ -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/moleculeSTM_gnn_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import (MessagePassing, global_add_pool, 5 | global_max_pool, global_mean_pool) 6 | from torch_geometric.nn.inits import glorot, zeros 7 | from torch_geometric.utils import add_self_loops, softmax, degree 8 | from torch_scatter import scatter_add 9 | from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder 10 | from collections import OrderedDict 11 | 12 | 13 | class GINConv(MessagePassing): 14 | def __init__(self, emb_dim, aggr="add"): 15 | ''' 16 | emb_dim (int): node embedding dimensionality 17 | ''' 18 | super(GINConv, self).__init__(aggr=aggr) 19 | 20 | self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim)) 21 | self.eps = torch.nn.Parameter(torch.Tensor([0])) 22 | 23 | self.bond_encoder = BondEncoder(emb_dim = emb_dim) 24 | 25 | def forward(self, x, edge_index, edge_attr): 26 | edge_embedding = self.bond_encoder(edge_attr) 27 | # WARN: some weird thing happend if excute in bfloat16, so we force to cast to float32 28 | dtype = x.dtype 29 | inter = (1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding) 30 | if dtype == torch.bfloat16: 31 | inter = inter.float() 32 | out = self.mlp.float()(inter) 33 | out = out.to(dtype) 34 | else: 35 | out = self.mlp(inter) 36 | return out 37 | 38 | def message(self, x_j, edge_attr): 39 | return F.relu(x_j + edge_attr) 40 | 41 | def update(self, aggr_out): 42 | return aggr_out 43 | 44 | 45 | class GCNConv(MessagePassing): 46 | def __init__(self, emb_dim, aggr="add"): 47 | super(GCNConv, self).__init__(aggr=aggr) 48 | 49 | self.linear = torch.nn.Linear(emb_dim, emb_dim) 50 | self.root_emb = torch.nn.Embedding(1, emb_dim) 51 | self.bond_encoder = BondEncoder(emb_dim = emb_dim) 52 | 53 | def forward(self, x, edge_index, edge_attr): 54 | x = self.linear(x) 55 | edge_embedding = self.bond_encoder(edge_attr) 56 | 57 | row, col = edge_index 58 | 59 | #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device) 60 | deg = degree(row, x.size(0), dtype = x.dtype) + 1 61 | deg_inv_sqrt = deg.pow(-0.5) 62 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 63 | 64 | norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] 65 | 66 | return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1) 67 | 68 | def message(self, x_j, edge_attr, norm): 69 | return norm.view(-1, 1) * F.relu(x_j + edge_attr) 70 | 71 | def update(self, aggr_out): 72 | return aggr_out 73 | 74 | 75 | class GNN(nn.Module): 76 | def __init__(self, num_layer, emb_dim, JK="last", drop_ratio=0., gnn_type="gin"): 77 | 78 | if num_layer < 2: 79 | raise ValueError("Number of GNN layers must be greater than 1.") 80 | 81 | super(GNN, self).__init__() 82 | self.drop_ratio = drop_ratio 83 | self.num_layer = num_layer 84 | self.JK = JK 85 | 86 | self.atom_encoder = AtomEncoder(emb_dim) 87 | 88 | ###List of MLPs 89 | self.gnns = nn.ModuleList() 90 | for layer in range(num_layer): 91 | if gnn_type == "gin": 92 | self.gnns.append(GINConv(emb_dim, aggr="add")) 93 | elif gnn_type == "gcn": 94 | self.gnns.append(GCNConv(emb_dim)) 95 | 96 | ###List of batchnorms 97 | self.batch_norms = nn.ModuleList() 98 | for layer in range(num_layer): 99 | self.batch_norms.append(nn.BatchNorm1d(emb_dim)) 100 | 101 | # def forward(self, x, edge_index, edge_attr): 102 | def forward(self, *argv): 103 | if len(argv) == 3: 104 | x, edge_index, edge_attr = argv[0], argv[1], argv[2] 105 | elif len(argv) == 1: 106 | data = argv[0] 107 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 108 | else: 109 | raise ValueError("unmatched number of arguments.") 110 | 111 | x = self.atom_encoder(x) 112 | 113 | h_list = [x] 114 | for layer in range(self.num_layer): 115 | h = self.gnns[layer](h_list[layer], edge_index, edge_attr) 116 | h = self.batch_norms[layer](h) 117 | # h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) 118 | if layer == self.num_layer - 1: 119 | # remove relu for the last layer 120 | h = F.dropout(h, self.drop_ratio, training=self.training) 121 | else: 122 | h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) 123 | h_list.append(h) 124 | 125 | ### Different implementations of Jk-concat 126 | if self.JK == "concat": 127 | node_representation = torch.cat(h_list, dim=1) 128 | elif self.JK == "last": 129 | node_representation = h_list[-1] 130 | elif self.JK == "max": 131 | h_list = [h.unsqueeze_(0) for h in h_list] 132 | node_representation = torch.max(torch.cat(h_list, dim=0), dim=0)[0] 133 | elif self.JK == "sum": 134 | h_list = [h.unsqueeze_(0) for h in h_list] 135 | node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0] 136 | else: 137 | raise ValueError("not implemented.") 138 | return node_representation 139 | 140 | 141 | class GNN_graphpred(nn.Module): 142 | """ 143 | Extension of GIN to incorporate edge information by concatenation. 144 | 145 | Args: 146 | num_layer (int): the number of GNN layers 147 | arg.emb_dim (int): dimensionality of embeddings 148 | num_tasks (int): number of tasks in multi-task learning scenario 149 | JK (str): last, concat, max or sum. 150 | graph_pooling (str): sum, mean, max, attention, set2set 151 | 152 | See https://arxiv.org/abs/1810.00826 153 | JK-net: https://arxiv.org/abs/1806.03536 """ 154 | 155 | def __init__( 156 | self, 157 | emb_dim, 158 | graph_pooling, 159 | projection_dim:int=None, 160 | molecule_node_model=None, 161 | init_checkpoint=None, 162 | ): 163 | super(GNN_graphpred, self).__init__() 164 | 165 | self.molecule_node_model = molecule_node_model 166 | self.emb_dim = emb_dim 167 | 168 | # Different kind of graph pooling 169 | if graph_pooling == "sum": 170 | self.pool = global_add_pool 171 | elif graph_pooling == "mean": 172 | self.pool = global_mean_pool 173 | elif graph_pooling == "max": 174 | self.pool = global_max_pool 175 | else: 176 | raise ValueError("Invalid graph pooling type.") 177 | 178 | if projection_dim is not None: 179 | self.projector = nn.Linear(emb_dim, projection_dim) 180 | self.output_dim = projection_dim 181 | else: 182 | self.projector = None 183 | self.output_dim = emb_dim 184 | 185 | if init_checkpoint is not None: 186 | self._load_state_dict(init_checkpoint, strict=False) 187 | 188 | def forward(self, *argv): 189 | if len(argv) == 4: 190 | x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3] 191 | elif len(argv) == 1: 192 | data = argv[0] 193 | x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch 194 | else: 195 | raise ValueError("unmatched number of arguments.") 196 | 197 | node_representation = self.molecule_node_model(x, edge_index, edge_attr) 198 | graph_representation = self.pool(node_representation, batch) 199 | return graph_representation, node_representation 200 | 201 | def encode_mol(self, mol, proj=False, return_node_feats=False, eval=True): 202 | if eval: 203 | self.molecule_node_model.eval() # hard code: set to eval mode 204 | with torch.no_grad(): 205 | h_graph, h_node = self.forward(mol) 206 | else: 207 | self.molecule_node_model.train() # set to train mode 208 | h_graph, h_node = self.forward(mol) 209 | if proj and self.projector is not None: 210 | h_graph = self.projector(h_graph) 211 | h_node = self.projector(h_node) 212 | if return_node_feats: 213 | return h_graph, h_node 214 | else: 215 | return h_graph 216 | 217 | def _load_state_dict(self, model_file, strict=False): 218 | print("Loading from {} ...".format(model_file)) 219 | state_dict = torch.load(model_file, map_location=torch.device('cpu')) 220 | self.load_state_dict(state_dict, strict=strict) 221 | return 222 | 223 | @property 224 | def dummy_feature(self): 225 | return self.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 226 | 227 | @property 228 | def hidden_size(self): 229 | return self.output_dim --------------------------------------------------------------------------------