├── .gitignore ├── README.md ├── assets └── images │ ├── deer.png │ ├── figure3.png │ ├── icon-name.pdf │ ├── icon-name.png │ └── table2.png ├── bayesian_optimization.py ├── enrich_lang_annotations.json ├── eval_ckpts.py ├── eval_sequences.json ├── lang_annotation_cache.json ├── license ├── modeling_gpt_9b.py ├── mosaic_gpt_3b.py ├── open_flamingo ├── HISTORY.md ├── LICENSE ├── Makefile ├── README.md ├── TERMS_AND_CONDITIONS.md ├── _optim_utils.py ├── environment.yml ├── open_flamingo.egg-info │ └── PKG-INFO ├── open_flamingo │ ├── __init__.py │ ├── eval │ │ ├── README.md │ │ ├── __init__.py │ │ ├── classification_utils.py │ │ ├── coco_metric.py │ │ ├── data │ │ │ ├── textvqa │ │ │ │ ├── train_annotations_vqa_format.json │ │ │ │ ├── train_questions_vqa_format.json │ │ │ │ ├── val_annotations_vqa_format.json │ │ │ │ └── val_questions_vqa_format.json │ │ │ └── vizwiz │ │ │ │ ├── test_questions_vqa_format.json │ │ │ │ ├── train_annotations_vqa_format.json │ │ │ │ ├── train_questions_vqa_format.json │ │ │ │ ├── val_annotations_vqa_format.json │ │ │ │ └── val_questions_vqa_format.json │ │ ├── eval_datasets.py │ │ ├── eval_model.py │ │ ├── evaluate.py │ │ ├── models │ │ │ ├── blip.py │ │ │ └── open_flamingo.py │ │ ├── ok_vqa_utils.py │ │ ├── rices.py │ │ ├── utils.py │ │ └── vqa_metric.py │ ├── scripts │ │ ├── cache_rices_features.py │ │ ├── convert_mmc4_to_wds.py │ │ └── fill_vqa_testdev_results.py │ ├── src │ │ ├── __init__.py │ │ ├── factory.py │ │ ├── flamingo.py │ │ ├── flamingo_lm.py │ │ ├── helpers.py │ │ └── utils.py │ └── train │ │ ├── README.md │ │ ├── __init__.py │ │ ├── data.py │ │ ├── data_utils.py │ │ ├── distributed.py │ │ ├── train.py │ │ └── train_utils.py └── setup.py ├── partial_task_data.json ├── requirements.txt └── robot_flamingo ├── data ├── data.py ├── real_dataset_hdf5.py └── vl_dataset.py ├── eval ├── eval_calvin.py └── eval_utils.py ├── models ├── action_head.py ├── factory.py ├── flamingo_bc.py ├── flamingo_mpt.py ├── normalizer.py ├── trajectory_gpt2.py ├── unets.py └── value_net.py ├── pt_eval_ckpts.bash ├── pt_run_gripper_post_ws_12_traj_aug_mpt_dolly_3b.bash ├── pt_run_gripper_post_ws_12_traj_aug_mpt_dolly_3b_co_train.bash ├── thresholds.bash ├── train ├── train_calvin_post_strategy.py └── train_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | *.pyc 3 | .ipynb_checkpoints 4 | *.ipynb 5 | **/wandb 6 | *.sh 7 | *.pt 8 | .vscode/ 9 | **/__pycache__ 10 | .idea 11 | .github 12 | .cache 13 | *.log 14 | logs/* 15 | *.txt 16 | *512.json 17 | *log* 18 | *vis* 19 | *.pth 20 | RobotFlamingoDBG* 21 | eval_ckpts_* 22 | cross_att_layer.json 23 | enrich* 24 | egl_check/ 25 | *opengl 26 | rollout*.py 27 | .swp 28 | real_dataset_mDT.py 29 | epoch_figs.py 30 | eval_hist* 31 | gen* 32 | gif_trans* 33 | *gpu* 34 | load* 35 | data_process.py 36 | lang_simi* 37 | *.deb 38 | *plot.py 39 | void.py 40 | test*.py 41 | *.safetensors 42 | lm_src 43 | evaluate 44 | *.zip 45 | *.bin 46 | 47 | 48 | # !arnold_before.sh 49 | # !run*.sh 50 | -------------------------------------------------------------------------------- /assets/images/deer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yueyang130/DeeR-VLA/b30de502b8d40cd7326cabb8131d4ad477748e56/assets/images/deer.png -------------------------------------------------------------------------------- /assets/images/figure3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yueyang130/DeeR-VLA/b30de502b8d40cd7326cabb8131d4ad477748e56/assets/images/figure3.png -------------------------------------------------------------------------------- /assets/images/icon-name.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yueyang130/DeeR-VLA/b30de502b8d40cd7326cabb8131d4ad477748e56/assets/images/icon-name.pdf -------------------------------------------------------------------------------- /assets/images/icon-name.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yueyang130/DeeR-VLA/b30de502b8d40cd7326cabb8131d4ad477748e56/assets/images/icon-name.png -------------------------------------------------------------------------------- /assets/images/table2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yueyang130/DeeR-VLA/b30de502b8d40cd7326cabb8131d4ad477748e56/assets/images/table2.png -------------------------------------------------------------------------------- /bayesian_optimization.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import os 4 | import numpy as np 5 | from skopt import gp_minimize 6 | from skopt.space import Real 7 | from skopt.utils import use_named_args 8 | 9 | 10 | def get_observation(log_path): 11 | with open(log_path, 'r') as file: 12 | lines = file.readlines() 13 | thresholds_str = lines[-3] 14 | thresholds = list(map(float, thresholds_str.split(','))) 15 | avg_len = float(lines[-2]) 16 | avg_exit = float(lines[-1]) 17 | return thresholds, avg_len, avg_exit 18 | 19 | def get_score(avg_len, avg_exit): 20 | if avg_exit > budget: 21 | res = - avg_len + 1.0 * (avg_exit - budget) 22 | else: 23 | res = - avg_len 24 | return res 25 | 26 | parser = argparse.ArgumentParser() 27 | # bayesian optimization 28 | parser.add_argument('--num_seq', type=int) 29 | parser.add_argument("--evaluate_from_checkpoint", type=str) 30 | parser.add_argument("--acq_func", type=str, default='EI', choices=['EI', 'LCB', 'PI']) 31 | parser.add_argument('--n_calls', type=int) 32 | parser.add_argument('--init_exit_ratio', type=float) 33 | parser.add_argument('--seed', type=int, default=1) 34 | parser.add_argument('--port', type=int) 35 | args = parser.parse_args() 36 | 37 | assert os.environ['calvin_dataset_path'] and os.environ['calvin_conf_path'], "PLEASE SET CAVLIN DATASET PATH and CONFIG PATH!" 38 | args.calvin_dataset = os.environ['calvin_dataset_path'] 39 | args.calvin_conf_path = os.environ['calvin_conf_path'] 40 | 41 | 42 | ckpt_dir, ckpt_name = os.path.split(args.evaluate_from_checkpoint) 43 | log_dir = f'log_BO_{args.init_exit_ratio}_{ckpt_dir}/' 44 | if not os.path.exists(log_dir): 45 | os.makedirs(log_dir) 46 | 47 | 48 | iter_num = 0 49 | log_file = log_dir + f'seq{args.num_seq}_{args.acq_func}_seed{args.seed}_' + ckpt_name[:-4] + f'_iter{str(iter_num)}' + '.log' 50 | print(f'{log_file=}') 51 | 52 | # solve thresholds with exp distribution with a validation datast to get initial point for bayesian optimization 53 | if not os.path.exists(log_file): 54 | os.system(f""" 55 | torchrun --nnodes=1 --nproc_per_node=$ARNOLD_WORKER_GPU --master_port={args.port} robot_flamingo/eval/eval_calvin.py \ 56 | --precision fp32 \ 57 | --use_gripper \ 58 | --run_name DeeR \ 59 | --calvin_dataset {args.calvin_dataset} \ 60 | --cross_attn_every_n_layers 4 \ 61 | --evaluate_from_checkpoint {args.evaluate_from_checkpoint} \ 62 | --calvin_conf_path {args.calvin_conf_path} \ 63 | --amp 1 \ 64 | --exit_ratio {args.init_exit_ratio} \ 65 | --num_seq {args.num_seq} \ 66 | --validation_set \ 67 | --workers 1 > {log_file} 2>&1 68 | """) 69 | 70 | with open(log_file, 'r') as file: 71 | lines = file.readlines() 72 | thresholds_str = lines[-3] 73 | init_thresholds = list(map(float, thresholds_str.split(','))) 74 | init_avg_len = float(lines[-2]) 75 | # set the FLOPs of the running using demonstration dataset as budget constraint, 76 | # such that the search result by bayesian should cost less FLOPs than threshold only using demonstration dataset. 77 | # You can set other values manually. PLEASE that here all values represents the average exit layer. 78 | # Average exit layer * FLOPS per layer = Average FLOPs 79 | init_avg_exit = budget = float(lines[-1]) 80 | 81 | print('exp result:') 82 | print(init_thresholds) 83 | print(init_avg_len) 84 | print(init_avg_exit) 85 | 86 | 87 | # get existing observations as other initial points 88 | x0, y0 = [init_thresholds[:-1]], [-init_avg_len] 89 | from pathlib import Path 90 | for log in Path(log_dir).glob('*.log'): 91 | if 'iter0.log' in str(log): continue 92 | try: 93 | thresholds, avg_len, avg_exit = get_observation(log) 94 | score = get_score(avg_len, avg_exit) 95 | x0.append(thresholds[:-1]) 96 | y0.append(score) 97 | except: 98 | print(f'Error when parsing {log}') 99 | pass 100 | 101 | # define search space 102 | space = [ 103 | Real(init_thresholds[0]-0.02, init_thresholds[0]+0.02, name='t0'), 104 | Real(init_thresholds[1]-0.002, init_thresholds[1]+0.002, name='t1'), 105 | Real(init_thresholds[2]-0.002, init_thresholds[2]+0.002, name='t2'), 106 | Real(init_thresholds[3]-0.002, init_thresholds[3]+0.002, name='t3'), 107 | Real(init_thresholds[4]-0.002, init_thresholds[4]+0.002, name='t4'), 108 | ] 109 | # space = [ 110 | # Real(init_thresholds[0]-0.01, init_thresholds[0]+0.01, name='t0'), 111 | # Real(init_thresholds[1]-0.001, init_thresholds[1]+0.001, name='t1'), 112 | # Real(init_thresholds[2]-0.001, init_thresholds[2]+0.001, name='t2'), 113 | # Real(init_thresholds[3]-0.001, init_thresholds[3]+0.001, name='t3'), 114 | # Real(init_thresholds[4]-0.001, init_thresholds[4]+0.001, name='t4'), 115 | # ] 116 | 117 | @use_named_args(space) 118 | def objective_function(t0, t1, t2, t3, t4): 119 | global log_file 120 | global iter_num 121 | iter_num += 1 122 | log_file = log_file[:-10] + f'_iter{str(iter_num)}' + '.log' 123 | t5 = 100000.0 124 | 125 | print('') 126 | print(f'{iter_num=}') 127 | print(f'threshold={t0}, {t1}, {t2}, {t3}, {t4}, {t5}') 128 | 129 | if not os.path.exists(log_file): 130 | os.system(f""" 131 | torchrun --nnodes=1 --nproc_per_node=$ARNOLD_WORKER_GPU --master_port={args.port} robot_flamingo/eval/eval_calvin.py \ 132 | --precision fp32 \ 133 | --use_gripper \ 134 | --run_name DeeR \ 135 | --calvin_dataset {args.calvin_dataset} \ 136 | --cross_attn_every_n_layers 4 \ 137 | --evaluate_from_checkpoint {args.evaluate_from_checkpoint} \ 138 | --calvin_conf_path {args.calvin_conf_path} \ 139 | --amp 1 \ 140 | --thresholds {t0} {t1} {t2} {t3} {t4} {t5} \ 141 | --num_seq {args.num_seq} \ 142 | --validation_set \ 143 | --workers 1 > {log_file} 2>&1 144 | """) 145 | 146 | thresholds, avg_len, avg_exit = get_observation(log_file) 147 | res = get_score(avg_len, avg_exit) 148 | print(f'{avg_len=}') 149 | print(f'{avg_exit=}') 150 | print(f'BO {res=}') 151 | return res 152 | 153 | # print('') 154 | # print('init x0:', x0) 155 | print('init y0:', y0) 156 | 157 | result = gp_minimize( 158 | objective_function, 159 | space, 160 | x0=x0, 161 | y0=y0, 162 | n_calls=20, 163 | random_state=args.seed, 164 | acq_func=args.acq_func, # 选择采集函数 165 | ) 166 | 167 | print("Optimal thresholds:", result.x) 168 | print("optimal avg exit:", -result.fun) -------------------------------------------------------------------------------- /eval_ckpts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import argparse 4 | 5 | # Create the parser 6 | parser = argparse.ArgumentParser() 7 | 8 | # Add the arguments 9 | parser.add_argument('--ckpt_dir', type=str, help='The checkpoint directory') 10 | parser.add_argument('--exit_ratio', nargs='+', type=float, default=[1.0], help='a list') 11 | parser.add_argument('--node_num', type=int, help='how much GPUs/threads to parallelly evaluate') 12 | parser.add_argument("--num_seq", type=int, default=224, help="the number of task chains for elvaution. Maximum is 1000.") 13 | parser.add_argument( 14 | "--amp", 15 | default=0, 16 | type=int, 17 | ) 18 | parser.add_argument("--max_layer", type=int, default=-1) # use for constraining memory/max flop. 19 | 20 | parser.add_argument('--enrich_annotation', type=int, default=0, help='If set, eval in enriched annotation setting') 21 | parser.add_argument( 22 | "--precision", 23 | choices=["int4", "int8", "bf16", "fp16", "fp32"], 24 | default="fp32", 25 | help="Floating point precision.", 26 | ) 27 | 28 | parser.add_argument("--note", type=str, default='') 29 | 30 | # Parse the arguments 31 | args = parser.parse_args() 32 | 33 | search_path = os.path.join(args.ckpt_dir, r'*_[0-9].pth') 34 | ckpt_names = [os.path.basename(path) for path in glob.glob(search_path)] 35 | ckpt_names.sort(reverse=True) 36 | 37 | print(ckpt_names) 38 | for ckpt_name in ckpt_names: 39 | for r in args.exit_ratio: 40 | ckpt_path = os.path.join(args.ckpt_dir, ckpt_name) 41 | if not os.path.exists(ckpt_path): 42 | print("ckpt doesn't exist, skipped.") 43 | continue 44 | log_dir = f'log_{args.ckpt_dir}' 45 | os.makedirs(log_dir, exist_ok=True) 46 | prefix = f'evaluate{args.num_seq}{args.note}_{args.precision}' 47 | if args.enrich_annotation: 48 | prefix += '_enrich' 49 | if args.amp: 50 | prefix += '_amp' 51 | prefix += f'_maxL={args.max_layer}_{r}' 52 | prefix += '_exit' 53 | 54 | log_file = '{}/{}_{}.log'.format(log_dir, prefix, '.'.join(ckpt_name.split('.')[:-1])) 55 | 56 | if 'freeze_emb' in ckpt_name: 57 | log_file = log_file[:-4] + '_freeze_emb.log' 58 | if os.path.exists(log_file): 59 | print(f'skip {log_file}') 60 | continue 61 | ckpt_ix = ckpt_names.index(ckpt_name) 62 | print('evaluating {}/{} checkpoint'.format(ckpt_ix+1, len(ckpt_names))) 63 | 64 | window_size = 12 65 | ckpt_attrs = ckpt_name.split('_') 66 | if 'ws' in ckpt_attrs: 67 | window_size = int(ckpt_attrs[ckpt_attrs.index('ws')+1]) 68 | 69 | os.system('bash robot_flamingo/pt_eval_ckpts.bash {} {} {} {} {} {} {} {} {} {}'.format( 70 | ckpt_path, 71 | log_file, 72 | window_size, 73 | args.node_num, 74 | args.amp, 75 | r, 76 | args.num_seq, 77 | args.max_layer, 78 | args.enrich_annotation, 79 | args.precision)) -------------------------------------------------------------------------------- /license: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /open_flamingo/HISTORY.md: -------------------------------------------------------------------------------- 1 | ## 2.0.0 2 | * Add gradient checkpointing, FullyShardedDataParallel 3 | * Model releases 4 | * (CLIP ViT-L-14 / MPT-1B) 5 | * (CLIP ViT-L-14 / MPT-1B Dolly) 6 | * (CLIP ViT-L-14 / RedPajama-3B) 7 | * (CLIP ViT-L-14 / RedPajama-3B Instruct) 8 | * (CLIP ViT-L-14 / MPT-7B) 9 | * Remove color jitter when training 10 | * Fix cross-attention bug when calling generate() 11 | 12 | ## 1.0.0 13 | 14 | * Initial code release 15 | * Early model release (CLIP ViT-L-14 / LLaMA-7B) -------------------------------------------------------------------------------- /open_flamingo/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Anas Awadalla, Irena Gao, Joshua Gardner, Jack Hessel, Yusuf Hanafy, Wanrong Zhu, Kalyani Marathe, Yonatan Bitton, Samir Gadre, Jenia Jitsev, Simon Kornblith, Pang Wei Koh, Gabriel Ilharco, Mitchell Wortsman, Ludwig Schmidt. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /open_flamingo/Makefile: -------------------------------------------------------------------------------- 1 | install: ## [Local development] Upgrade pip, install requirements, install package. 2 | python -m pip install -U pip 3 | python -m pip install -e . 4 | 5 | install-dev: ## [Local development] Install test requirements 6 | python -m pip install -r requirements-dev.txt 7 | 8 | lint: ## [Local development] Run mypy, pylint and black 9 | python -m mypy open_flamingo 10 | python -m pylint open_flamingo 11 | python -m black --check -l 120 open_flamingo 12 | 13 | black: ## [Local development] Auto-format python code using black 14 | python -m black -l 120 . 15 | 16 | .PHONY: help 17 | 18 | help: # Run `make help` to get help on the make commands 19 | @grep -E '^[0-9a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 20 | -------------------------------------------------------------------------------- /open_flamingo/README.md: -------------------------------------------------------------------------------- 1 | # 🦩 OpenFlamingo 2 | 3 | [![PyPI version](https://badge.fury.io/py/open_flamingo.svg)](https://badge.fury.io/py/open_flamingo) 4 | 5 | [Paper](https://arxiv.org/abs/2308.01390) | Blog posts: [1](https://laion.ai/blog/open-flamingo/), [2](https://laion.ai/blog/open-flamingo-v2/) | [Demo](https://huggingface.co/spaces/openflamingo/OpenFlamingo) 6 | 7 | Welcome to our open source implementation of DeepMind's [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model)! 8 | 9 | In this repository, we provide a PyTorch implementation for training and evaluating OpenFlamingo models. 10 | If you have any questions, please feel free to open an issue. We also welcome contributions! 11 | 12 | # Table of Contents 13 | - [Installation](#installation) 14 | - [Approach](#approach) 15 | * [Model architecture](#model-architecture) 16 | - [Usage](#usage) 17 | * [Initializing an OpenFlamingo model](#initializing-an-openflamingo-model) 18 | * [Generating text](#generating-text) 19 | - [Training](#training) 20 | * [Dataset](#dataset) 21 | - [Evaluation](#evaluation) 22 | - [Future plans](#future-plans) 23 | - [Team](#team) 24 | - [Acknowledgments](#acknowledgments) 25 | - [Citing](#citing) 26 | 27 | # Installation 28 | 29 | To install the package in an existing environment, run 30 | ``` 31 | pip install open-flamingo 32 | ``` 33 | 34 | or to create a conda environment for running OpenFlamingo, run 35 | ``` 36 | conda env create -f environment.yml 37 | ``` 38 | 39 | To install training or eval dependencies, run one of the first two commands. To install everything, run the third command. 40 | ``` 41 | pip install open-flamingo[training] 42 | pip install open-flamingo[eval] 43 | pip install open-flamingo[all] 44 | ``` 45 | 46 | There are three `requirements.txt` files: 47 | - `requirements.txt` 48 | - `requirements-training.txt` 49 | - `requirements-eval.txt` 50 | 51 | Depending on your use case, you can install any of these with `pip install -r `. The base file contains only the dependencies needed for running the model. 52 | 53 | # Approach 54 | OpenFlamingo is a multimodal language model that can be used for a variety of tasks. It is trained on a large multimodal dataset (e.g. Multimodal C4) and can be used to generate text conditioned on interleaved images/text. For example, OpenFlamingo can be used to generate a caption for an image, or to generate a question given an image and a text passage. The benefit of this approach is that we are able to rapidly adapt to new tasks using in-context learning. 55 | 56 | ## Model architecture 57 | OpenFlamingo combines a pretrained vision encoder and a language model using cross attention layers. The model architecture is shown below. 58 | 59 | ![OpenFlamingo architecture](docs/flamingo.png) 60 | Credit: [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) 61 | 62 | # Usage 63 | ## Initializing an OpenFlamingo model 64 | We support pretrained vision encoders from the [OpenCLIP](https://github.com/mlfoundations/open_clip) package, which includes OpenAI's pretrained models. 65 | We also support pretrained language models from the `transformers` package, such as [MPT](https://huggingface.co/models?search=mosaicml%20mpt), [RedPajama](https://huggingface.co/models?search=redpajama), [LLaMA](https://huggingface.co/models?search=llama), [OPT](https://huggingface.co/models?search=opt), [GPT-Neo](https://huggingface.co/models?search=gpt-neo), [GPT-J](https://huggingface.co/models?search=gptj), and [Pythia](https://huggingface.co/models?search=pythia) models. 66 | 67 | ``` python 68 | from open_flamingo import create_model_and_transforms 69 | 70 | model, image_processor, tokenizer = create_model_and_transforms( 71 | clip_vision_encoder_path="ViT-L-14", 72 | clip_vision_encoder_pretrained="openai", 73 | lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b", 74 | tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b", 75 | cross_attn_every_n_layers=1, 76 | cache_dir="PATH/TO/CACHE/DIR" # Defaults to ~/.cache 77 | ) 78 | ``` 79 | 80 | ## Released OpenFlamingo models 81 | We have trained the following OpenFlamingo models so far. 82 | 83 | |# params|Language model|Vision encoder|Xattn interval*|COCO 4-shot CIDEr|VQAv2 4-shot Accuracy|Weights| 84 | |------------|--------------|--------------|----------|-----------|-------|----| 85 | |3B| mosaicml/mpt-1b-redpajama-200b | openai CLIP ViT-L/14 | 1 | 77.3 | 45.8 |[Link](https://huggingface.co/openflamingo/OpenFlamingo-3B-vitl-mpt1b)| 86 | |3B| mosaicml/mpt-1b-redpajama-200b-dolly | openai CLIP ViT-L/14 | 1 | 82.7 | 45.7 |[Link](https://huggingface.co/openflamingo/OpenFlamingo-3B-vitl-mpt1b-langinstruct)| 87 | |4B| togethercomputer/RedPajama-INCITE-Base-3B-v1 | openai CLIP ViT-L/14 | 2 | 81.8 | 49.0 | [Link](https://huggingface.co/openflamingo/OpenFlamingo-4B-vitl-rpj3b)| 88 | |4B| togethercomputer/RedPajama-INCITE-Instruct-3B-v1 | openai CLIP ViT-L/14 | 2 | 85.8 | 49.0 | [Link](https://huggingface.co/openflamingo/OpenFlamingo-4B-vitl-rpj3b-langinstruct)| 89 | |9B| mosaicml/mpt-7b | openai CLIP ViT-L/14 | 4 | 89.0 | 54.8 | [Link](https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b)| 90 | 91 | *\* Xattn interval refers to the `--cross_attn_every_n_layers` argument.* 92 | 93 | Note: as part of our v2 release, we have deprecated a previous LLaMA-based checkpoint. However, you can continue to use our older checkpoint using the new codebase. 94 | 95 | ## Downloading pretrained weights 96 | 97 | To instantiate an OpenFlamingo model with one of our released weights, initialize the model as above and use the following code. 98 | 99 | ```python 100 | # grab model checkpoint from huggingface hub 101 | from huggingface_hub import hf_hub_download 102 | import torch 103 | 104 | checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt") 105 | model.load_state_dict(torch.load(checkpoint_path), strict=False) 106 | ``` 107 | 108 | ## Generating text 109 | Below is an example of generating text conditioned on interleaved images/text. In particular, let's try few-shot image captioning. 110 | 111 | ``` python 112 | from PIL import Image 113 | import requests 114 | import torch 115 | 116 | """ 117 | Step 1: Load images 118 | """ 119 | demo_image_one = Image.open( 120 | requests.get( 121 | "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True 122 | ).raw 123 | ) 124 | 125 | demo_image_two = Image.open( 126 | requests.get( 127 | "http://images.cocodataset.org/test-stuff2017/000000028137.jpg", 128 | stream=True 129 | ).raw 130 | ) 131 | 132 | query_image = Image.open( 133 | requests.get( 134 | "http://images.cocodataset.org/test-stuff2017/000000028352.jpg", 135 | stream=True 136 | ).raw 137 | ) 138 | 139 | 140 | """ 141 | Step 2: Preprocessing images 142 | Details: For OpenFlamingo, we expect the image to be a torch tensor of shape 143 | batch_size x num_media x num_frames x channels x height x width. 144 | In this case batch_size = 1, num_media = 3, num_frames = 1, 145 | channels = 3, height = 224, width = 224. 146 | """ 147 | vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)] 148 | vision_x = torch.cat(vision_x, dim=0) 149 | vision_x = vision_x.unsqueeze(1).unsqueeze(0) 150 | 151 | """ 152 | Step 3: Preprocessing text 153 | Details: In the text we expect an special token to indicate where an image is. 154 | We also expect an <|endofchunk|> special token to indicate the end of the text 155 | portion associated with an image. 156 | """ 157 | tokenizer.padding_side = "left" # For generation padding tokens should be on the left 158 | lang_x = tokenizer( 159 | ["An image of two cats.<|endofchunk|>An image of a bathroom sink.<|endofchunk|>An image of"], 160 | return_tensors="pt", 161 | ) 162 | 163 | 164 | """ 165 | Step 4: Generate text 166 | """ 167 | generated_text = model.generate( 168 | vision_x=vision_x, 169 | lang_x=lang_x["input_ids"], 170 | attention_mask=lang_x["attention_mask"], 171 | max_new_tokens=20, 172 | num_beams=3, 173 | ) 174 | 175 | print("Generated text: ", tokenizer.decode(generated_text[0])) 176 | ``` 177 | 178 | # Training 179 | We provide training scripts in `open_flamingo/train`. We provide an example Slurm script in `open_flamingo/scripts/run_train.py`, as well as the following example command: 180 | ``` 181 | torchrun --nnodes=1 --nproc_per_node=4 open_flamingo/train/train.py \ 182 | --lm_path anas-awadalla/mpt-1b-redpajama-200b \ 183 | --tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \ 184 | --cross_attn_every_n_layers 1 \ 185 | --dataset_resampled \ 186 | --batch_size_mmc4 32 \ 187 | --batch_size_laion 64 \ 188 | --train_num_samples_mmc4 125000\ 189 | --train_num_samples_laion 250000 \ 190 | --loss_multiplier_laion 0.2 \ 191 | --workers=4 \ 192 | --run_name OpenFlamingo-3B-vitl-mpt1b \ 193 | --num_epochs 480 \ 194 | --warmup_steps 1875 \ 195 | --mmc4_textsim_threshold 0.24 \ 196 | --laion_shards "/path/to/shards/shard-{0000..0999}.tar" \ 197 | --mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \ 198 | --report_to_wandb 199 | ``` 200 | 201 | *Note: The MPT-1B [base](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) and [instruct](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b-dolly) modeling code does not accept the `labels` kwarg or compute cross-entropy loss directly within `forward()`, as expected by our codebase. We suggest using a modified version of the MPT-1B models found [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b) and [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b-dolly).* 202 | 203 | For more details, see our [training README](https://github.com/mlfoundations/open_flamingo/tree/main/open_flamingo/train). 204 | 205 | 206 | # Evaluation 207 | An example evaluation script is at `open_flamingo/scripts/run_eval.sh`. Please see our [evaluation README](https://github.com/mlfoundations/open_flamingo/tree/main/open_flamingo/eval) for more details. 208 | 209 | 210 | To run evaluations on OKVQA you will need to run the following command: 211 | ``` 212 | import nltk 213 | nltk.download('wordnet') 214 | ``` 215 | 216 | 217 | # Future plans 218 | - [ ] Add support for video input 219 | 220 | # Team 221 | 222 | OpenFlamingo is developed by: 223 | 224 | [Anas Awadalla*](https://anas-awadalla.streamlit.app/), [Irena Gao*](https://i-gao.github.io/), [Joshua Gardner](https://homes.cs.washington.edu/~jpgard/), [Jack Hessel](https://jmhessel.com/), [Yusuf Hanafy](https://www.linkedin.com/in/yusufhanafy/), [Wanrong Zhu](https://wanrong-zhu.com/), [Kalyani Marathe](https://sites.google.com/uw.edu/kalyanimarathe/home?authuser=0), [Yonatan Bitton](https://yonatanbitton.github.io/), [Samir Gadre](https://sagadre.github.io/), [Shiori Sagawa](https://cs.stanford.edu/~ssagawa/), [Jenia Jitsev](https://scholar.google.de/citations?user=p1FuAMkAAAAJ&hl=en), [Simon Kornblith](https://simonster.com/), [Pang Wei Koh](https://koh.pw/), [Gabriel Ilharco](https://gabrielilharco.com/), [Mitchell Wortsman](https://mitchellnw.github.io/), [Ludwig Schmidt](https://people.csail.mit.edu/ludwigs/). 225 | 226 | The team is primarily from the University of Washington, Stanford, AI2, UCSB, and Google. 227 | 228 | # Acknowledgments 229 | This code is based on Lucidrains' [flamingo implementation](https://github.com/lucidrains/flamingo-pytorch) and David Hansmair's [flamingo-mini repo](https://github.com/dhansmair/flamingo-mini). Thank you for making your code public! We also thank the [OpenCLIP](https://github.com/mlfoundations/open_clip) team as we use their data loading code and take inspiration from their library design. 230 | 231 | We would also like to thank [Jean-Baptiste Alayrac](https://www.jbalayrac.com) and [Antoine Miech](https://antoine77340.github.io) for their advice, [Rohan Taori](https://www.rohantaori.com/), [Nicholas Schiefer](https://nicholasschiefer.com/), [Deep Ganguli](https://hai.stanford.edu/people/deep-ganguli), [Thomas Liao](https://thomasliao.com/), [Tatsunori Hashimoto](https://thashim.github.io/), and [Nicholas Carlini](https://nicholas.carlini.com/) for their help with assessing the safety risks of our release, and to [Stability AI](https://stability.ai) for providing us with compute resources to train these models. 232 | 233 | # Citing 234 | If you found this repository useful, please consider citing: 235 | 236 | ``` 237 | @article{awadalla2023openflamingo, 238 | title={OpenFlamingo: An Open-Source Framework for Training Large Autoregressive Vision-Language Models}, 239 | author={Anas Awadalla and Irena Gao and Josh Gardner and Jack Hessel and Yusuf Hanafy and Wanrong Zhu and Kalyani Marathe and Yonatan Bitton and Samir Gadre and Shiori Sagawa and Jenia Jitsev and Simon Kornblith and Pang Wei Koh and Gabriel Ilharco and Mitchell Wortsman and Ludwig Schmidt}, 240 | journal={arXiv preprint arXiv:2308.01390}, 241 | year={2023} 242 | } 243 | ``` 244 | 245 | ``` 246 | @software{anas_awadalla_2023_7733589, 247 | author = {Awadalla, Anas and Gao, Irena and Gardner, Joshua and Hessel, Jack and Hanafy, Yusuf and Zhu, Wanrong and Marathe, Kalyani and Bitton, Yonatan and Gadre, Samir and Jitsev, Jenia and Kornblith, Simon and Koh, Pang Wei and Ilharco, Gabriel and Wortsman, Mitchell and Schmidt, Ludwig}, 248 | title = {OpenFlamingo}, 249 | month = mar, 250 | year = 2023, 251 | publisher = {Zenodo}, 252 | version = {v0.1.1}, 253 | doi = {10.5281/zenodo.7733589}, 254 | url = {https://doi.org/10.5281/zenodo.7733589} 255 | } 256 | ``` 257 | 258 | ``` 259 | @article{Alayrac2022FlamingoAV, 260 | title={Flamingo: a Visual Language Model for Few-Shot Learning}, 261 | author={Jean-Baptiste Alayrac and Jeff Donahue and Pauline Luc and Antoine Miech and Iain Barr and Yana Hasson and Karel Lenc and Arthur Mensch and Katie Millican and Malcolm Reynolds and Roman Ring and Eliza Rutherford and Serkan Cabi and Tengda Han and Zhitao Gong and Sina Samangooei and Marianne Monteiro and Jacob Menick and Sebastian Borgeaud and Andy Brock and Aida Nematzadeh and Sahand Sharifzadeh and Mikolaj Binkowski and Ricardo Barreira and Oriol Vinyals and Andrew Zisserman and Karen Simonyan}, 262 | journal={ArXiv}, 263 | year={2022}, 264 | volume={abs/2204.14198} 265 | } 266 | ``` 267 | -------------------------------------------------------------------------------- /open_flamingo/TERMS_AND_CONDITIONS.md: -------------------------------------------------------------------------------- 1 | **Please read the following information carefully before proceeding.** 2 | 3 | OpenFlamingo is a **research prototype** that aims to enable users to interact with AI through both language and images. AI agents equipped with both language and visual understanding can be useful on a larger variety of tasks compared to models that communicate solely via language. By releasing an open-source research prototype, we hope to help the research community better understand the risks and limitations of modern visual-language AI models and accelerate the development of safer and more reliable methods. 4 | 5 | - [ ] I understand that OpenFlamingo is a research prototype and I will only use it for non-commercial research purposes. 6 | 7 | **Limitations.** OpenFlamingo is built on top of the LLaMA large language model developed by Meta AI. Large language models, including LLaMA, are trained on mostly unfiltered internet data, and have been shown to be able to produce toxic, unethical, inaccurate, and harmful content. On top of this, OpenFlamingo’s ability to support visual inputs creates additional risks, since it can be used in a wider variety of applications; image+text models may carry additional risks specific to multimodality. Please use discretion when assessing the accuracy or appropriateness of the model’s outputs, and be mindful before sharing its results. 8 | 9 | - [ ] I understand that OpenFlamingo may produce unintended, inappropriate, offensive, and/or inaccurate results. I agree to take full responsibility for any use of the OpenFlamingo outputs that I generate. 10 | 11 | **Privacy and data collection.** This demo does NOT store any personal information on its users, and it does NOT store user queries. 12 | 13 | **Licensing.** As OpenFlamingo is built on top of the LLaMA large language model from Meta AI, the LLaMA license agreement (as documented in the Meta request form) also applies. 14 | 15 | - [ ] I have read and agree to the terms of the LLaMA license agreement. 16 | -------------------------------------------------------------------------------- /open_flamingo/environment.yml: -------------------------------------------------------------------------------- 1 | name: openflamingo 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.9 6 | - conda-forge::openjdk 7 | - pip 8 | - pip: 9 | - -r requirements.txt 10 | - -r requirements-training.txt 11 | - -r requirements-eval.txt 12 | - -e . 13 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: open-flamingo 3 | Version: 2.0.1 4 | Summary: An open-source framework for training large multimodal models 5 | Home-page: UNKNOWN 6 | License: MIT 7 | Keywords: machine learning 8 | Platform: UNKNOWN 9 | Classifier: Development Status :: 4 - Beta 10 | Classifier: Intended Audience :: Developers 11 | Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence 12 | Classifier: License :: OSI Approved :: MIT License 13 | Classifier: Programming Language :: Python :: 3.9 14 | Description-Content-Type: text/markdown 15 | Provides-Extra: eval 16 | Provides-Extra: training 17 | Provides-Extra: all 18 | License-File: LICENSE 19 | 20 | # 🦩 OpenFlamingo 21 | 22 | [![PyPI version](https://badge.fury.io/py/open_flamingo.svg)](https://badge.fury.io/py/open_flamingo) 23 | 24 | [Paper](https://arxiv.org/abs/2308.01390) | Blog posts: [1](https://laion.ai/blog/open-flamingo/), [2](https://laion.ai/blog/open-flamingo-v2/) | [Demo](https://huggingface.co/spaces/openflamingo/OpenFlamingo) 25 | 26 | Welcome to our open source implementation of DeepMind's [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model)! 27 | 28 | In this repository, we provide a PyTorch implementation for training and evaluating OpenFlamingo models. 29 | If you have any questions, please feel free to open an issue. We also welcome contributions! 30 | 31 | # Table of Contents 32 | - [Installation](#installation) 33 | - [Approach](#approach) 34 | * [Model architecture](#model-architecture) 35 | - [Usage](#usage) 36 | * [Initializing an OpenFlamingo model](#initializing-an-openflamingo-model) 37 | * [Generating text](#generating-text) 38 | - [Training](#training) 39 | * [Dataset](#dataset) 40 | - [Evaluation](#evaluation) 41 | - [Future plans](#future-plans) 42 | - [Team](#team) 43 | - [Acknowledgments](#acknowledgments) 44 | - [Citing](#citing) 45 | 46 | # Installation 47 | 48 | To install the package in an existing environment, run 49 | ``` 50 | pip install open-flamingo 51 | ``` 52 | 53 | or to create a conda environment for running OpenFlamingo, run 54 | ``` 55 | conda env create -f environment.yml 56 | ``` 57 | 58 | To install training or eval dependencies, run one of the first two commands. To install everything, run the third command. 59 | ``` 60 | pip install open-flamingo[training] 61 | pip install open-flamingo[eval] 62 | pip install open-flamingo[all] 63 | ``` 64 | 65 | There are three `requirements.txt` files: 66 | - `requirements.txt` 67 | - `requirements-training.txt` 68 | - `requirements-eval.txt` 69 | 70 | Depending on your use case, you can install any of these with `pip install -r `. The base file contains only the dependencies needed for running the model. 71 | 72 | # Approach 73 | OpenFlamingo is a multimodal language model that can be used for a variety of tasks. It is trained on a large multimodal dataset (e.g. Multimodal C4) and can be used to generate text conditioned on interleaved images/text. For example, OpenFlamingo can be used to generate a caption for an image, or to generate a question given an image and a text passage. The benefit of this approach is that we are able to rapidly adapt to new tasks using in-context learning. 74 | 75 | ## Model architecture 76 | OpenFlamingo combines a pretrained vision encoder and a language model using cross attention layers. The model architecture is shown below. 77 | 78 | ![OpenFlamingo architecture](docs/flamingo.png) 79 | Credit: [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) 80 | 81 | # Usage 82 | ## Initializing an OpenFlamingo model 83 | We support pretrained vision encoders from the [OpenCLIP](https://github.com/mlfoundations/open_clip) package, which includes OpenAI's pretrained models. 84 | We also support pretrained language models from the `transformers` package, such as [MPT](https://huggingface.co/models?search=mosaicml%20mpt), [RedPajama](https://huggingface.co/models?search=redpajama), [LLaMA](https://huggingface.co/models?search=llama), [OPT](https://huggingface.co/models?search=opt), [GPT-Neo](https://huggingface.co/models?search=gpt-neo), [GPT-J](https://huggingface.co/models?search=gptj), and [Pythia](https://huggingface.co/models?search=pythia) models. 85 | 86 | ``` python 87 | from open_flamingo import create_model_and_transforms 88 | 89 | model, image_processor, tokenizer = create_model_and_transforms( 90 | clip_vision_encoder_path="ViT-L-14", 91 | clip_vision_encoder_pretrained="openai", 92 | lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b", 93 | tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b", 94 | cross_attn_every_n_layers=1, 95 | cache_dir="PATH/TO/CACHE/DIR" # Defaults to ~/.cache 96 | ) 97 | ``` 98 | 99 | ## Released OpenFlamingo models 100 | We have trained the following OpenFlamingo models so far. 101 | 102 | |# params|Language model|Vision encoder|Xattn interval*|COCO 4-shot CIDEr|VQAv2 4-shot Accuracy|Weights| 103 | |------------|--------------|--------------|----------|-----------|-------|----| 104 | |3B| mosaicml/mpt-1b-redpajama-200b | openai CLIP ViT-L/14 | 1 | 77.3 | 45.8 |[Link](https://huggingface.co/openflamingo/OpenFlamingo-3B-vitl-mpt1b)| 105 | |3B| mosaicml/mpt-1b-redpajama-200b-dolly | openai CLIP ViT-L/14 | 1 | 82.7 | 45.7 |[Link](https://huggingface.co/openflamingo/OpenFlamingo-3B-vitl-mpt1b-langinstruct)| 106 | |4B| togethercomputer/RedPajama-INCITE-Base-3B-v1 | openai CLIP ViT-L/14 | 2 | 81.8 | 49.0 | [Link](https://huggingface.co/openflamingo/OpenFlamingo-4B-vitl-rpj3b)| 107 | |4B| togethercomputer/RedPajama-INCITE-Instruct-3B-v1 | openai CLIP ViT-L/14 | 2 | 85.8 | 49.0 | [Link](https://huggingface.co/openflamingo/OpenFlamingo-4B-vitl-rpj3b-langinstruct)| 108 | |9B| mosaicml/mpt-7b | openai CLIP ViT-L/14 | 4 | 89.0 | 54.8 | [Link](https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b)| 109 | 110 | *\* Xattn interval refers to the `--cross_attn_every_n_layers` argument.* 111 | 112 | Note: as part of our v2 release, we have deprecated a previous LLaMA-based checkpoint. However, you can continue to use our older checkpoint using the new codebase. 113 | 114 | ## Downloading pretrained weights 115 | 116 | To instantiate an OpenFlamingo model with one of our released weights, initialize the model as above and use the following code. 117 | 118 | ```python 119 | # grab model checkpoint from huggingface hub 120 | from huggingface_hub import hf_hub_download 121 | import torch 122 | 123 | checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt") 124 | model.load_state_dict(torch.load(checkpoint_path), strict=False) 125 | ``` 126 | 127 | ## Generating text 128 | Below is an example of generating text conditioned on interleaved images/text. In particular, let's try few-shot image captioning. 129 | 130 | ``` python 131 | from PIL import Image 132 | import requests 133 | import torch 134 | 135 | """ 136 | Step 1: Load images 137 | """ 138 | demo_image_one = Image.open( 139 | requests.get( 140 | "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True 141 | ).raw 142 | ) 143 | 144 | demo_image_two = Image.open( 145 | requests.get( 146 | "http://images.cocodataset.org/test-stuff2017/000000028137.jpg", 147 | stream=True 148 | ).raw 149 | ) 150 | 151 | query_image = Image.open( 152 | requests.get( 153 | "http://images.cocodataset.org/test-stuff2017/000000028352.jpg", 154 | stream=True 155 | ).raw 156 | ) 157 | 158 | 159 | """ 160 | Step 2: Preprocessing images 161 | Details: For OpenFlamingo, we expect the image to be a torch tensor of shape 162 | batch_size x num_media x num_frames x channels x height x width. 163 | In this case batch_size = 1, num_media = 3, num_frames = 1, 164 | channels = 3, height = 224, width = 224. 165 | """ 166 | vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)] 167 | vision_x = torch.cat(vision_x, dim=0) 168 | vision_x = vision_x.unsqueeze(1).unsqueeze(0) 169 | 170 | """ 171 | Step 3: Preprocessing text 172 | Details: In the text we expect an special token to indicate where an image is. 173 | We also expect an <|endofchunk|> special token to indicate the end of the text 174 | portion associated with an image. 175 | """ 176 | tokenizer.padding_side = "left" # For generation padding tokens should be on the left 177 | lang_x = tokenizer( 178 | ["An image of two cats.<|endofchunk|>An image of a bathroom sink.<|endofchunk|>An image of"], 179 | return_tensors="pt", 180 | ) 181 | 182 | 183 | """ 184 | Step 4: Generate text 185 | """ 186 | generated_text = model.generate( 187 | vision_x=vision_x, 188 | lang_x=lang_x["input_ids"], 189 | attention_mask=lang_x["attention_mask"], 190 | max_new_tokens=20, 191 | num_beams=3, 192 | ) 193 | 194 | print("Generated text: ", tokenizer.decode(generated_text[0])) 195 | ``` 196 | 197 | # Training 198 | We provide training scripts in `open_flamingo/train`. We provide an example Slurm script in `open_flamingo/scripts/run_train.py`, as well as the following example command: 199 | ``` 200 | torchrun --nnodes=1 --nproc_per_node=4 open_flamingo/train/train.py \ 201 | --lm_path anas-awadalla/mpt-1b-redpajama-200b \ 202 | --tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \ 203 | --cross_attn_every_n_layers 1 \ 204 | --dataset_resampled \ 205 | --batch_size_mmc4 32 \ 206 | --batch_size_laion 64 \ 207 | --train_num_samples_mmc4 125000\ 208 | --train_num_samples_laion 250000 \ 209 | --loss_multiplier_laion 0.2 \ 210 | --workers=4 \ 211 | --run_name OpenFlamingo-3B-vitl-mpt1b \ 212 | --num_epochs 480 \ 213 | --warmup_steps 1875 \ 214 | --mmc4_textsim_threshold 0.24 \ 215 | --laion_shards "/path/to/shards/shard-{0000..0999}.tar" \ 216 | --mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \ 217 | --report_to_wandb 218 | ``` 219 | 220 | *Note: The MPT-1B [base](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) and [instruct](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b-dolly) modeling code does not accept the `labels` kwarg or compute cross-entropy loss directly within `forward()`, as expected by our codebase. We suggest using a modified version of the MPT-1B models found [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b) and [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b-dolly).* 221 | 222 | For more details, see our [training README](https://github.com/mlfoundations/open_flamingo/tree/main/open_flamingo/train). 223 | 224 | 225 | # Evaluation 226 | An example evaluation script is at `open_flamingo/scripts/run_eval.sh`. Please see our [evaluation README](https://github.com/mlfoundations/open_flamingo/tree/main/open_flamingo/eval) for more details. 227 | 228 | 229 | To run evaluations on OKVQA you will need to run the following command: 230 | ``` 231 | import nltk 232 | nltk.download('wordnet') 233 | ``` 234 | 235 | 236 | # Future plans 237 | - [ ] Add support for video input 238 | 239 | # Team 240 | 241 | OpenFlamingo is developed by: 242 | 243 | [Anas Awadalla*](https://anas-awadalla.streamlit.app/), [Irena Gao*](https://i-gao.github.io/), [Joshua Gardner](https://homes.cs.washington.edu/~jpgard/), [Jack Hessel](https://jmhessel.com/), [Yusuf Hanafy](https://www.linkedin.com/in/yusufhanafy/), [Wanrong Zhu](https://wanrong-zhu.com/), [Kalyani Marathe](https://sites.google.com/uw.edu/kalyanimarathe/home?authuser=0), [Yonatan Bitton](https://yonatanbitton.github.io/), [Samir Gadre](https://sagadre.github.io/), [Shiori Sagawa](https://cs.stanford.edu/~ssagawa/), [Jenia Jitsev](https://scholar.google.de/citations?user=p1FuAMkAAAAJ&hl=en), [Simon Kornblith](https://simonster.com/), [Pang Wei Koh](https://koh.pw/), [Gabriel Ilharco](https://gabrielilharco.com/), [Mitchell Wortsman](https://mitchellnw.github.io/), [Ludwig Schmidt](https://people.csail.mit.edu/ludwigs/). 244 | 245 | The team is primarily from the University of Washington, Stanford, AI2, UCSB, and Google. 246 | 247 | # Acknowledgments 248 | This code is based on Lucidrains' [flamingo implementation](https://github.com/lucidrains/flamingo-pytorch) and David Hansmair's [flamingo-mini repo](https://github.com/dhansmair/flamingo-mini). Thank you for making your code public! We also thank the [OpenCLIP](https://github.com/mlfoundations/open_clip) team as we use their data loading code and take inspiration from their library design. 249 | 250 | We would also like to thank [Jean-Baptiste Alayrac](https://www.jbalayrac.com) and [Antoine Miech](https://antoine77340.github.io) for their advice, [Rohan Taori](https://www.rohantaori.com/), [Nicholas Schiefer](https://nicholasschiefer.com/), [Deep Ganguli](https://hai.stanford.edu/people/deep-ganguli), [Thomas Liao](https://thomasliao.com/), [Tatsunori Hashimoto](https://thashim.github.io/), and [Nicholas Carlini](https://nicholas.carlini.com/) for their help with assessing the safety risks of our release, and to [Stability AI](https://stability.ai) for providing us with compute resources to train these models. 251 | 252 | # Citing 253 | If you found this repository useful, please consider citing: 254 | 255 | ``` 256 | @article{awadalla2023openflamingo, 257 | title={OpenFlamingo: An Open-Source Framework for Training Large Autoregressive Vision-Language Models}, 258 | author={Anas Awadalla and Irena Gao and Josh Gardner and Jack Hessel and Yusuf Hanafy and Wanrong Zhu and Kalyani Marathe and Yonatan Bitton and Samir Gadre and Shiori Sagawa and Jenia Jitsev and Simon Kornblith and Pang Wei Koh and Gabriel Ilharco and Mitchell Wortsman and Ludwig Schmidt}, 259 | journal={arXiv preprint arXiv:2308.01390}, 260 | year={2023} 261 | } 262 | ``` 263 | 264 | ``` 265 | @software{anas_awadalla_2023_7733589, 266 | author = {Awadalla, Anas and Gao, Irena and Gardner, Joshua and Hessel, Jack and Hanafy, Yusuf and Zhu, Wanrong and Marathe, Kalyani and Bitton, Yonatan and Gadre, Samir and Jitsev, Jenia and Kornblith, Simon and Koh, Pang Wei and Ilharco, Gabriel and Wortsman, Mitchell and Schmidt, Ludwig}, 267 | title = {OpenFlamingo}, 268 | month = mar, 269 | year = 2023, 270 | publisher = {Zenodo}, 271 | version = {v0.1.1}, 272 | doi = {10.5281/zenodo.7733589}, 273 | url = {https://doi.org/10.5281/zenodo.7733589} 274 | } 275 | ``` 276 | 277 | ``` 278 | @article{Alayrac2022FlamingoAV, 279 | title={Flamingo: a Visual Language Model for Few-Shot Learning}, 280 | author={Jean-Baptiste Alayrac and Jeff Donahue and Pauline Luc and Antoine Miech and Iain Barr and Yana Hasson and Karel Lenc and Arthur Mensch and Katie Millican and Malcolm Reynolds and Roman Ring and Eliza Rutherford and Serkan Cabi and Tengda Han and Zhitao Gong and Sina Samangooei and Marianne Monteiro and Jacob Menick and Sebastian Borgeaud and Andy Brock and Aida Nematzadeh and Sahand Sharifzadeh and Mikolaj Binkowski and Ricardo Barreira and Oriol Vinyals and Andrew Zisserman and Karen Simonyan}, 281 | journal={ArXiv}, 282 | year={2022}, 283 | volume={abs/2204.14198} 284 | } 285 | ``` 286 | 287 | 288 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/__init__.py: -------------------------------------------------------------------------------- 1 | from .src.flamingo import Flamingo 2 | from .src.factory import create_model_and_transforms 3 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/eval/README.md: -------------------------------------------------------------------------------- 1 | # OpenFlamingo Evaluation Suite 2 | 3 | This is the evaluation module of OpenFlamingo. It contains a set of utilities for evaluating multimodal models on various benchmarking datasets. 4 | 5 | *This module is a work in progress! We will be updating this README as it develops. In the meantime, if you notice an issue, please file a Bug Report or Feature Request [here](https://github.com/mlfoundations/open_flamingo/issues/new/choose).* 6 | 7 | ## Supported datasets 8 | 9 | |Dataset|Task|Metric|Evaluation method| 10 | |-------|----|------|-----------------| 11 | |[COCO](https://arxiv.org/abs/1405.0312)|Captioning|CIDEr|Generation| 12 | |[Flickr-30K](https://aclanthology.org/Q14-1006/)|Captioning|CIDEr|Generation| 13 | |[VQAv2](https://arxiv.org/abs/1612.00837v3)|VQA|VQA accuracy|Generation| 14 | |[OK-VQA](https://arxiv.org/abs/1906.00067)|VQA|VQA accuracy|Generation| 15 | |[TextVQA](https://arxiv.org/abs/1904.08920)|VQA|VQA accuracy|Generation| 16 | |[VizWiz](https://arxiv.org/abs/1802.08218)|VQA|VQA accuracy|Generation| 17 | |[Hateful Memes](https://arxiv.org/abs/2005.04790)|Classification|ROC AUC|Logprobs| 18 | |[ImageNet](https://arxiv.org/abs/1409.0575)|Classification|Top-1 accuracy|Logprobs| 19 | 20 | When evaluating a model using `num_shots` shots, we sample the exemplars from the training split. Performance is evaluated on a disjoint test split, subsampled to `--num_samples` examples (or using the full test split if `--num_samples=-1`). 21 | 22 | ## Sample scripts 23 | Our codebase uses DistributedDataParallel to parallelize evaluation by default, so please make sure to set the `MASTER_ADDR` and `MASTER_PORT` environment variables or use `torchrun`. We provide a sample Slurm evaluation script in `open_flamingo/open_flamingo/scripts/run_eval.sh`. 24 | 25 | We also support evaluating at a lower precision using the `--precision` flag. We find minimal difference between evaluating at full precision vs. amp_bf16. 26 | 27 | To evaluate one of our pretrained checkpoints, we suggest first downloading a local copy of the weights, as follows: 28 | 29 | ``` 30 | # grab model checkpoint from huggingface hub 31 | from huggingface_hub import hf_hub_download 32 | HF_TOKEN="" 33 | 34 | checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt") 35 | checkpoint_path= hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", 36 | "checkpoint.pt", 37 | local_dir="openflamingo/OpenFlamingo-3B-vitl-mpt1b", 38 | cache_dir="openflamingo/OpenFlamingo-3B-vitl-mpt1b", 39 | local_dir_use_symlinks=False, 40 | token=HF_TOKEN) 41 | print(checkpoint_path) 42 | ## openflamingo/OpenFlamingo-3B-vitl-mpt1b/checkpoint.pt 43 | ``` 44 | 45 | This should place the OpenFlamingo model at the expected location in the evaluation script. 46 | 47 | For TextVQA and VizWiz we expect annotations to be formatted differently than the original datasets. We provide the custom annotations in `open_flamingo/open_flamingo/eval/data/`. We have also uploaded all the annotation files in a [huggingface dataset](https://huggingface.co/datasets/openflamingo/eval_benchmark/tree/main) for easy access. 48 | 49 | # Evaluating using RICES (Retrieval-based In-Context Example Selection) 50 | 51 | We provide the option to evaluate using RICES, which is a method for selecting exemplars from the training set based on image similarity. This method was used in DeepMind's implementation for evaluating on ImageNet, but can be used for any dataset in our evaluation suite. 52 | 53 | To use RICES, you must first create features for a benchmark's training set. We provide a script for doing so in `open_flamingo/open_flamingo/scripts/cache_rices_features.py`. This script will extract image features for a given dataset using a given CLIP model checkpoint. For example, to extract features for the COCO training set, you can run: 54 | 55 | ```bash 56 | python cache_rices_features.py \ 57 | --vision_encoder_path ViT-L-14 \ 58 | --vision_encoder_pretrained openai \ 59 | --batch_size 128 \ 60 | --eval_coco \ 61 | --coco_train_image_dir_path /path/to/coco/train2014 \ 62 | --coco_val_image_dir_path /path/to/coco/val2014 \ 63 | --coco_karpathy_json_path /path/to/coco/dataset_coco.json \ 64 | --coco_annotations_json_path /path/to/coco/annotations/captions_train2014.json \ 65 | --output_dir /path/to/coco/features 66 | ``` 67 | 68 | This will create a directory at `/path/to/coco/features` containing a file named `coco.pkl` with the extracted features. You can then use this directory to evaluate using RICES by passing the `--rices` flag to the evaluation script, specifying the path to the features directory using the `--cached_demonstration_features` flag, and specifying the vision encoder to use for RICES using the `--rices_vision_encoder_path` and `--rices_vision_encoder_pretrained` flags. 69 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/eval/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/eval/coco_metric.py: -------------------------------------------------------------------------------- 1 | from pycocoevalcap.eval import COCOEvalCap 2 | from pycocotools.coco import COCO 3 | 4 | 5 | def compute_cider( 6 | result_path, 7 | annotations_path, 8 | ): 9 | # create coco object and coco_result object 10 | coco = COCO(annotations_path) 11 | coco_result = coco.loadRes(result_path) 12 | 13 | # create coco_eval object by taking coco and coco_result 14 | coco_eval = COCOEvalCap(coco, coco_result) 15 | coco_eval.params["image_id"] = coco_result.getImgIds() 16 | coco_eval.evaluate() 17 | 18 | return coco_eval.eval 19 | 20 | 21 | def postprocess_captioning_generation(predictions): 22 | return predictions.split("Output", 1)[0] 23 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/eval/eval_datasets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision.datasets import ImageFolder 7 | 8 | from open_flamingo.eval.classification_utils import IMAGENET_CLASSNAMES 9 | 10 | 11 | class CaptionDataset(Dataset): 12 | def __init__( 13 | self, 14 | image_train_dir_path, 15 | annotations_path, 16 | is_train, 17 | dataset_name, 18 | image_val_dir_path=None, 19 | ): 20 | self.image_train_dir_path = image_train_dir_path 21 | self.image_val_dir_path = image_val_dir_path 22 | self.annotations = [] 23 | self.is_train = is_train 24 | self.dataset_name = dataset_name 25 | 26 | full_annotations = json.load(open(annotations_path))["images"] 27 | 28 | for i in range(len(full_annotations)): 29 | if self.is_train and full_annotations[i]["split"] != "train": 30 | continue 31 | elif not self.is_train and full_annotations[i]["split"] != "test": 32 | continue 33 | 34 | self.annotations.append(full_annotations[i]) 35 | 36 | def __len__(self): 37 | return len(self.annotations) 38 | 39 | def __getitem__(self, idx): 40 | if self.dataset_name == "coco": 41 | image = Image.open( 42 | os.path.join( 43 | self.image_train_dir_path, self.annotations[idx]["filename"] 44 | ) 45 | if self.annotations[idx]["filepath"] == "train2014" 46 | else os.path.join( 47 | self.image_val_dir_path, self.annotations[idx]["filename"] 48 | ) 49 | ) 50 | elif self.dataset_name == "flickr": 51 | image = Image.open( 52 | os.path.join( 53 | self.image_train_dir_path, self.annotations[idx]["filename"] 54 | ) 55 | ) 56 | image.load() 57 | caption = self.annotations[idx]["sentences"][0]["raw"] 58 | return { 59 | "image": image, 60 | "caption": caption, 61 | "image_id": self.annotations[idx]["cocoid"] 62 | if self.dataset_name == "coco" 63 | else self.annotations[idx]["filename"].split(".")[0], 64 | } 65 | 66 | 67 | class VQADataset(Dataset): 68 | def __init__( 69 | self, image_dir_path, question_path, annotations_path, is_train, dataset_name 70 | ): 71 | self.questions = json.load(open(question_path, "r"))["questions"] 72 | if annotations_path is not None: 73 | self.answers = json.load(open(annotations_path, "r"))["annotations"] 74 | else: 75 | self.answers = None 76 | self.image_dir_path = image_dir_path 77 | self.is_train = is_train 78 | self.dataset_name = dataset_name 79 | if self.dataset_name in {"vqav2", "ok_vqa"}: 80 | self.img_coco_split = self.image_dir_path.strip("/").split("/")[-1] 81 | assert self.img_coco_split in {"train2014", "val2014", "test2015"} 82 | 83 | def __len__(self): 84 | return len(self.questions) 85 | 86 | def get_img_path(self, question): 87 | if self.dataset_name in {"vqav2", "ok_vqa"}: 88 | return os.path.join( 89 | self.image_dir_path, 90 | f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg" 91 | if self.is_train 92 | else f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg", 93 | ) 94 | elif self.dataset_name == "vizwiz": 95 | return os.path.join(self.image_dir_path, question["image_id"]) 96 | elif self.dataset_name == "textvqa": 97 | return os.path.join(self.image_dir_path, f"{question['image_id']}.jpg") 98 | else: 99 | raise Exception(f"Unknown VQA dataset {self.dataset_name}") 100 | 101 | def __getitem__(self, idx): 102 | question = self.questions[idx] 103 | img_path = self.get_img_path(question) 104 | image = Image.open(img_path) 105 | image.load() 106 | results = { 107 | "image": image, 108 | "question": question["question"], 109 | "question_id": question["question_id"], 110 | } 111 | if self.answers is not None: 112 | answers = self.answers[idx] 113 | results["answers"] = [a["answer"] for a in answers["answers"]] 114 | return results 115 | 116 | 117 | class ImageNetDataset(ImageFolder): 118 | """Class to represent the ImageNet1k dataset.""" 119 | 120 | def __init__(self, root, **kwargs): 121 | super().__init__(root=root, **kwargs) 122 | self.class_id_to_name = dict( 123 | zip(range(len(IMAGENET_CLASSNAMES)), IMAGENET_CLASSNAMES) 124 | ) 125 | 126 | def __getitem__(self, idx): 127 | sample, target = super().__getitem__(idx) 128 | target_label = self.class_id_to_name[target] 129 | return { 130 | "id": idx, 131 | "image": sample, 132 | "class_id": target, # numeric ID of the ImageNet class 133 | "class_name": target_label, # human-readable name of ImageNet class 134 | } 135 | 136 | 137 | class HatefulMemesDataset(Dataset): 138 | def __init__(self, image_dir_path, annotations_path): 139 | self.image_dir_path = image_dir_path 140 | with open(annotations_path, "r") as f: 141 | self.annotations = [json.loads(line) for line in f] 142 | 143 | def __len__(self): 144 | return len(self.annotations) 145 | 146 | def __getitem__(self, idx): 147 | annotation = self.annotations[idx] 148 | img_path = os.path.join(self.image_dir_path, annotation["img"].split("/")[-1]) 149 | image = Image.open(img_path) 150 | image.load() 151 | return { 152 | "id": annotation["id"], 153 | "image": image, 154 | "ocr": annotation["text"], 155 | "class_name": "yes" if annotation["label"] == 1 else "no", 156 | "class_id": annotation["label"], 157 | } 158 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/eval/eval_model.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import argparse 3 | from typing import List 4 | from torch.nn.parallel import DistributedDataParallel as DDP 5 | from PIL import Image 6 | 7 | 8 | class BaseEvalModel(abc.ABC): 9 | """Base class encapsulating functionality needed to evaluate a model.""" 10 | 11 | def __init__(self, args: List[str]): 12 | """Initialize model. 13 | 14 | Args: 15 | args: arguments to model. These should be parsed, or if the model 16 | has no applicable arguments, an error should be thrown if `args` 17 | is non-empty. 18 | """ 19 | 20 | def init_distributed(self): 21 | """Wrap model as DDP.""" 22 | self.model = DDP(self.model, device_ids=[self.device]) 23 | 24 | def set_device(self, device): 25 | """Set device for model.""" 26 | self.device = device 27 | self.model = self.model.to(device) 28 | 29 | def get_outputs( 30 | self, 31 | batch_text: List[str], 32 | batch_images: List[List[Image.Image]], 33 | min_generation_length: int, 34 | max_generation_length: int, 35 | num_beams: int, 36 | length_penalty: float, 37 | ) -> List[str]: 38 | """Get outputs for a batch of images and text. 39 | 40 | Args: 41 | batch_text: list of text strings, with the text "" in place 42 | of any images to be included. 43 | batch_images: images to provide to model. Should be a list of lists, 44 | where each list contains the images for a single example. 45 | max_generation_length: maximum length of the generated caption. 46 | Defaults to 10. 47 | num_beams: number of beams to use for beam search. Defaults to 3. 48 | length_penalty: length penalty for beam search. Defaults to -2.0. 49 | 50 | Returns: 51 | List of decoded output strings. 52 | """ 53 | 54 | def vqa_prompt(self, question, answer=None) -> str: 55 | """Get the prompt to use for VQA evaluation. If the answer is not provided, it should be left blank to be generated by the model. 56 | 57 | Returns: 58 | The prompt to use for VQA. 59 | """ 60 | 61 | def caption_prompt(self, caption=None) -> str: 62 | """Get the prompt to use for caption evaluation. If the caption is not provided, it should be left blank to be generated by the model. 63 | 64 | Returns: 65 | The prompt to use for captioning. 66 | """ 67 | 68 | def get_rank_classifications( 69 | self, 70 | batch_text: List[str], 71 | batch_images: List[List[Image.Image]], 72 | all_class_names: List[str], 73 | use_cache: bool, 74 | normalize_length: bool, 75 | ): 76 | """ 77 | Returns a (B, |all_class_names|) tensor containing the logprobs for each class name. 78 | Args: 79 | batch_text: list of text strings, with the text "" in place 80 | of any images to be included. 81 | batch_images: images to provide to model. Should be a list of lists, 82 | where each list contains the images for a single example. 83 | all_class_names: list of all class names. 84 | use_cache: whether to cache the context to speed up evaluations. 85 | normalize_length: whether to normalize logprobs by the length of the 86 | class name 87 | Returns: 88 | (B, |all_class_names|) tensor containing the logprobs for each class name. 89 | """ 90 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/eval/models/blip.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from PIL import Image 4 | import torch 5 | 6 | from transformers import Blip2Processor, Blip2ForConditionalGeneration 7 | from open_flamingo.eval.eval_model import BaseEvalModel 8 | from open_flamingo.eval.utils import unwrap_model 9 | 10 | 11 | class EvalModel(BaseEvalModel): 12 | """BLIP-2 model evaluation. 13 | 14 | Attributes: 15 | model (nn.Module): Underlying Torch model. 16 | tokenizer (transformers.PreTrainedTokenizer): Tokenizer for model. 17 | device: Index of GPU to use, or the string "cpu" 18 | """ 19 | 20 | def __init__(self, model_args): 21 | assert ( 22 | "processor_path" in model_args and "lm_path" in model_args 23 | ), "BLIP-2 requires processor_path, lm_path, and device arguments to be specified" 24 | 25 | self.processor = Blip2Processor.from_pretrained(model_args["processor_path"]) 26 | self.model = Blip2ForConditionalGeneration.from_pretrained( 27 | model_args["lm_path"] 28 | ) 29 | self.model.eval() 30 | self.processor.tokenizer.padding_side = "left" 31 | self.lm_name = model_args["lm_path"].split("/")[-1] 32 | 33 | def _prepare_images(self, batch: List[List[torch.Tensor]]) -> torch.Tensor: 34 | """Preprocess images and stack them. 35 | 36 | Args: 37 | batch: A list of lists of images. 38 | 39 | Returns: 40 | A Tensor of shape 41 | (batch_size, channels, height, width). 42 | """ 43 | batch_images = None 44 | assert all( 45 | len(example) == 1 for example in batch 46 | ), "BLIP-2 only supports one image per example" 47 | 48 | for example in batch: 49 | assert len(example) == 1, "BLIP-2 only supports one image per example" 50 | batch_images = torch.cat( 51 | [ 52 | batch_images, 53 | self.processor.image_processor(example, return_tensors="pt")[ 54 | "pixel_values" 55 | ], 56 | ] 57 | if batch_images is not None 58 | else [ 59 | self.processor.image_processor(example, return_tensors="pt")[ 60 | "pixel_values" 61 | ] 62 | ], 63 | dim=0, 64 | ) 65 | return batch_images 66 | 67 | def get_outputs( 68 | self, 69 | batch_text: List[str], 70 | batch_images: List[List[Image.Image]], 71 | min_generation_length: int, 72 | max_generation_length: int, 73 | num_beams: int, 74 | length_penalty: float, 75 | ) -> List[str]: 76 | encodings = self.processor.tokenizer( 77 | batch_text, 78 | padding="longest", 79 | truncation=True, 80 | return_tensors="pt", 81 | max_length=2000, 82 | ) 83 | input_ids = encodings["input_ids"] 84 | attention_mask = encodings["attention_mask"] 85 | 86 | with torch.inference_mode(): 87 | outputs = unwrap_model(self.model).generate( 88 | self._prepare_images(batch_images).to(self.device), 89 | input_ids.to(self.device), 90 | attention_mask=attention_mask.to(self.device), 91 | max_new_tokens=max_generation_length, 92 | min_new_tokens=min_generation_length, 93 | num_beams=num_beams, 94 | length_penalty=length_penalty, 95 | ) 96 | 97 | return self.processor.tokenizer.batch_decode(outputs, skip_special_tokens=True) 98 | 99 | def get_vqa_prompt(self, question, answer=None) -> str: 100 | return ( 101 | f"Question:{question} Short answer:{answer if answer is not None else ''}" 102 | ) 103 | 104 | def get_caption_prompt(self, caption=None) -> str: 105 | return f"A photo of {caption if caption is not None else ''}" 106 | 107 | def get_rank_classifications( 108 | self, 109 | batch_text: List[str], 110 | batch_images: List[List[Image.Image]], 111 | all_class_names: List[str], 112 | use_cache: bool, 113 | normalize_length: bool, 114 | ): 115 | raise NotImplementedError( 116 | "BLIP-2 classification-based evaluation not implemented" 117 | ) 118 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/eval/models/open_flamingo.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | from PIL import Image 4 | import torch 5 | from einops import repeat 6 | 7 | from open_flamingo.eval.eval_model import BaseEvalModel 8 | from open_flamingo.src.factory import create_model_and_transforms 9 | from open_flamingo.eval.utils import unwrap_model, get_autocast, get_cast_dtype 10 | from transformers.modeling_outputs import CausalLMOutputWithPast 11 | 12 | 13 | class EvalModel(BaseEvalModel): 14 | """OpenFlamingo model evaluation. 15 | 16 | Attributes: 17 | model (nn.Module): Underlying Torch model. 18 | tokenizer (transformers.PreTrainedTokenizer): Tokenizer for model. 19 | device: Index of GPU to use, or the string "CPU" 20 | """ 21 | 22 | def __init__(self, model_args): 23 | assert ( 24 | "vision_encoder_path" in model_args 25 | and "lm_path" in model_args 26 | and "checkpoint_path" in model_args 27 | and "lm_tokenizer_path" in model_args 28 | and "cross_attn_every_n_layers" in model_args 29 | and "vision_encoder_pretrained" in model_args 30 | and "precision" in model_args 31 | ), "OpenFlamingo requires vision_encoder_path, lm_path, device, checkpoint_path, lm_tokenizer_path, cross_attn_every_n_layers, vision_encoder_pretrained, and precision arguments to be specified" 32 | 33 | self.device = ( 34 | model_args["device"] 35 | if ("device" in model_args and model_args["device"] >= 0) 36 | else "cpu" 37 | ) 38 | 39 | ( 40 | self.model, 41 | self.image_processor, 42 | self.tokenizer, 43 | ) = create_model_and_transforms( 44 | model_args["vision_encoder_path"], 45 | model_args["vision_encoder_pretrained"], 46 | model_args["lm_path"], 47 | model_args["lm_tokenizer_path"], 48 | cross_attn_every_n_layers=int(model_args["cross_attn_every_n_layers"]), 49 | ) 50 | checkpoint = torch.load(model_args["checkpoint_path"], map_location=self.device) 51 | if "model_state_dict" in checkpoint: 52 | checkpoint = checkpoint["model_state_dict"] 53 | checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()} 54 | self.model.load_state_dict(checkpoint, strict=False) 55 | self.model.to(self.device) 56 | self.model.eval() 57 | self.tokenizer.padding_side = "left" 58 | 59 | self.lm_name = model_args["lm_path"].split("/")[-1] 60 | 61 | # autocast 62 | self.autocast = get_autocast(model_args["precision"]) 63 | self.cast_dtype = get_cast_dtype(model_args["precision"]) 64 | 65 | def _prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor: 66 | """ 67 | Convert images to tensors, reshape them, and stack them. 68 | Args: 69 | batch: A list of lists of images. 70 | Returns: 71 | preprocessed images (tensors) or None 72 | shape (B, T_img, F, C, H, W) 73 | None if no images in batch 74 | """ 75 | images_per_example = max(len(x) for x in batch) 76 | batch_images = None 77 | for iexample, example in enumerate(batch): 78 | for iimage, image in enumerate(example): 79 | preprocessed = self.image_processor(image) 80 | if batch_images is None: 81 | batch_images = torch.zeros( 82 | (len(batch), images_per_example, 1) + preprocessed.shape, 83 | dtype=preprocessed.dtype, 84 | ) 85 | batch_images[iexample, iimage, 0] = preprocessed 86 | if batch_images is not None: 87 | batch_images = batch_images.to( 88 | self.device, dtype=self.cast_dtype, non_blocking=True 89 | ) 90 | return batch_images 91 | 92 | def _prepare_text( 93 | self, 94 | batch: List[List[str]], 95 | padding="longest", 96 | truncation=True, 97 | max_length=2000, 98 | ): 99 | """ 100 | Tokenize the text and stack them. 101 | Args: 102 | batch: A list of lists of strings. 103 | Returns: 104 | input_ids (tensor) 105 | shape (B, T_txt) 106 | attention_mask (tensor) 107 | shape (B, T_txt) 108 | """ 109 | encodings = self.tokenizer( 110 | batch, 111 | padding=padding, 112 | truncation=truncation, 113 | return_tensors="pt", 114 | max_length=max_length, 115 | ) 116 | input_ids, attention_mask = encodings["input_ids"], encodings["attention_mask"] 117 | input_ids = input_ids.to(self.device, dtype=self.cast_dtype, non_blocking=True) 118 | attention_mask = attention_mask.to( 119 | self.device, dtype=self.cast_dtype, non_blocking=True 120 | ) 121 | return input_ids, attention_mask.bool() 122 | 123 | def get_outputs( 124 | self, 125 | batch_text: List[str], 126 | batch_images: List[List[Image.Image]], 127 | min_generation_length: int, 128 | max_generation_length: int, 129 | num_beams: int, 130 | length_penalty: float, 131 | ) -> List[str]: 132 | """ 133 | Get generation outputs. 134 | """ 135 | batch_images = self._prepare_images(batch_images) 136 | input_ids, attention_mask = self._prepare_text(batch_text) 137 | 138 | with torch.inference_mode(): 139 | with self.autocast(): 140 | outputs = unwrap_model(self.model).generate( 141 | batch_images, 142 | input_ids, 143 | attention_mask, 144 | min_new_tokens=min_generation_length, 145 | max_new_tokens=max_generation_length, 146 | num_beams=num_beams, 147 | length_penalty=length_penalty, 148 | ) 149 | 150 | # Extract only the new gnerated tokens 151 | outputs = outputs[:, len(input_ids[0]) :] 152 | 153 | return self.tokenizer.batch_decode(outputs, skip_special_tokens=True) 154 | 155 | def get_rank_classifications( 156 | self, 157 | batch_text: List[str], 158 | batch_images: List[List[Image.Image]], 159 | all_class_names: List[str], 160 | use_cache: bool, 161 | normalize_length: bool, 162 | ): 163 | """ 164 | Returns a (B, |all_class_names|) tensor containing the logprobs for each class name. 165 | """ 166 | batch_images = self._prepare_images(batch_images) 167 | ctx_input_ids, ctx_attention_mask = self._prepare_text(batch_text) 168 | 169 | # Cache the context 170 | if use_cache: 171 | # reserve the last token in the context for the main forward pass 172 | self.cache_media( 173 | input_ids=ctx_input_ids, 174 | vision_x=batch_images, 175 | ) 176 | precomputed = self.__call__( 177 | vision_x=None, 178 | lang_x=ctx_input_ids, 179 | attention_mask=ctx_attention_mask, 180 | clear_conditioned_layers=False, 181 | use_cache=True, 182 | ) 183 | precomputed_logits = precomputed.logits 184 | precomputed_pkvs = precomputed.past_key_values 185 | else: 186 | precomputed_pkvs = None 187 | 188 | # Loop through class names and get log-likelihoods 189 | # Note: if all classnames are one token, this code is redundant, since we could 190 | # get all logits after one pass. However, if there are multi-token classnames, 191 | # we need to loop through each classname separately. 192 | overall_probs = [] 193 | for class_name in all_class_names: 194 | # Tokenize only the class name 195 | classname_tokens = self.tokenizer( 196 | class_name, add_special_tokens=False, return_tensors="pt" 197 | )["input_ids"].to(self.device) 198 | assert classname_tokens.ndim == 2 199 | classname_tokens = repeat( 200 | classname_tokens, "b s -> (repeat b) s", repeat=len(batch_text) 201 | ) 202 | num_tokens_in_classname = classname_tokens.shape[1] 203 | 204 | # Concatenate the class name tokens 205 | if not use_cache: 206 | _lang_x = torch.cat([ctx_input_ids, classname_tokens], dim=1) 207 | _attention_mask = torch.cat( 208 | [ 209 | ctx_attention_mask, 210 | torch.ones_like(classname_tokens).bool(), 211 | ], 212 | dim=1, 213 | ) 214 | _vision_x = batch_images 215 | else: 216 | _lang_x = classname_tokens 217 | _attention_mask = None 218 | _vision_x = None 219 | 220 | # Call forward to get the logits 221 | outputs = self.__call__( 222 | vision_x=_vision_x, 223 | lang_x=_lang_x, 224 | attention_mask=_attention_mask, 225 | clear_conditioned_layers=(not use_cache), 226 | past_key_values=precomputed_pkvs, 227 | ) 228 | 229 | # Get the logits of the classname 230 | # logits shape is either (B, num_tokens_in_classname, vocab_len) with use_cache 231 | # or (B, len(_lang_x), vocab_len) without use_cache 232 | # remember that the logits at index t on dim 1 correspond to predictions for the t+1st token 233 | logits = outputs.logits 234 | if use_cache: 235 | logits = torch.cat([precomputed_logits, logits], dim=1) 236 | 237 | logprobs = torch.log_softmax(logits, dim=-1) 238 | gen_probs = logprobs[ 239 | :, -num_tokens_in_classname - 1 : -1, : 240 | ] # (B, num_tokens_in_classname, vocab_len) 241 | gen_probs = torch.gather( 242 | gen_probs, 2, classname_tokens[:, :, None] 243 | ).squeeze(-1) 244 | 245 | # Aggregate over tokens in the classname 246 | if normalize_length: 247 | class_prob = torch.mean(gen_probs, dim=1) 248 | else: 249 | class_prob = torch.sum(gen_probs, dim=1) 250 | overall_probs.append(class_prob) # (B, 1) 251 | 252 | self.uncache_media() 253 | overall_probs = torch.vstack(overall_probs).T.cpu() # shape (B, num_classes) 254 | return overall_probs 255 | 256 | def __call__( 257 | self, 258 | lang_x: torch.Tensor, 259 | vision_x: torch.Tensor, 260 | attention_mask: torch.Tensor, 261 | past_key_values: torch.Tensor = None, 262 | clear_conditioned_layers: bool = False, 263 | use_cache: bool = False, 264 | ): 265 | """ 266 | Calls the forward function of the model. 267 | Special logic to handle the case if past_key_values is not None: 268 | then lang_x is assumed to contain the tokens to be generated 269 | *excluding* the tokens already in past_key_values. 270 | We then repeatedly call forward, updating the past_key_values. 271 | """ 272 | # standard forward pass 273 | if past_key_values is None: 274 | with torch.inference_mode(): 275 | with self.autocast(): 276 | outputs = self.model( 277 | vision_x=vision_x, 278 | lang_x=lang_x, 279 | attention_mask=attention_mask, 280 | clear_conditioned_layers=clear_conditioned_layers, 281 | past_key_values=past_key_values, 282 | use_cache=use_cache, 283 | ) 284 | return outputs 285 | 286 | # loop to handle updating past_key_values 287 | logits = [] 288 | for token_idx in range(lang_x.shape[1]): 289 | _lang_x = lang_x[:, token_idx].reshape((-1, 1)) 290 | if attention_mask is not None: 291 | _attention_mask = attention_mask[:, token_idx].reshape((-1, 1)) 292 | else: 293 | _attention_mask = None 294 | 295 | with torch.inference_mode(): 296 | with self.autocast(): 297 | outputs = self.model( 298 | vision_x=vision_x, 299 | lang_x=_lang_x, 300 | attention_mask=_attention_mask, 301 | clear_conditioned_layers=False, 302 | past_key_values=past_key_values, 303 | use_cache=True, 304 | ) 305 | 306 | past_key_values = outputs.past_key_values 307 | logits.append(outputs.logits) 308 | 309 | logits = torch.cat(logits, dim=1) 310 | return CausalLMOutputWithPast( 311 | logits=logits, 312 | past_key_values=past_key_values, 313 | ) 314 | 315 | def encode_vision_x(self, image_tensor: torch.Tensor): 316 | unwrap_model(self.model)._encode_vision_x(image_tensor.to(self.device)) 317 | 318 | def uncache_media(self): 319 | unwrap_model(self.model).uncache_media() 320 | 321 | def cache_media(self, input_ids, vision_x): 322 | unwrap_model(self.model).cache_media(input_ids=input_ids, vision_x=vision_x) 323 | 324 | def get_vqa_prompt(self, question, answer=None) -> str: 325 | return f"Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}" 326 | 327 | def get_caption_prompt(self, caption=None) -> str: 328 | return f"Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}" 329 | 330 | def get_imagenet_prompt(self, label=None) -> str: 331 | return f"Output:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}" 332 | 333 | def get_hateful_memes_prompt(self, text, label=None) -> str: 334 | return f"is an image with: '{text}' written on it. Is it hateful? Answer:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}" 335 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/eval/ok_vqa_utils.py: -------------------------------------------------------------------------------- 1 | # Those are manual mapping that are not caught by our stemming rules or would 2 | # would be done incorrectly by our automatic stemming rule. In details, 3 | # the keys of the _MANUAL_MATCHES dict contains the original word and the value 4 | # contains the transformation of the word expected by the OKVQA stemming rule. 5 | # These manual rules were found by checking the `raw_answers` and the `answers` 6 | # fields of the released OKVQA dataset and checking all things that were not 7 | # properly mapped by our automatic rules. In particular some of the mapping 8 | # are sometimes constant, e.g. christmas -> christmas which was incorrectly 9 | # singularized by our inflection.singularize. 10 | import re 11 | import nltk 12 | from nltk.corpus.reader import VERB 13 | import inflection 14 | 15 | _MANUAL_MATCHES = { 16 | "police": "police", 17 | "las": "las", 18 | "vegas": "vegas", 19 | "yes": "yes", 20 | "jeans": "jean", 21 | "hell's": "hell", 22 | "domino's": "domino", 23 | "morning": "morn", 24 | "clothes": "cloth", 25 | "are": "are", 26 | "riding": "ride", 27 | "leaves": "leaf", 28 | "dangerous": "danger", 29 | "clothing": "cloth", 30 | "texting": "text", 31 | "kiting": "kite", 32 | "firefighters": "firefight", 33 | "ties": "tie", 34 | "married": "married", 35 | "teething": "teeth", 36 | "gloves": "glove", 37 | "tennis": "tennis", 38 | "dining": "dine", 39 | "directions": "direct", 40 | "waves": "wave", 41 | "christmas": "christmas", 42 | "drives": "drive", 43 | "pudding": "pud", 44 | "coding": "code", 45 | "plating": "plate", 46 | "quantas": "quanta", 47 | "hornes": "horn", 48 | "graves": "grave", 49 | "mating": "mate", 50 | "paned": "pane", 51 | "alertness": "alert", 52 | "sunbathing": "sunbath", 53 | "tenning": "ten", 54 | "wetness": "wet", 55 | "urinating": "urine", 56 | "sickness": "sick", 57 | "braves": "brave", 58 | "firefighting": "firefight", 59 | "lenses": "lens", 60 | "reflections": "reflect", 61 | "backpackers": "backpack", 62 | "eatting": "eat", 63 | "designers": "design", 64 | "curiousity": "curious", 65 | "playfulness": "play", 66 | "blindness": "blind", 67 | "hawke": "hawk", 68 | "tomatoe": "tomato", 69 | "rodeoing": "rodeo", 70 | "brightness": "bright", 71 | "circuses": "circus", 72 | "skateboarders": "skateboard", 73 | "staring": "stare", 74 | "electronics": "electron", 75 | "electicity": "elect", 76 | "mountainous": "mountain", 77 | "socializing": "social", 78 | "hamburgers": "hamburg", 79 | "caves": "cave", 80 | "transitions": "transit", 81 | "wading": "wade", 82 | "creame": "cream", 83 | "toileting": "toilet", 84 | "sautee": "saute", 85 | "buildings": "build", 86 | "belongings": "belong", 87 | "stockings": "stock", 88 | "walle": "wall", 89 | "cumulis": "cumuli", 90 | "travelers": "travel", 91 | "conducter": "conduct", 92 | "browsing": "brows", 93 | "pooping": "poop", 94 | "haircutting": "haircut", 95 | "toppings": "top", 96 | "hearding": "heard", 97 | "sunblocker": "sunblock", 98 | "bases": "base", 99 | "markings": "mark", 100 | "mopeds": "mope", 101 | "kindergartener": "kindergarten", 102 | "pies": "pie", 103 | "scrapbooking": "scrapbook", 104 | "couponing": "coupon", 105 | "meetings": "meet", 106 | "elevators": "elev", 107 | "lowes": "low", 108 | "men's": "men", 109 | "childrens": "children", 110 | "shelves": "shelve", 111 | "paintings": "paint", 112 | "raines": "rain", 113 | "paring": "pare", 114 | "expressions": "express", 115 | "routes": "rout", 116 | "pease": "peas", 117 | "vastness": "vast", 118 | "awning": "awn", 119 | "boy's": "boy", 120 | "drunkenness": "drunken", 121 | "teasing": "teas", 122 | "conferences": "confer", 123 | "ripeness": "ripe", 124 | "suspenders": "suspend", 125 | "earnings": "earn", 126 | "reporters": "report", 127 | "kid's": "kid", 128 | "containers": "contain", 129 | "corgie": "corgi", 130 | "porche": "porch", 131 | "microwaves": "microwave", 132 | "batter's": "batter", 133 | "sadness": "sad", 134 | "apartments": "apart", 135 | "oxygenize": "oxygen", 136 | "striping": "stripe", 137 | "purring": "pure", 138 | "professionals": "profession", 139 | "piping": "pipe", 140 | "farmer's": "farmer", 141 | "potatoe": "potato", 142 | "emirates": "emir", 143 | "womens": "women", 144 | "veteran's": "veteran", 145 | "wilderness": "wilder", 146 | "propellers": "propel", 147 | "alpes": "alp", 148 | "charioteering": "chariot", 149 | "swining": "swine", 150 | "illness": "ill", 151 | "crepte": "crept", 152 | "adhesives": "adhesive", 153 | "regent's": "regent", 154 | "decorations": "decor", 155 | "rabbies": "rabbi", 156 | "overseas": "oversea", 157 | "travellers": "travel", 158 | "casings": "case", 159 | "smugness": "smug", 160 | "doves": "dove", 161 | "nationals": "nation", 162 | "mustange": "mustang", 163 | "ringe": "ring", 164 | "gondoliere": "gondolier", 165 | "vacationing": "vacate", 166 | "reminders": "remind", 167 | "baldness": "bald", 168 | "settings": "set", 169 | "glaced": "glace", 170 | "coniferous": "conifer", 171 | "revelations": "revel", 172 | "personals": "person", 173 | "daughter's": "daughter", 174 | "badness": "bad", 175 | "projections": "project", 176 | "polarizing": "polar", 177 | "vandalizers": "vandal", 178 | "minerals": "miner", 179 | "protesters": "protest", 180 | "controllers": "control", 181 | "weddings": "wed", 182 | "sometimes": "sometime", 183 | "earing": "ear", 184 | } 185 | 186 | 187 | class OKVQAStemmer: 188 | """Stemmer to match OKVQA v1.1 procedure.""" 189 | 190 | def __init__(self): 191 | self._wordnet_lemmatizer = nltk.stem.WordNetLemmatizer() 192 | 193 | def stem(self, input_string): 194 | """Apply stemming.""" 195 | word_and_pos = nltk.pos_tag(nltk.tokenize.word_tokenize(input_string)) 196 | stemmed_words = [] 197 | for w, p in word_and_pos: 198 | if w in _MANUAL_MATCHES: 199 | w = _MANUAL_MATCHES[w] 200 | elif w.endswith("ing"): 201 | w = self._wordnet_lemmatizer.lemmatize(w, VERB) 202 | elif p.startswith("NNS") or p.startswith("NNPS"): 203 | w = inflection.singularize(w) 204 | stemmed_words.append(w) 205 | return " ".join(stemmed_words) 206 | 207 | 208 | stemmer = OKVQAStemmer() 209 | 210 | 211 | def postprocess_ok_vqa_generation(predictions) -> str: 212 | prediction = re.split("Question|Answer|Short", predictions, 1)[0] 213 | prediction = re.split(", ", prediction, 1)[0] 214 | prediction_stem = stemmer.stem(prediction) 215 | return prediction_stem 216 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/eval/rices.py: -------------------------------------------------------------------------------- 1 | import open_clip 2 | import torch 3 | from tqdm import tqdm 4 | import torch 5 | from utils import custom_collate_fn 6 | 7 | 8 | class RICES: 9 | def __init__( 10 | self, 11 | dataset, 12 | device, 13 | batch_size, 14 | vision_encoder_path="ViT-B-32", 15 | vision_encoder_pretrained="openai", 16 | cached_features=None, 17 | ): 18 | self.dataset = dataset 19 | self.device = device 20 | self.batch_size = batch_size 21 | 22 | # Load the model and processor 23 | vision_encoder, _, image_processor = open_clip.create_model_and_transforms( 24 | vision_encoder_path, 25 | pretrained=vision_encoder_pretrained, 26 | ) 27 | self.model = vision_encoder.to(self.device) 28 | self.image_processor = image_processor 29 | 30 | # Precompute features 31 | if cached_features is None: 32 | self.features = self._precompute_features() 33 | else: 34 | self.features = cached_features 35 | 36 | def _precompute_features(self): 37 | features = [] 38 | 39 | # Switch to evaluation mode 40 | self.model.eval() 41 | 42 | # Set up loader 43 | loader = torch.utils.data.DataLoader( 44 | self.dataset, 45 | batch_size=self.batch_size, 46 | collate_fn=custom_collate_fn, 47 | ) 48 | 49 | with torch.no_grad(): 50 | for batch in tqdm( 51 | loader, 52 | desc="Precomputing features for RICES", 53 | ): 54 | batch = batch["image"] 55 | inputs = torch.stack( 56 | [self.image_processor(image) for image in batch] 57 | ).to(self.device) 58 | image_features = self.model.encode_image(inputs) 59 | image_features /= image_features.norm(dim=-1, keepdim=True) 60 | features.append(image_features.detach()) 61 | 62 | features = torch.cat(features) 63 | return features 64 | 65 | def find(self, batch, num_examples): 66 | """ 67 | Get the top num_examples most similar examples to the images. 68 | """ 69 | # Switch to evaluation mode 70 | self.model.eval() 71 | 72 | with torch.no_grad(): 73 | inputs = torch.stack([self.image_processor(image) for image in batch]).to( 74 | self.device 75 | ) 76 | 77 | # Get the feature of the input image 78 | query_feature = self.model.encode_image(inputs) 79 | query_feature /= query_feature.norm(dim=-1, keepdim=True) 80 | query_feature = query_feature.detach().cpu() 81 | 82 | if query_feature.ndim == 1: 83 | query_feature = query_feature.unsqueeze(0) 84 | 85 | # Compute the similarity of the input image to the precomputed features 86 | similarity = (query_feature @ self.features.T).squeeze() 87 | 88 | if similarity.ndim == 1: 89 | similarity = similarity.unsqueeze(0) 90 | 91 | # Get the indices of the 'num_examples' most similar images 92 | indices = similarity.argsort(dim=-1, descending=True)[:, :num_examples] 93 | 94 | # Return with the most similar images last 95 | return [[self.dataset[i] for i in reversed(row)] for row in indices] 96 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/eval/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import torch.nn as nn 5 | from contextlib import suppress 6 | 7 | 8 | def random_seed(seed=42, rank=0): 9 | torch.manual_seed(seed + rank) 10 | np.random.seed(seed + rank) 11 | random.seed(seed + rank) 12 | 13 | 14 | def custom_collate_fn(batch): 15 | """ 16 | Collate function for DataLoader that collates a list of dicts into a dict of lists. 17 | """ 18 | collated_batch = {} 19 | for key in batch[0].keys(): 20 | collated_batch[key] = [item[key] for item in batch] 21 | return collated_batch 22 | 23 | 24 | def compute_effective_num_shots(num_shots, model_type): 25 | """ 26 | Compute the effective number of shots for a given model type. 27 | For example, following Flamingo, 0-shot OF evaluations use two text-only shots. 28 | """ 29 | if model_type == "open_flamingo": 30 | return num_shots if num_shots > 0 else 2 31 | return num_shots 32 | 33 | 34 | def sample_batch_demos_from_query_set(query_set, num_samples, batch_size): 35 | """ 36 | Sample random demonstrations from the query set. 37 | """ 38 | return [random.sample(query_set, num_samples) for _ in range(batch_size)] 39 | 40 | 41 | def get_query_set(train_dataset, query_set_size): 42 | """ 43 | Get a subset of the training dataset to use as the query set. 44 | """ 45 | query_set = np.random.choice(len(train_dataset), query_set_size, replace=False) 46 | return [train_dataset[i] for i in query_set] 47 | 48 | 49 | def prepare_eval_samples(test_dataset, num_samples, batch_size): 50 | """ 51 | Subset the test dataset and return a DataLoader. 52 | """ 53 | random_indices = np.random.choice(len(test_dataset), num_samples, replace=False) 54 | dataset = torch.utils.data.Subset(test_dataset, random_indices) 55 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) 56 | loader = torch.utils.data.DataLoader( 57 | dataset, 58 | batch_size=batch_size, 59 | sampler=sampler, 60 | collate_fn=custom_collate_fn, 61 | ) 62 | return loader 63 | 64 | 65 | def get_indices_of_unique(x): 66 | """ 67 | Return the indices of x that correspond to unique elements. 68 | If value v is unique and two indices in x have value v, the first index is returned. 69 | """ 70 | unique_elements = torch.unique(x) 71 | first_indices = [] 72 | for v in unique_elements: 73 | indices = torch.where(x == v)[0] 74 | first_indices.append(indices[0]) # Take the first index for each unique element 75 | return torch.tensor(first_indices) 76 | 77 | 78 | def unwrap_model(model): 79 | """ 80 | Unwrap a model from a DataParallel or DistributedDataParallel wrapper. 81 | """ 82 | if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): 83 | return model.module 84 | else: 85 | return model 86 | 87 | 88 | def get_predicted_classnames(logprobs, k, class_id_to_name): 89 | """ 90 | Args: 91 | - logprobs shape (B, Y) containing logprobs for each classname 92 | - k: number for top-k 93 | - class_id_to_name: dict mapping class index to classname 94 | 95 | Returns: 96 | - top-k predicted classnames shape (B, k) type str 97 | - top-k logprobs shape (B, k) type float 98 | """ 99 | # convert indices to classnames 100 | _, predictions = torch.topk(logprobs, k=k, dim=1) # shape (B, k) 101 | predicted_classnames = [ 102 | [class_id_to_name[ix] for ix in item] for item in predictions.tolist() 103 | ] 104 | predicted_logprobs = torch.gather(logprobs, 1, predictions) 105 | return predicted_classnames, predicted_logprobs 106 | 107 | 108 | def get_cast_dtype(precision: str): 109 | cast_dtype = None 110 | if precision == "bf16": 111 | cast_dtype = torch.bfloat16 112 | elif precision == "fp16": 113 | cast_dtype = torch.float16 114 | return cast_dtype 115 | 116 | 117 | def get_autocast(precision): 118 | if precision == "amp": 119 | return torch.cuda.amp.autocast 120 | elif precision == "amp_bfloat16" or precision == "amp_bf16": 121 | # amp_bfloat16 is more stable than amp float16 for clip training 122 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) 123 | else: 124 | return suppress 125 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/scripts/cache_rices_features.py: -------------------------------------------------------------------------------- 1 | """ 2 | Cache CLIP features for all images in training split in preparation for RICES 3 | """ 4 | import argparse 5 | import sys 6 | import os 7 | 8 | sys.path.append( 9 | os.path.join( 10 | os.path.dirname(os.path.abspath(__file__)), 11 | "..", 12 | ) 13 | ) 14 | from eval.rices import RICES 15 | from eval.eval_datasets import ( 16 | CaptionDataset, 17 | VQADataset, 18 | ImageNetDataset, 19 | HatefulMemesDataset, 20 | ) 21 | import os 22 | import torch 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument( 26 | "--output_dir", 27 | type=str, 28 | required=True, 29 | help="Directory to save the cached features.", 30 | ) 31 | parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str) 32 | parser.add_argument("--vision_encoder_pretrained", default="openai", type=str) 33 | parser.add_argument("--batch_size", default=256) 34 | 35 | # Per-dataset flags 36 | parser.add_argument( 37 | "--eval_coco", 38 | action="store_true", 39 | default=False, 40 | help="Whether to cache COCO.", 41 | ) 42 | parser.add_argument( 43 | "--eval_vqav2", 44 | action="store_true", 45 | default=False, 46 | help="Whether to cache VQAV2.", 47 | ) 48 | parser.add_argument( 49 | "--eval_ok_vqa", 50 | action="store_true", 51 | default=False, 52 | help="Whether to cache OK-VQA.", 53 | ) 54 | parser.add_argument( 55 | "--eval_vizwiz", 56 | action="store_true", 57 | default=False, 58 | help="Whether to cache VizWiz.", 59 | ) 60 | parser.add_argument( 61 | "--eval_textvqa", 62 | action="store_true", 63 | default=False, 64 | help="Whether to cache TextVQA.", 65 | ) 66 | parser.add_argument( 67 | "--eval_imagenet", 68 | action="store_true", 69 | default=False, 70 | help="Whether to cache ImageNet.", 71 | ) 72 | parser.add_argument( 73 | "--eval_flickr30", 74 | action="store_true", 75 | default=False, 76 | help="Whether to cache Flickr30.", 77 | ) 78 | parser.add_argument( 79 | "--eval_hateful_memes", 80 | action="store_true", 81 | default=False, 82 | help="Whether to cache Hateful Memes.", 83 | ) 84 | 85 | # Dataset arguments 86 | 87 | ## Flickr30 Dataset 88 | parser.add_argument( 89 | "--flickr_image_dir_path", 90 | type=str, 91 | help="Path to the flickr30/flickr30k_images directory.", 92 | default=None, 93 | ) 94 | parser.add_argument( 95 | "--flickr_karpathy_json_path", 96 | type=str, 97 | help="Path to the dataset_flickr30k.json file.", 98 | default=None, 99 | ) 100 | parser.add_argument( 101 | "--flickr_annotations_json_path", 102 | type=str, 103 | help="Path to the dataset_flickr30k_coco_style.json file.", 104 | ) 105 | ## COCO Dataset 106 | parser.add_argument( 107 | "--coco_train_image_dir_path", 108 | type=str, 109 | default=None, 110 | ) 111 | parser.add_argument( 112 | "--coco_val_image_dir_path", 113 | type=str, 114 | default=None, 115 | ) 116 | parser.add_argument( 117 | "--coco_karpathy_json_path", 118 | type=str, 119 | default=None, 120 | ) 121 | parser.add_argument( 122 | "--coco_annotations_json_path", 123 | type=str, 124 | default=None, 125 | ) 126 | 127 | ## VQAV2 Dataset 128 | parser.add_argument( 129 | "--vqav2_train_image_dir_path", 130 | type=str, 131 | default=None, 132 | ) 133 | parser.add_argument( 134 | "--vqav2_train_questions_json_path", 135 | type=str, 136 | default=None, 137 | ) 138 | parser.add_argument( 139 | "--vqav2_train_annotations_json_path", 140 | type=str, 141 | default=None, 142 | ) 143 | 144 | ## OK-VQA Dataset 145 | parser.add_argument( 146 | "--ok_vqa_train_image_dir_path", 147 | type=str, 148 | help="Path to the vqav2/train2014 directory.", 149 | default=None, 150 | ) 151 | parser.add_argument( 152 | "--ok_vqa_train_questions_json_path", 153 | type=str, 154 | help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.", 155 | default=None, 156 | ) 157 | parser.add_argument( 158 | "--ok_vqa_train_annotations_json_path", 159 | type=str, 160 | help="Path to the v2_mscoco_train2014_annotations.json file.", 161 | default=None, 162 | ) 163 | 164 | ## VizWiz Dataset 165 | parser.add_argument( 166 | "--vizwiz_train_image_dir_path", 167 | type=str, 168 | help="Path to the vizwiz train images directory.", 169 | default=None, 170 | ) 171 | parser.add_argument( 172 | "--vizwiz_train_questions_json_path", 173 | type=str, 174 | help="Path to the vizwiz questions json file.", 175 | default=None, 176 | ) 177 | parser.add_argument( 178 | "--vizwiz_train_annotations_json_path", 179 | type=str, 180 | help="Path to the vizwiz annotations json file.", 181 | default=None, 182 | ) 183 | 184 | # TextVQA Dataset 185 | parser.add_argument( 186 | "--textvqa_image_dir_path", 187 | type=str, 188 | help="Path to the textvqa images directory.", 189 | default=None, 190 | ) 191 | parser.add_argument( 192 | "--textvqa_train_questions_json_path", 193 | type=str, 194 | help="Path to the textvqa questions json file.", 195 | default=None, 196 | ) 197 | parser.add_argument( 198 | "--textvqa_train_annotations_json_path", 199 | type=str, 200 | help="Path to the textvqa annotations json file.", 201 | default=None, 202 | ) 203 | 204 | 205 | ## Imagenet dataset 206 | parser.add_argument("--imagenet_root", type=str, default="/tmp") 207 | 208 | ## Hateful Memes dataset 209 | parser.add_argument( 210 | "--hateful_memes_image_dir_path", 211 | type=str, 212 | default=None, 213 | ) 214 | parser.add_argument( 215 | "--hateful_memes_train_annotations_json_path", 216 | type=str, 217 | default=None, 218 | ) 219 | 220 | 221 | def main(): 222 | args, leftovers = parser.parse_known_args() 223 | device_id = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" 224 | if args.eval_flickr30: 225 | print("Caching Flickr30k...") 226 | train_dataset = CaptionDataset( 227 | image_train_dir_path=args.flickr_image_dir_path, 228 | image_val_dir_path=None, 229 | annotations_path=args.flickr_karpathy_json_path, 230 | is_train=True, 231 | dataset_name="flickr", 232 | ) 233 | rices_dataset = RICES( 234 | train_dataset, 235 | device_id, 236 | args.batch_size, 237 | vision_encoder_path=args.vision_encoder_path, 238 | vision_encoder_pretrained=args.vision_encoder_pretrained, 239 | ) 240 | torch.save( 241 | rices_dataset.features, 242 | os.path.join(args.output_dir, "flickr30.pkl"), 243 | ) 244 | 245 | if args.eval_coco: 246 | print("Caching COCO...") 247 | train_dataset = CaptionDataset( 248 | image_train_dir_path=args.coco_train_image_dir_path, 249 | image_val_dir_path=args.coco_val_image_dir_path, 250 | annotations_path=args.coco_karpathy_json_path, 251 | is_train=True, 252 | dataset_name="coco", 253 | ) 254 | rices_dataset = RICES( 255 | train_dataset, 256 | device_id, 257 | args.batch_size, 258 | vision_encoder_path=args.vision_encoder_path, 259 | vision_encoder_pretrained=args.vision_encoder_pretrained, 260 | ) 261 | torch.save( 262 | rices_dataset.features, 263 | os.path.join(args.output_dir, "coco.pkl"), 264 | ) 265 | 266 | if args.eval_ok_vqa: 267 | print("Caching OK-VQA...") 268 | train_dataset = VQADataset( 269 | image_dir_path=args.ok_vqa_train_image_dir_path, 270 | question_path=args.ok_vqa_train_questions_json_path, 271 | annotations_path=args.ok_vqa_train_annotations_json_path, 272 | is_train=True, 273 | dataset_name="ok_vqa", 274 | ) 275 | rices_dataset = RICES( 276 | train_dataset, 277 | device_id, 278 | args.batch_size, 279 | vision_encoder_path=args.vision_encoder_path, 280 | vision_encoder_pretrained=args.vision_encoder_pretrained, 281 | ) 282 | torch.save( 283 | rices_dataset.features, 284 | os.path.join(args.output_dir, "ok_vqa.pkl"), 285 | ) 286 | 287 | if args.eval_vizwiz: 288 | print("Caching VizWiz...") 289 | train_dataset = VQADataset( 290 | image_dir_path=args.vizwiz_train_image_dir_path, 291 | question_path=args.vizwiz_train_questions_json_path, 292 | annotations_path=args.vizwiz_train_annotations_json_path, 293 | is_train=True, 294 | dataset_name="vizwiz", 295 | ) 296 | rices_dataset = RICES( 297 | train_dataset, 298 | device_id, 299 | args.batch_size, 300 | vision_encoder_path=args.vision_encoder_path, 301 | vision_encoder_pretrained=args.vision_encoder_pretrained, 302 | ) 303 | torch.save( 304 | rices_dataset.features, 305 | os.path.join(args.output_dir, "vizwiz.pkl"), 306 | ) 307 | 308 | if args.eval_vqav2: 309 | print("Caching VQAv2...") 310 | train_dataset = VQADataset( 311 | image_dir_path=args.vqav2_train_image_dir_path, 312 | question_path=args.vqav2_train_questions_json_path, 313 | annotations_path=args.vqav2_train_annotations_json_path, 314 | is_train=True, 315 | dataset_name="vqav2", 316 | ) 317 | rices_dataset = RICES( 318 | train_dataset, 319 | device_id, 320 | args.batch_size, 321 | vision_encoder_path=args.vision_encoder_path, 322 | vision_encoder_pretrained=args.vision_encoder_pretrained, 323 | ) 324 | torch.save( 325 | rices_dataset.features, 326 | os.path.join(args.output_dir, "vqav2.pkl"), 327 | ) 328 | 329 | if args.eval_textvqa: 330 | print("Caching TextVQA...") 331 | train_dataset = VQADataset( 332 | image_dir_path=args.textvqa_image_dir_path, 333 | question_path=args.textvqa_train_questions_json_path, 334 | annotations_path=args.textvqa_train_annotations_json_path, 335 | is_train=True, 336 | dataset_name="textvqa", 337 | ) 338 | rices_dataset = RICES( 339 | train_dataset, 340 | device_id, 341 | args.batch_size, 342 | vision_encoder_path=args.vision_encoder_path, 343 | vision_encoder_pretrained=args.vision_encoder_pretrained, 344 | ) 345 | torch.save( 346 | rices_dataset.features, 347 | os.path.join(args.output_dir, "textvqa.pkl"), 348 | ) 349 | 350 | if args.eval_hateful_memes: 351 | print("Caching Hateful Memes...") 352 | train_dataset = HatefulMemesDataset( 353 | image_dir_path=args.hateful_memes_image_dir_path, 354 | annotations_path=args.hateful_memes_train_annotations_json_path, 355 | ) 356 | rices_dataset = RICES( 357 | train_dataset, 358 | device_id, 359 | args.batch_size, 360 | vision_encoder_path=args.vision_encoder_path, 361 | vision_encoder_pretrained=args.vision_encoder_pretrained, 362 | ) 363 | torch.save( 364 | rices_dataset.features, 365 | os.path.join(args.output_dir, "hateful_memes.pkl"), 366 | ) 367 | 368 | 369 | if __name__ == "__main__": 370 | main() 371 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/scripts/convert_mmc4_to_wds.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import uuid 5 | import zipfile 6 | from PIL import Image 7 | import base64 8 | from io import BytesIO 9 | 10 | import braceexpand 11 | import webdataset as wds 12 | 13 | arg_parser = argparse.ArgumentParser() 14 | arg_parser.add_argument( 15 | "--output_dir", 16 | type=str, 17 | help="Pass in the directory where the output shards (as tar files) will be written to.", 18 | ) 19 | arg_parser.add_argument( 20 | "--zip_files", 21 | type=str, 22 | help="Pass in a list of MMC4 shards in the format path_to_shard/shard_{0..23098}.zip", 23 | ) 24 | arg_parser.add_argument( 25 | "--image_dir", 26 | type=str, 27 | help="Pass in the directory where the images have been downloaded to.", 28 | ) 29 | arg_parser.add_argument( 30 | "--num_files_per_shard", 31 | type=int, 32 | default=1000, 33 | ) 34 | args = arg_parser.parse_args() 35 | 36 | 37 | def main(): 38 | os.makedirs(args.output_dir, exist_ok=True) 39 | 40 | doc_shards = list(braceexpand.braceexpand(args.zip_files)) 41 | 42 | with wds.ShardWriter(args.output_dir + "/%09d.tar") as sink: 43 | for idx in range(len(doc_shards)): 44 | # Open the ZIP archive and extract the JSON file 45 | with zipfile.ZipFile(doc_shards[idx], "r") as zip_file: 46 | # Assumes the JSON file is the first file in the archive 47 | json_filename = zip_file.namelist()[0] 48 | with zip_file.open(json_filename, "r") as json_file: 49 | for sample_data in json_file: 50 | # get image names from json 51 | sample_data = json.loads(sample_data) 52 | image_info = sample_data["image_info"] 53 | image_names = [image["image_name"] for image in image_info] 54 | 55 | # Add each image to the tar file 56 | for img_idx, image_name in enumerate(image_names): 57 | try: 58 | # load image 59 | img = Image.open( 60 | os.path.join(args.image_dir, str(idx), image_name) 61 | ).convert("RGB") 62 | buffered = BytesIO() 63 | img.save(buffered, format="JPEG") 64 | img_str = base64.b64encode(buffered.getvalue()) 65 | 66 | # convert to base64 67 | sample_data["image_info"][img_idx][ 68 | "image_base64" 69 | ] = img_str.decode("utf-8") 70 | except FileNotFoundError: 71 | print( 72 | f"Did not find {image_name} downloaded. This can happen if the url is now 404." 73 | ) 74 | except Exception as e: 75 | print(f"Error processing {image_name}: {e}") 76 | 77 | key_str = uuid.uuid4().hex 78 | sink.write({"__key__": key_str, "json": sample_data}) 79 | 80 | if (idx + 1) % args.num_files_per_shard == 0: 81 | sink.next_stream() 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/scripts/fill_vqa_testdev_results.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper scripts to prepare a vqa test-dev evaluation for EvalAI submission. 3 | Note: EvalAI requires VQAv2 submissions to have predictions for all the questions in the test2015 set, not just the test-dev set. 4 | Given a json with a subset of the vqa questions, fill in the rest of the questions with an empty string as the model prediction. 5 | """ 6 | import json 7 | import sys 8 | import os 9 | 10 | sys.path.append( 11 | os.path.join( 12 | os.path.dirname(os.path.abspath(__file__)), 13 | "..", 14 | ) 15 | ) 16 | from eval.vqa_metric import VQAEval 17 | 18 | postprocessor = VQAEval(None, None) 19 | 20 | 21 | def fill_vizwiz_test_json( 22 | input_path, 23 | output_path, 24 | vqa_test_questions_json_path, 25 | ): 26 | # read the input json and build a set with all question_ids 27 | with open(input_path, "r") as f: 28 | input_json = json.load(f) 29 | 30 | # postprocess answers 31 | question_id_to_answer = {} 32 | for q in input_json: 33 | resAns = q["answer"] 34 | resAns = resAns.replace("\n", " ") 35 | resAns = resAns.replace("\t", " ") 36 | resAns = resAns.strip() 37 | resAns = postprocessor.processPunctuation(resAns) 38 | resAns = postprocessor.processDigitArticle(resAns) 39 | question_id_to_answer[q["question_id"]] = resAns 40 | 41 | # read the vqa test json to get all the qustion_ids that need to be filled 42 | with open(vqa_test_questions_json_path, "r") as f: 43 | vqa_test_json = json.load(f) 44 | vqa_test_json = vqa_test_json["questions"] 45 | 46 | # if the question_id is not in the set, add it to the copy of the input json with an empty string as the answer 47 | output_json = [] 48 | for q in vqa_test_json: 49 | output_json.append( 50 | { 51 | "image": q["image_id"], 52 | "answer": question_id_to_answer.get(q["question_id"], ""), 53 | } 54 | ) 55 | 56 | # write the json to the output path 57 | with open(output_path, "w") as f: 58 | json.dump(output_json, f) 59 | 60 | 61 | def fill_vqav2_test_json( 62 | input_path, 63 | output_path, 64 | vqa_test_questions_json_path, 65 | ): 66 | # read the input json and build a set with all question_ids 67 | with open(input_path, "r") as f: 68 | input_json = json.load(f) 69 | question_ids = set() 70 | for q in input_json: 71 | question_ids.add(q["question_id"]) 72 | 73 | # make a copy of the input json 74 | output_json = [] 75 | for q in input_json: 76 | resAns = q["answer"] 77 | resAns = resAns.replace("\n", " ") 78 | resAns = resAns.replace("\t", " ") 79 | resAns = resAns.strip() 80 | resAns = postprocessor.processPunctuation(resAns) 81 | resAns = postprocessor.processDigitArticle(resAns) 82 | q["answer"] = resAns 83 | output_json.append(q) 84 | 85 | # read the vqa test json to get all the qustion_ids that need to be filled 86 | with open(vqa_test_questions_json_path, "r") as f: 87 | vqa_test_json = json.load(f) 88 | vqa_test_json = vqa_test_json["questions"] 89 | 90 | # if the question_id is not in the set, add it to the copy of the input json with an empty string as the answer 91 | for q in vqa_test_json: 92 | if q["question_id"] not in question_ids: 93 | output_json.append( 94 | { 95 | "question_id": q["question_id"], 96 | "answer": "", 97 | } 98 | ) 99 | 100 | # write the json to the output path 101 | with open(output_path, "w") as f: 102 | json.dump(output_json, f) 103 | 104 | 105 | if __name__ == "__main__": 106 | import argparse 107 | 108 | parser = argparse.ArgumentParser() 109 | parser.add_argument( 110 | "--dataset", 111 | type=str, 112 | choices=["vqav2", "vizwiz"], 113 | ) 114 | parser.add_argument( 115 | "--input_path", 116 | type=str, 117 | help="Path to the json file with the subset of the vqa test-dev questions.", 118 | ) 119 | parser.add_argument( 120 | "--vqa_test_questions_json_path", 121 | type=str, 122 | help="Path to the json file with all the vqa test questions.", 123 | ) 124 | parser.add_argument( 125 | "--output_path", 126 | type=str, 127 | help="Path to store the filled json.", 128 | ) 129 | args = parser.parse_args() 130 | 131 | if args.dataset == "vqav2": 132 | fill_vqav2_test_json( 133 | args.input_path, 134 | args.output_path, 135 | args.vqa_test_questions_json_path, 136 | ) 137 | else: 138 | fill_vizwiz_test_json( 139 | args.input_path, 140 | args.output_path, 141 | args.vqa_test_questions_json_path, 142 | ) 143 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yueyang130/DeeR-VLA/b30de502b8d40cd7326cabb8131d4ad477748e56/open_flamingo/open_flamingo/src/__init__.py -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/src/factory.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | import open_clip 5 | 6 | from .flamingo import Flamingo 7 | from .flamingo_lm import FlamingoLMMixin 8 | from .utils import extend_instance 9 | 10 | 11 | def create_model_and_transforms( 12 | clip_vision_encoder_path: str, 13 | clip_vision_encoder_pretrained: str, 14 | lang_encoder_path: str, 15 | tokenizer_path: str, 16 | cross_attn_every_n_layers: int = 1, 17 | use_local_files: bool = False, 18 | decoder_layers_attr_name: str = None, 19 | freeze_lm_embeddings: bool = False, 20 | cache_dir: Optional[str] = None, 21 | **flamingo_kwargs, 22 | ): 23 | """ 24 | Initialize a Flamingo model from a pretrained vision encoder and language encoder. 25 | Appends special tokens to the tokenizer and freezes backbones. 26 | 27 | Args: 28 | clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32") 29 | clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k") 30 | lang_encoder_path (str): path to pretrained language encoder 31 | tokenizer_path (str): path to pretrained tokenizer 32 | cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1. 33 | use_local_files (bool, optional): whether to use local files. Defaults to False. 34 | decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. 35 | freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver. 36 | cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights. 37 | Returns: 38 | Flamingo: Flamingo model from pretrained vision and language encoders 39 | Image processor: Pipeline to preprocess input images 40 | Tokenizer: A tokenizer for the language model 41 | """ 42 | vision_encoder, _, image_processor = open_clip.create_model_and_transforms( 43 | clip_vision_encoder_path, 44 | pretrained=clip_vision_encoder_pretrained, 45 | cache_dir=cache_dir, 46 | ) 47 | # set the vision encoder to output the visual features 48 | vision_encoder.visual.output_tokens = True 49 | 50 | text_tokenizer = AutoTokenizer.from_pretrained( 51 | tokenizer_path, 52 | local_files_only=use_local_files, 53 | trust_remote_code=True, 54 | cache_dir=cache_dir, 55 | ) 56 | # add Flamingo special tokens to the tokenizer 57 | text_tokenizer.add_special_tokens( 58 | {"additional_special_tokens": ["<|endofchunk|>", "", ""]} 59 | ) 60 | if text_tokenizer.pad_token is None: 61 | # Issue: GPT models don't have a pad token, which we use to 62 | # modify labels for the loss. 63 | text_tokenizer.add_special_tokens({"pad_token": ""}) 64 | 65 | lang_encoder = AutoModelForCausalLM.from_pretrained( 66 | lang_encoder_path, 67 | local_files_only=use_local_files, 68 | trust_remote_code=True, 69 | cache_dir=cache_dir, 70 | ) 71 | 72 | # hacks for MPT-1B, which doesn't have a get_input_embeddings method 73 | if "mpt-1b-redpajama-200b" in lang_encoder_path: 74 | 75 | class EmbeddingFnMixin: 76 | def get_input_embeddings(self): 77 | return self.transformer.wte 78 | 79 | def set_input_embeddings(self, new_embeddings): 80 | self.transformer.wte = new_embeddings 81 | 82 | extend_instance(lang_encoder, EmbeddingFnMixin) 83 | 84 | # convert LM to FlamingoLM 85 | extend_instance(lang_encoder, FlamingoLMMixin) 86 | 87 | if decoder_layers_attr_name is None: 88 | decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) 89 | lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) 90 | lang_encoder.resize_token_embeddings(len(text_tokenizer)) 91 | 92 | model = Flamingo( 93 | vision_encoder, 94 | lang_encoder, 95 | text_tokenizer.encode("<|endofchunk|>")[-1], 96 | text_tokenizer.encode("")[-1], 97 | vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][ 98 | "width" 99 | ], 100 | cross_attn_every_n_layers=cross_attn_every_n_layers, 101 | **flamingo_kwargs, 102 | ) 103 | 104 | # Freeze all parameters 105 | model.requires_grad_(False) 106 | assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 107 | 108 | # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings 109 | model.perceiver.requires_grad_(True) 110 | model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) 111 | if not freeze_lm_embeddings: 112 | model.lang_encoder.get_input_embeddings().requires_grad_(True) 113 | # TODO: investigate also training the output embeddings when untied 114 | 115 | print( 116 | f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" 117 | ) 118 | 119 | return model, image_processor, text_tokenizer 120 | 121 | 122 | def _infer_decoder_layers_attr_name(model): 123 | for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: 124 | if k.lower() in model.__class__.__name__.lower(): 125 | return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] 126 | 127 | raise ValueError( 128 | f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually." 129 | ) 130 | 131 | 132 | __KNOWN_DECODER_LAYERS_ATTR_NAMES = { 133 | "opt": "model.decoder.layers", 134 | "gptj": "transformer.h", 135 | "gpt-j": "transformer.h", 136 | "pythia": "gpt_neox.layers", 137 | "llama": "model.layers", 138 | "gptneoxforcausallm": "gpt_neox.layers", 139 | "mpt": "transformer.blocks", 140 | "mosaicgpt": "transformer.blocks", 141 | } 142 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/src/flamingo_lm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .helpers import GatedCrossAttentionBlock 3 | from .utils import getattr_recursive, setattr_recursive 4 | import copy 5 | 6 | class FlamingoLayer(nn.Module): 7 | """ 8 | FlamingoLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer. 9 | """ 10 | 11 | def __init__( 12 | self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False, residual=False 13 | ): 14 | super().__init__() 15 | self.gated_cross_attn_layer = gated_cross_attn_layer 16 | self.decoder_layer = decoder_layer 17 | self.vis_x = None 18 | self.media_locations = None 19 | self.residual = residual 20 | 21 | if self.gated_cross_attn_layer is not None: 22 | self.gated_cross_attn_layer._use_gradient_checkpointing = ( 23 | gradient_checkpointing 24 | ) 25 | self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing 26 | 27 | def clone_parameters(self): 28 | self.res_layer = copy.deepcopy(self.gated_cross_attn_layer) 29 | if self.res_layer is not None: 30 | self.res_layer.requires_grad_(False) 31 | 32 | def is_conditioned(self) -> bool: 33 | """Check whether the layer is conditioned.""" 34 | return self.vis_x is not None and self.media_locations is not None 35 | 36 | # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/) 37 | def condition_vis_x(self, vis_x): 38 | self.vis_x = vis_x 39 | 40 | def condition_media_locations(self, media_locations): 41 | self.media_locations = media_locations 42 | 43 | def condition_use_cached_media(self, use_cached_media): 44 | self.use_cached_media = use_cached_media 45 | 46 | def forward( 47 | self, 48 | lang_x, 49 | attention_mask=None, 50 | **decoder_layer_kwargs, 51 | ): 52 | # Cross attention 53 | if self.gated_cross_attn_layer is not None: 54 | if self.vis_x is None: 55 | raise ValueError("vis_x must be conditioned before forward pass") 56 | 57 | if self.media_locations is None: 58 | raise ValueError( 59 | "media_locations must be conditioned before forward pass" 60 | ) 61 | 62 | lang_x = self.gated_cross_attn_layer( 63 | lang_x, 64 | self.vis_x, 65 | media_locations=self.media_locations, 66 | use_cached_media=self.use_cached_media, 67 | ) 68 | 69 | # Residual 70 | if self.residual and self.res_layer is not None: 71 | lang_x_res = self.res_layer( 72 | lang_x, 73 | self.vis_x, 74 | media_locations=self.media_locations, 75 | attend_previous=self.attend_previous, 76 | ) 77 | lang_x = (lang_x + lang_x_res) / 2.0 78 | 79 | # Normal decoder layer 80 | lang_x = self.decoder_layer( 81 | lang_x, attention_mask=attention_mask, **decoder_layer_kwargs 82 | ) 83 | return lang_x 84 | 85 | # def forward( 86 | # self, 87 | # lang_x, 88 | # past_key_value=None, 89 | # attn_bias=None, 90 | # attention_mask=None, 91 | # is_causal: bool = True, 92 | # ): 93 | # # Cross attention 94 | # if self.gated_cross_attn_layer is not None: 95 | # if self.vis_x is None: 96 | # raise ValueError("vis_x must be conditioned before forward pass") 97 | 98 | # if self.media_locations is None: 99 | # raise ValueError( 100 | # "media_locations must be conditioned before forward pass" 101 | # ) 102 | 103 | # lang_x = self.gated_cross_attn_layer( 104 | # lang_x, 105 | # self.vis_x, 106 | # media_locations=self.media_locations, 107 | # use_cached_media=self.use_cached_media, 108 | # ) 109 | 110 | # # Residual 111 | # if self.residual and self.res_layer is not None: 112 | # lang_x_res = self.res_layer( 113 | # lang_x, 114 | # self.vis_x, 115 | # media_locations=self.media_locations, 116 | # attend_previous=self.attend_previous, 117 | # ) 118 | # lang_x = (lang_x + lang_x_res) / 2.0 119 | 120 | # # Normal decoder layer 121 | # lang_x = self.decoder_layer( 122 | # lang_x, 123 | # past_key_value=past_key_value, 124 | # attn_bias=attn_bias, 125 | # attention_mask=attention_mask, 126 | # is_causal=is_causal 127 | # ) 128 | # return lang_x 129 | 130 | 131 | class FlamingoLMMixin(nn.Module): 132 | """ 133 | Mixin to add cross-attention layers to a language model. 134 | """ 135 | 136 | def set_decoder_layers_attr_name(self, decoder_layers_attr_name): 137 | self.decoder_layers_attr_name = decoder_layers_attr_name 138 | 139 | def _get_decoder_layers(self): 140 | return getattr_recursive(self, self.decoder_layers_attr_name) 141 | 142 | def _set_decoder_layers(self, value): 143 | setattr_recursive(self, self.decoder_layers_attr_name, value) 144 | 145 | def _delete_decoder_layers(self, indices): 146 | indices = sorted(indices, reverse=True) 147 | print(f'deleting layers {indices} in Flamingo...') 148 | layers = self._get_decoder_layers() 149 | for i in indices: 150 | del layers[i] 151 | del self.gated_cross_attn_layers[i] 152 | del self.old_decoder_blocks[i] # original language self-attention layers 153 | self.config.n_layers = len(self._get_decoder_layers()) 154 | print(f'Now the number of layer is {len(self._get_decoder_layers())}') 155 | 156 | def init_flamingo( 157 | self, 158 | media_token_id, 159 | lang_hidden_size, 160 | vis_hidden_size, 161 | cross_attn_every_n_layers, 162 | gradient_checkpointing, 163 | residual=False, 164 | ): 165 | """ 166 | Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations. 167 | """ 168 | print('-'*100) 169 | print(self.decoder_layers_attr_name) 170 | self.old_decoder_blocks = self._get_decoder_layers() 171 | self.gated_cross_attn_layers = nn.ModuleList( 172 | [ 173 | GatedCrossAttentionBlock( 174 | dim=lang_hidden_size, dim_visual=vis_hidden_size 175 | ) 176 | if (layer_idx + 1) % cross_attn_every_n_layers == 0 177 | else None 178 | for layer_idx, _ in enumerate(self._get_decoder_layers()) 179 | ] 180 | ) 181 | self.init_flamingo_layers(gradient_checkpointing, residual=residual) 182 | self.media_token_id = media_token_id 183 | self.initialized_flamingo = True 184 | self._use_cached_vision_x = False 185 | 186 | def init_flamingo_layers(self, gradient_checkpointing, residual=False): 187 | """ 188 | Re initializes the FlamingoLayers. 189 | Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks 190 | """ 191 | self._set_decoder_layers( 192 | nn.ModuleList( 193 | [ 194 | FlamingoLayer( 195 | gated_cross_attn_layer, decoder_layer, gradient_checkpointing, residual=residual 196 | ) 197 | for gated_cross_attn_layer, decoder_layer in zip( 198 | self.gated_cross_attn_layers, self.old_decoder_blocks 199 | ) 200 | ] 201 | ) 202 | ) 203 | 204 | def forward(self, input_ids, attention_mask, **kwargs): 205 | """Condition the Flamingo layers on the media locations before forward()""" 206 | if not self.initialized_flamingo: 207 | raise ValueError( 208 | "Flamingo layers are not initialized. Please call `init_flamingo` first." 209 | ) 210 | 211 | media_locations = input_ids == self.media_token_id 212 | 213 | # if there are media already cached and we're generating and there are no media tokens in the input, 214 | # we'll assume that ALL input tokens should attend to the last previous media that is cached. 215 | # this is especially important for HF generate() compatibility, since generate() calls forward() 216 | # repeatedly one token at a time (with no media tokens). 217 | # without this check, the model would not attend to any images when generating (after the first token) 218 | use_cached_media_locations = ( 219 | self._use_cached_vision_x 220 | and self.is_conditioned() 221 | and not media_locations.any() 222 | ) 223 | 224 | for layer in self._get_decoder_layers(): 225 | if not use_cached_media_locations: 226 | layer.condition_media_locations(media_locations) 227 | layer.condition_use_cached_media(use_cached_media_locations) 228 | 229 | # package arguments for the other parent's forward. since we don't know the order of the arguments, 230 | # make them all kwargs 231 | kwargs["input_ids"] = input_ids 232 | kwargs["attention_mask"] = attention_mask 233 | return super().forward(**kwargs) # Call the other parent's forward method 234 | 235 | def is_conditioned(self) -> bool: 236 | """Check whether all decoder layers are already conditioned.""" 237 | return all(l.is_conditioned() for l in self._get_decoder_layers()) 238 | 239 | def clone_parameters(self): 240 | for layer in self._get_decoder_layers(): 241 | layer.clone_parameters() 242 | 243 | def clear_conditioned_layers(self): 244 | for layer in self._get_decoder_layers(): 245 | layer.condition_vis_x(None) 246 | layer.condition_media_locations(None) 247 | layer.condition_use_cached_media(None) 248 | 249 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/src/helpers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on: https://github.com/lucidrains/flamingo-pytorch 3 | """ 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | from einops_exts import rearrange_many 8 | from torch import einsum, nn 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def FeedForward(dim, mult=4): 16 | inner_dim = int(dim * mult) 17 | return nn.Sequential( 18 | nn.LayerNorm(dim), 19 | nn.Linear(dim, inner_dim, bias=False), 20 | nn.GELU(), 21 | nn.Linear(inner_dim, dim, bias=False), 22 | ) 23 | 24 | 25 | class PerceiverAttention(nn.Module): 26 | def __init__(self, *, dim, dim_head=64, heads=8): 27 | super().__init__() 28 | self.scale = dim_head**-0.5 29 | self.heads = heads 30 | inner_dim = dim_head * heads 31 | 32 | self.norm_media = nn.LayerNorm(dim) 33 | self.norm_latents = nn.LayerNorm(dim) 34 | 35 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 36 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 37 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 38 | 39 | def forward(self, x, latents): 40 | """ 41 | Args: 42 | x (torch.Tensor): image features 43 | shape (b, T, n1, D) 44 | latent (torch.Tensor): latent features 45 | shape (b, T, n2, D) 46 | """ 47 | x = self.norm_media(x) 48 | latents = self.norm_latents(latents) 49 | 50 | h = self.heads 51 | 52 | q = self.to_q(latents) 53 | kv_input = torch.cat((x, latents), dim=-2) 54 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 55 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) 56 | q = q * self.scale 57 | 58 | # attention 59 | sim = einsum("... i d, ... j d -> ... i j", q, k) 60 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 61 | attn = sim.softmax(dim=-1) 62 | 63 | out = einsum("... i j, ... j d -> ... i d", attn, v) 64 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) 65 | return self.to_out(out) 66 | 67 | 68 | class PerceiverResampler(nn.Module): 69 | def __init__( 70 | self, 71 | *, 72 | dim, 73 | depth=6, 74 | dim_head=64, 75 | heads=8, 76 | num_latents=64, 77 | max_num_media=None, 78 | max_num_frames=None, 79 | ff_mult=4, 80 | ): 81 | super().__init__() 82 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 83 | self.frame_embs = ( 84 | nn.Parameter(torch.randn(max_num_frames, dim)) 85 | if exists(max_num_frames) 86 | else None 87 | ) 88 | self.media_time_embs = ( 89 | nn.Parameter(torch.randn(max_num_media, 1, dim)) 90 | if exists(max_num_media) 91 | else None 92 | ) 93 | 94 | self.layers = nn.ModuleList([]) 95 | for _ in range(depth): 96 | self.layers.append( 97 | nn.ModuleList( 98 | [ 99 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 100 | FeedForward(dim=dim, mult=ff_mult), 101 | ] 102 | ) 103 | ) 104 | 105 | self.norm = nn.LayerNorm(dim) 106 | 107 | def forward(self, x): 108 | """ 109 | Args: 110 | x (torch.Tensor): image features 111 | shape (b, T, F, v, D) 112 | Returns: 113 | shape (b, T, n, D) where n is self.num_latents 114 | """ 115 | b, T, F, v = x.shape[:4] 116 | 117 | # frame and media time embeddings 118 | if exists(self.frame_embs): 119 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) 120 | x = x + frame_embs 121 | x = rearrange( 122 | x, "b T F v d -> b T (F v) d" 123 | ) # flatten the frame and spatial dimensions 124 | if exists(self.media_time_embs): 125 | x = x + self.media_time_embs[:T] 126 | 127 | # blocks 128 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) 129 | for attn, ff in self.layers: 130 | latents = attn(x, latents) + latents 131 | latents = ff(latents) + latents 132 | return self.norm(latents) 133 | 134 | 135 | # gated cross attention 136 | class MaskedCrossAttention(nn.Module): 137 | def __init__( 138 | self, 139 | *, 140 | dim, 141 | dim_visual, 142 | dim_head=64, 143 | heads=8, 144 | only_attend_immediate_media=True, 145 | ): 146 | super().__init__() 147 | self.scale = dim_head**-0.5 148 | self.heads = heads 149 | inner_dim = dim_head * heads 150 | 151 | self.norm = nn.LayerNorm(dim) 152 | 153 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 154 | self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False) 155 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 156 | 157 | # whether for text to only attend to immediate preceding image, or all previous images 158 | self.only_attend_immediate_media = only_attend_immediate_media 159 | 160 | def forward(self, x, media, media_locations=None, use_cached_media=False): 161 | """ 162 | Args: 163 | x (torch.Tensor): text features 164 | shape (B, T_txt, D_txt) 165 | media (torch.Tensor): image features 166 | shape (B, T_img, n, D_img) where n is the dim of the latents 167 | media_locations: boolean mask identifying the media tokens in x 168 | shape (B, T_txt) 169 | use_cached_media: bool 170 | If true, treat all of x as if they occur after the last media 171 | registered in media_locations. T_txt does not need to exactly 172 | equal media_locations.shape[1] in this case 173 | """ 174 | 175 | if not use_cached_media: 176 | assert ( 177 | media_locations.shape[1] == x.shape[1] 178 | ), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}" 179 | 180 | T_txt = x.shape[1] 181 | _, T_img, n = media.shape[:3] 182 | h = self.heads 183 | 184 | x = self.norm(x) 185 | 186 | q = self.to_q(x) 187 | media = rearrange(media, "b t n d -> b (t n) d") 188 | 189 | k, v = self.to_kv(media).chunk(2, dim=-1) 190 | q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) 191 | 192 | q = q * self.scale 193 | 194 | sim = einsum("... i d, ... j d -> ... i j", q, k) 195 | 196 | if exists(media_locations): 197 | media_time = torch.arange(T_img, device=x.device) + 1 198 | 199 | if use_cached_media: 200 | # text time is set to the last cached media location 201 | text_time = repeat( 202 | torch.count_nonzero(media_locations, dim=1), 203 | "b -> b i", 204 | i=T_txt, 205 | ) 206 | else: 207 | # at each boolean of True, increment the time counter (relative to media time) 208 | text_time = media_locations.cumsum(dim=-1) 209 | 210 | # text time must equal media time if only attending to most immediate image 211 | # otherwise, as long as text time is greater than media time (if attending to all previous images / media) 212 | mask_op = torch.eq if self.only_attend_immediate_media else torch.ge 213 | 214 | text_to_media_mask = mask_op( 215 | rearrange(text_time, "b i -> b 1 i 1"), 216 | repeat(media_time, "j -> 1 1 1 (j n)", n=n), 217 | ) 218 | sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max) 219 | 220 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 221 | attn = sim.softmax(dim=-1) 222 | 223 | if exists(media_locations) and self.only_attend_immediate_media: 224 | # any text without a preceding media needs to have attention zeroed out 225 | text_without_media_mask = text_time == 0 226 | text_without_media_mask = rearrange( 227 | text_without_media_mask, "b i -> b 1 i 1" 228 | ) 229 | attn = attn.masked_fill(text_without_media_mask, 0.0) 230 | 231 | out = einsum("... i j, ... j d -> ... i d", attn, v) 232 | out = rearrange(out, "b h n d -> b n (h d)") 233 | return self.to_out(out) 234 | 235 | 236 | class GatedCrossAttentionBlock(nn.Module): 237 | def __init__( 238 | self, 239 | *, 240 | dim, 241 | dim_visual, 242 | dim_head=64, 243 | heads=8, 244 | ff_mult=4, 245 | only_attend_immediate_media=True, 246 | ): 247 | super().__init__() 248 | self.attn = MaskedCrossAttention( 249 | dim=dim, 250 | dim_visual=dim_visual, 251 | dim_head=dim_head, 252 | heads=heads, 253 | only_attend_immediate_media=only_attend_immediate_media, 254 | ) 255 | self.attn_gate = nn.Parameter(torch.tensor([0.0])) 256 | 257 | self.ff = FeedForward(dim, mult=ff_mult) 258 | self.ff_gate = nn.Parameter(torch.tensor([0.0])) 259 | 260 | def forward( 261 | self, 262 | x, 263 | media, 264 | media_locations=None, 265 | use_cached_media=False, 266 | ): 267 | x = ( 268 | self.attn( 269 | x, 270 | media, 271 | media_locations=media_locations, 272 | use_cached_media=use_cached_media, 273 | ) 274 | * self.attn_gate.tanh() 275 | + x 276 | ) 277 | x = self.ff(x) * self.ff_gate.tanh() + x 278 | 279 | return x 280 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/src/utils.py: -------------------------------------------------------------------------------- 1 | def extend_instance(obj, mixin): 2 | """Apply mixins to a class instance after creation""" 3 | base_cls = obj.__class__ 4 | base_cls_name = obj.__class__.__name__ 5 | obj.__class__ = type( 6 | base_cls_name, (mixin, base_cls), {} 7 | ) # mixin needs to go first for our forward() logic to work 8 | 9 | 10 | def getattr_recursive(obj, att): 11 | """ 12 | Return nested attribute of obj 13 | Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c 14 | """ 15 | if att == "": 16 | return obj 17 | i = att.find(".") 18 | if i < 0: 19 | return getattr(obj, att) 20 | else: 21 | return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) 22 | 23 | 24 | def setattr_recursive(obj, att, val): 25 | """ 26 | Set nested attribute of obj 27 | Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val 28 | """ 29 | if "." in att: 30 | obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) 31 | setattr(obj, att.split(".")[-1], val) 32 | 33 | 34 | def apply_with_stopping_condition( 35 | module, apply_fn, apply_condition=None, stopping_condition=None, **other_args 36 | ): 37 | if stopping_condition(module): 38 | return 39 | if apply_condition(module): 40 | apply_fn(module, **other_args) 41 | for child in module.children(): 42 | apply_with_stopping_condition( 43 | child, 44 | apply_fn, 45 | apply_condition=apply_condition, 46 | stopping_condition=stopping_condition, 47 | **other_args 48 | ) 49 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/train/README.md: -------------------------------------------------------------------------------- 1 | # OpenFlamingo Training 2 | To train OpenFlamingo, please ensure your environment matches that of `environment.yml`. 3 | 4 | ## Data 5 | Our codebase uses [WebDataset](https://github.com/webdataset/webdataset) to efficiently load `.tar` files containing image and text sequences. We recommend resampling shards with replacement during training using the `--dataset_resampled` flag. 6 | 7 | ### LAION-2B Dataset 8 | [LAION-2B](https://arxiv.org/abs/2210.08402) contains 2B web-scraped (image, text) pairs. 9 | We use [img2dataset](https://github.com/rom1504/img2dataset) to download this dataset into tar files. 10 | 11 | ### Multimodal C4 Dataset 12 | We train on the full version of [Multimodal C4 (MMC4)](https://github.com/allenai/mmc4), which includes 103M documents of web-scraped, interleaved image-text sequences. During training, we truncate sequences to 256 text tokens and six images per sequence. 13 | 14 | Our codebase expects `.tar` files containing `.json` files, which include raw images encoded in base64. 15 | We provide scripts to convert MMC4 to this format: 16 | 17 | 1. Download the MMC4 shards into `.zip` files using [the MMC4-provided scripts](https://github.com/allenai/mmc4/tree/main/scripts) (e.g., `fewer_facesv2.sh`). 18 | 2. Download the MMC4 raw images into an image directory using [the MMC4-provided scripts](https://github.com/allenai/mmc4/tree/main/scripts) (e.g., `download_images.py`). 19 | 2. Run `scripts/convert_mmc4_to_wds.py` to convert the downloaded items into the expected tar files. 20 | 21 | ### ChatGPT-generated sequences 22 | A subset of our models (listed below) were also trained on experimental ChatGPT-generated (image, text) sequences, where images are pulled from LAION. The shards containing these sequences can be found at [this CodaLab worksheet](https://worksheets.codalab.org/worksheets/0xdcd888ff7c754ae680c5e038f6ed1d9b). We are unable to distribute raw images in the released shards; images must be pre-downloaded from the urls in the json files and converted to base64 before using this data for training in our codebase. 23 | 24 | Models trained with ChatGPT-generated sequences: 25 | 26 | * OpenFlamingo-4B-vitl-rpj3b 27 | * OpenFlamingo-4B-vitl-rpj3b-langinstruct 28 | 29 | ## Example training command 30 | We provide a sample Slurm training script in `scripts/`. You can also modify the following command: 31 | 32 | ``` 33 | torchrun --nnodes=1 --nproc_per_node=4 train.py \ 34 | --lm_path anas-awadalla/mpt-1b-redpajama-200b \ 35 | --tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \ 36 | --cross_attn_every_n_layers 1 \ 37 | --dataset_resampled \ 38 | --batch_size_mmc4 32 \ 39 | --batch_size_laion 64 \ 40 | --train_num_samples_mmc4 125000\ 41 | --train_num_samples_laion 250000 \ 42 | --loss_multiplier_laion 0.2 \ 43 | --workers=4 \ 44 | --run_name OpenFlamingo-3B-vitl-mpt1b \ 45 | --num_epochs 480 \ 46 | --warmup_steps 1875 \ 47 | --mmc4_textsim_threshold 0.24 \ 48 | --laion_shards "/path/to/shards/shard-{0000..0999}.tar" \ 49 | --mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \ 50 | --report_to_wandb 51 | ``` 52 | *Note: The MPT-1B [base](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) and [instruct](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b-dolly) modeling code does not accept the `labels` kwarg or compute cross-entropy loss directly within `forward()`, as expected by our codebase. We suggest using a modified version of the MPT-1B models found [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b) and [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b-dolly).* 53 | 54 | ## Distributed training 55 | 56 | By default, `train.py` uses Pytorch's [DistributedDataParallel](https://pytorch.org/docs/stable/torch.nn.parallel.DistributedDataParallel.html) for training. 57 | To use [FullyShardedDataParallel](https://pytorch.org/docs/stable/fsdp.html), use the `--fsdp` flag. 58 | 59 | Some notes on FSDP: 60 | 61 | * We recommend using the `--fsdp_use_orig_params` flag. If `--fsdp` is on without this flag, all language model embeddings will be unfrozen during training. (In contrast, the default behavior is to only train the newly added `` and `<|endofchunk|>` tokens.) 62 | * Note: we've encountered issues using OPT with this flag. Other language models should be compatible. 63 | * Our current FSDP wrapping strategy does not permit training language model embeddings that use tied weights (i.e., tied input / output embeddings). To train such models with FSDP, the language model embeddings must be frozen with the `--freeze_lm_embeddings` flag. 64 | 65 | We also implement gradient checkpointing and mixed precision training. Use the `--gradient_checkpointing` and `--precision` arguments respectively. -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/train/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/train/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Util functions for initializing webdataset objects 3 | """ 4 | 5 | import ast 6 | import json 7 | import logging 8 | import os 9 | import random 10 | import sys 11 | from dataclasses import dataclass 12 | from multiprocessing import Value 13 | 14 | import braceexpand 15 | import numpy as np 16 | import webdataset as wds 17 | from PIL import Image 18 | from torch.utils.data import DataLoader, IterableDataset, get_worker_info 19 | from torch.utils.data.distributed import DistributedSampler 20 | from webdataset.filters import _shuffle 21 | from webdataset.tariterators import ( 22 | base_plus_ext, 23 | tar_file_expander, 24 | url_opener, 25 | valid_sample, 26 | ) 27 | 28 | try: 29 | import horovod.torch as hvd 30 | except ImportError: 31 | hvd = None 32 | 33 | 34 | class SharedEpoch: 35 | def __init__(self, epoch: int = 0): 36 | self.shared_epoch = Value("i", epoch) 37 | 38 | def set_value(self, epoch): 39 | self.shared_epoch.value = epoch 40 | 41 | def get_value(self): 42 | return self.shared_epoch.value 43 | 44 | 45 | @dataclass 46 | class DataInfo: 47 | dataloader: DataLoader 48 | sampler: DistributedSampler = None 49 | shared_epoch: SharedEpoch = None 50 | 51 | def set_epoch(self, epoch): 52 | if self.shared_epoch is not None: 53 | self.shared_epoch.set_value(epoch) 54 | if self.sampler is not None and isinstance(self.sampler, DistributedSampler): 55 | self.sampler.set_epoch(epoch) 56 | 57 | 58 | def get_dataset_size(shards): 59 | shards_list = list(braceexpand.braceexpand(shards)) 60 | dir_path = os.path.dirname(shards[0]) 61 | sizes_filename = os.path.join(dir_path, "sizes.json") 62 | len_filename = os.path.join(dir_path, "__len__") 63 | if os.path.exists(sizes_filename): 64 | sizes = json.load(open(sizes_filename, "r")) 65 | total_size = sum( 66 | [ 67 | int(sizes[os.path.basename(shard)]) 68 | if os.path.basename(shard) in sizes 69 | else 0 70 | for shard in shards_list 71 | ] 72 | ) 73 | elif os.path.exists(len_filename): 74 | # FIXME this used to be eval(open(...)) but that seemed rather unsafe 75 | total_size = ast.literal_eval(open(len_filename, "r").read()) 76 | else: 77 | total_size = None # num samples undefined 78 | # some common dataset sizes (at time of authors last download) 79 | # CC3M (train): 2905954 80 | # CC12M: 10968539 81 | # LAION-400M: 407332084 82 | # LAION-2B (english): 2170337258 83 | num_shards = len(shards_list) 84 | return total_size, num_shards 85 | 86 | 87 | def count_samples(dataloader): 88 | os.environ["WDS_EPOCH"] = "0" 89 | n_elements, n_batches = 0, 0 90 | for images, texts in dataloader: 91 | n_batches += 1 92 | n_elements += len(images) 93 | assert len(images) == len(texts) 94 | return n_elements, n_batches 95 | 96 | 97 | def log_and_continue(exn): 98 | """Call in an exception handler to ignore any exception, issue a warning, and continue.""" 99 | logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") 100 | return True 101 | 102 | 103 | def group_by_keys_nothrow( 104 | data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None 105 | ): 106 | """Return function over iterator that groups key, value pairs into samples. 107 | 108 | :param keys: function that splits the key into key and extension (base_plus_ext) 109 | :param lcase: convert suffixes to lower case (Default value = True) 110 | """ 111 | current_sample = None 112 | for filesample in data: 113 | assert isinstance(filesample, dict) 114 | fname, value = filesample["fname"], filesample["data"] 115 | prefix, suffix = keys(fname) 116 | if prefix is None: 117 | continue 118 | if lcase: 119 | suffix = suffix.lower() 120 | # FIXME webdataset version throws if suffix in current_sample, but we have a potential for 121 | # this happening in the current LAION400m dataset if a tar ends with same prefix as the next 122 | # begins, rare, but can happen since prefix aren't unique across tar files in that dataset 123 | if ( 124 | current_sample is None 125 | or prefix != current_sample["__key__"] 126 | or suffix in current_sample 127 | ): 128 | if valid_sample(current_sample): 129 | yield current_sample 130 | current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) 131 | if suffixes is None or suffix in suffixes: 132 | current_sample[suffix] = value 133 | if valid_sample(current_sample): 134 | yield current_sample 135 | 136 | 137 | def tarfile_to_samples_nothrow(src, handler=log_and_continue): 138 | # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw 139 | streams = url_opener(src, handler=handler) 140 | files = tar_file_expander(streams, handler=handler) 141 | samples = group_by_keys_nothrow(files, handler=handler) 142 | return samples 143 | 144 | 145 | def pytorch_worker_seed(increment=0): 146 | """get dataloader worker seed from pytorch""" 147 | worker_info = get_worker_info() 148 | if worker_info is not None: 149 | # favour using the seed already created for pytorch dataloader workers if it exists 150 | seed = worker_info.seed 151 | if increment: 152 | # space out seed increments so they can't overlap across workers in different iterations 153 | seed += increment * max(1, worker_info.num_workers) 154 | return seed 155 | # fallback to wds rank based seed 156 | return wds.utils.pytorch_worker_seed() 157 | 158 | 159 | class detshuffle2(wds.PipelineStage): 160 | def __init__( 161 | self, 162 | bufsize=1000, 163 | initial=100, 164 | seed=0, 165 | epoch=-1, 166 | ): 167 | self.bufsize = bufsize 168 | self.initial = initial 169 | self.seed = seed 170 | self.epoch = epoch 171 | 172 | def run(self, src): 173 | if isinstance(self.epoch, SharedEpoch): 174 | epoch = self.epoch.get_value() 175 | else: 176 | # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) 177 | # situation as different workers may wrap at different times (or not at all). 178 | self.epoch += 1 179 | epoch = self.epoch 180 | rng = random.Random() 181 | if self.seed < 0: 182 | # If seed is negative, we use the worker's seed, this will be different across all nodes/workers 183 | seed = pytorch_worker_seed(epoch) 184 | else: 185 | # This seed to be deterministic AND the same across all nodes/workers in each epoch 186 | seed = self.seed + epoch 187 | rng.seed(seed) 188 | return _shuffle(src, self.bufsize, self.initial, rng) 189 | 190 | 191 | class ResampledShards2(IterableDataset): 192 | """An iterable dataset yielding a list of urls.""" 193 | 194 | def __init__( 195 | self, 196 | urls, 197 | nshards=sys.maxsize, 198 | worker_seed=None, 199 | deterministic=False, 200 | epoch=-1, 201 | ): 202 | """Sample shards from the shard list with replacement. 203 | :param urls: a list of URLs as a Python list or brace notation string 204 | """ 205 | super().__init__() 206 | urls = wds.shardlists.expand_urls(urls) 207 | self.urls = urls 208 | assert isinstance(self.urls[0], str) 209 | self.nshards = nshards 210 | self.rng = random.Random() 211 | self.worker_seed = worker_seed 212 | self.deterministic = deterministic 213 | self.epoch = epoch 214 | 215 | def __iter__(self): 216 | """Return an iterator over the shards.""" 217 | if isinstance(self.epoch, SharedEpoch): 218 | epoch = self.epoch.get_value() 219 | else: 220 | # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) 221 | # situation as different workers may wrap at different times (or not at all). 222 | self.epoch += 1 223 | epoch = self.epoch 224 | 225 | if self.deterministic: 226 | # reset seed w/ epoch if deterministic 227 | if self.worker_seed is None: 228 | # pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id 229 | seed = pytorch_worker_seed(epoch) 230 | else: 231 | seed = self.worker_seed() + epoch 232 | self.rng.seed(seed) 233 | for _ in range(self.nshards): 234 | yield dict(url=self.rng.choice(self.urls)) 235 | -------------------------------------------------------------------------------- /open_flamingo/open_flamingo/train/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Util functions for setting up distributed training. 3 | Credit: https://github.com/mlfoundations/open_clip/blob/main/src/training/distributed.py 4 | """ 5 | 6 | import os 7 | import torch 8 | 9 | try: 10 | import horovod.torch as hvd 11 | except ImportError: 12 | hvd = None 13 | 14 | 15 | def is_global_master(args): 16 | return args.rank == 0 17 | 18 | 19 | def is_local_master(args): 20 | return args.local_rank == 0 21 | 22 | 23 | def is_master(args, local=False): 24 | return is_local_master(args) if local else is_global_master(args) 25 | 26 | 27 | def is_using_horovod(): 28 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 29 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 30 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 31 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 32 | if all([var in os.environ for var in ompi_vars]) or all( 33 | [var in os.environ for var in pmi_vars] 34 | ): 35 | return True 36 | else: 37 | return False 38 | 39 | 40 | def is_using_distributed(): 41 | if "WORLD_SIZE" in os.environ: 42 | return int(os.environ["WORLD_SIZE"]) > 1 43 | if "SLURM_NTASKS" in os.environ: 44 | return int(os.environ["SLURM_NTASKS"]) > 1 45 | return False 46 | 47 | 48 | def world_info_from_env(): 49 | local_rank = 0 50 | for v in ( 51 | "LOCAL_RANK", 52 | "MPI_LOCALRANKID", 53 | "SLURM_LOCALID", 54 | "OMPI_COMM_WORLD_LOCAL_RANK", 55 | ): 56 | if v in os.environ: 57 | local_rank = int(os.environ[v]) 58 | break 59 | global_rank = 0 60 | for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"): 61 | if v in os.environ: 62 | global_rank = int(os.environ[v]) 63 | break 64 | world_size = 1 65 | for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"): 66 | if v in os.environ: 67 | world_size = int(os.environ[v]) 68 | break 69 | 70 | return local_rank, global_rank, world_size 71 | 72 | 73 | def init_distributed_device(args): 74 | # Distributed training = training on more than one GPU. 75 | # Works in both single and multi-node scenarios. 76 | args.distributed = False 77 | args.world_size = 1 78 | args.rank = 0 # global rank 79 | args.local_rank = 0 80 | if args.horovod: 81 | assert hvd is not None, "Horovod is not installed" 82 | hvd.init() 83 | args.local_rank = int(hvd.local_rank()) 84 | args.rank = hvd.rank() 85 | args.world_size = hvd.size() 86 | args.distributed = True 87 | os.environ["LOCAL_RANK"] = str(args.local_rank) 88 | os.environ["RANK"] = str(args.rank) 89 | os.environ["WORLD_SIZE"] = str(args.world_size) 90 | elif is_using_distributed(): 91 | if "SLURM_PROCID" in os.environ: 92 | # DDP via SLURM 93 | args.local_rank, args.rank, args.world_size = world_info_from_env() 94 | # SLURM var -> torch.distributed vars in case needed 95 | os.environ["LOCAL_RANK"] = str(args.local_rank) 96 | os.environ["RANK"] = str(args.rank) 97 | os.environ["WORLD_SIZE"] = str(args.world_size) 98 | torch.distributed.init_process_group( 99 | backend=args.dist_backend, 100 | init_method=args.dist_url, 101 | world_size=args.world_size, 102 | rank=args.rank, 103 | ) 104 | else: 105 | # DDP via torchrun, torch.distributed.launch 106 | args.local_rank, _, _ = world_info_from_env() 107 | torch.distributed.init_process_group( 108 | backend=args.dist_backend, init_method=args.dist_url 109 | ) 110 | args.world_size = torch.distributed.get_world_size() 111 | args.rank = torch.distributed.get_rank() 112 | args.distributed = True 113 | else: 114 | # needed to run on single gpu 115 | torch.distributed.init_process_group( 116 | backend=args.dist_backend, 117 | init_method=args.dist_url, 118 | world_size=1, 119 | rank=0, 120 | ) 121 | 122 | if torch.cuda.is_available(): 123 | if args.distributed and not args.no_set_device_rank: 124 | device = "cuda:%d" % args.local_rank 125 | else: 126 | device = "cuda:0" 127 | torch.cuda.set_device(device) 128 | else: 129 | device = "cpu" 130 | args.device = device 131 | device = torch.device(device) 132 | return device 133 | -------------------------------------------------------------------------------- /open_flamingo/setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from setuptools import find_packages, setup 4 | 5 | if __name__ == "__main__": 6 | with Path(Path(__file__).parent, "README.md").open(encoding="utf-8") as file: 7 | long_description = file.read() 8 | 9 | REQUIREMENTS = [ 10 | "einops", 11 | "einops-exts", 12 | "transformers>=4.28.1", 13 | "torch==2.0.1", 14 | "pillow", 15 | "open_clip_torch>=2.16.0", 16 | "sentencepiece", 17 | ] 18 | 19 | EVAL = [ 20 | "scipy", 21 | "torchvision", 22 | "nltk", 23 | "inflection", 24 | "pycocoevalcap", 25 | "pycocotools", 26 | "tqdm", 27 | ] 28 | 29 | TRAINING = [ 30 | "wandb", 31 | "torchvision", 32 | "braceexpand", 33 | "webdataset", 34 | "tqdm", 35 | ] 36 | 37 | setup( 38 | name="open_flamingo", 39 | packages=find_packages(), 40 | include_package_data=True, 41 | version="2.0.1", 42 | license="MIT", 43 | description="An open-source framework for training large multimodal models", 44 | long_description=long_description, 45 | long_description_content_type="text/markdown", 46 | data_files=[(".", ["README.md"])], 47 | keywords=["machine learning"], 48 | install_requires=REQUIREMENTS, 49 | extras_require={ 50 | "eval": EVAL, 51 | "training": TRAINING, 52 | "all": list(set(REQUIREMENTS + EVAL + TRAINING)), 53 | }, 54 | classifiers=[ 55 | "Development Status :: 4 - Beta", 56 | "Intended Audience :: Developers", 57 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 58 | "License :: OSI Approved :: MIT License", 59 | "Programming Language :: Python :: 3.9", 60 | ], 61 | ) 62 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.6.1 2 | einops-exts==0.0.4 3 | entrypoints==0.4 4 | fasteners==0.18 5 | ffmpeg==1.4 6 | filelock==3.12.4 7 | flamingo-pytorch==0.1.2 8 | h5py==3.10.0 9 | huggingface-hub==0.18.0 10 | imageio==2.31.5 11 | imageio-ffmpeg==0.4.8 12 | joblib==1.2.0 13 | numpy==1.24.4 14 | numpy-quaternion==2022.4.3 15 | omegaconf==2.3.0 16 | open-clip-torch==2.20.0 17 | opencv-python 18 | # opencv-python==3.4.8.29 19 | opencv-python 20 | packaging==23.2 21 | pandas==2.0.3 22 | pathtools==0.1.2 23 | Pillow==9.5.0 24 | pyparsing==3.1.1 25 | pyrender==0.1.45 26 | python-dateutil==2.8.2 27 | pytorch-lightning==1.8.6 28 | # pytorch3d==0.3.0 29 | # pytorch3d 30 | PyYAML==6.0.1 31 | regex==2023.5.5 32 | requests==2.31.0 33 | requests-oauthlib==1.3.1 34 | responses==0.18.0 35 | rich==13.3.5 36 | rsa==4.9 37 | sacremoses==0.0.53 38 | safetensors==0.4.0 39 | scikit-image==0.19.3 40 | scikit-learn==1.2.2 41 | scipy==1.5.2 42 | sentence-transformers==2.2.2 43 | torch==1.12.1+cu113 44 | torchaudio==0.12.1+cu113 45 | torchvision==0.13.1+cu113 46 | tqdm==4.66.1 47 | transformers==4.33.1 48 | wandb==0.15.2 49 | yarl==1.9.2 50 | zarr==2.15.0 51 | zipp==3.17.0 52 | braceexpand 53 | webdataset 54 | einops_exts 55 | open_clip_torch>=2.16.0 56 | thop 57 | fvcore 58 | scikit-optimize 59 | bitsandbytes 60 | accelerate 61 | 62 | -------------------------------------------------------------------------------- /robot_flamingo/data/vl_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | vqa_data_dir = "path/to/vqav2/train2014" 3 | vqa_questions = "path/to/vqav2/v2_OpenEnded_mscoco_train2014_questions.json" 4 | vqa_ann = "path/to/vqav2/v2_mscoco_train2014_annotations.json" 5 | coco_data_dir = "path/to/coco/train2014" 6 | coco_ann = "path/to/coco/annotations/captions_train2014.json" 7 | 8 | import json 9 | import os 10 | import random 11 | from PIL import Image 12 | import torch 13 | from torch.utils.data import Dataset, DataLoader 14 | 15 | 16 | class CaptionDataset(Dataset): 17 | def __init__( 18 | self, 19 | image_train_dir_path, 20 | annotations_path, 21 | tokenizer=None, 22 | transforms=None, 23 | seed=123, 24 | is_train=True, 25 | dataset_name='coco', 26 | image_val_dir_path=None, 27 | ): 28 | self.image_train_dir_path = image_train_dir_path 29 | self.image_val_dir_path = image_val_dir_path 30 | self.annotations = [] 31 | self.is_train = is_train 32 | self.dataset_name = dataset_name 33 | self.seed = seed 34 | random.seed(self.seed) 35 | full_annotations = json.load(open(annotations_path)) 36 | self.tokenizer = tokenizer 37 | self.transforms = transforms 38 | print(len(full_annotations["images"]), len(full_annotations["annotations"])) 39 | self.id2path = {} 40 | self.id2caption = {} 41 | for i in range(len(full_annotations["images"])): 42 | self.id2path[full_annotations["images"][i]["id"]] = os.path.join( 43 | self.image_train_dir_path, full_annotations["images"][i]["file_name"]) 44 | self.image_ids = list(self.id2path.keys()) 45 | for i in range(len(full_annotations["annotations"])): 46 | image_id = full_annotations["annotations"][i]["image_id"] 47 | if image_id not in self.id2caption: 48 | self.id2caption[image_id] = [full_annotations["annotations"][i]['caption']] 49 | else: 50 | self.id2caption[image_id].append(full_annotations["annotations"][i]['caption']) 51 | 52 | def __len__(self): 53 | return len(self.image_ids) 54 | 55 | def __getitem__(self, idx): 56 | image = Image.open(self.id2path[self.image_ids[idx]]) 57 | image.load() 58 | caption = random.choice(self.id2caption[self.image_ids[idx]]) 59 | return { 60 | "image": image, 61 | "caption": caption, 62 | "image_id": self.image_ids[idx] 63 | } 64 | 65 | def get_caption_prompt(self, caption=None): 66 | return f"A photo of {caption if caption is not None else ''}" 67 | 68 | def collator(self, samples): 69 | images = torch.stack([self.transforms(s['image']) for s in samples], dim=0) 70 | text = [self.get_caption_prompt(s['caption']) for s in samples] 71 | text_tensors, attention_mask = self.tokenizer(text) 72 | return images, (text_tensors, attention_mask) 73 | 74 | 75 | class VQADataset(Dataset): 76 | def __init__( 77 | self, image_dir_path, question_path, annotations_path, tokenizer=None, transforms=None, seed=123, is_train=True, dataset_name='vqav2' 78 | ): 79 | self.questions = json.load(open(question_path, "r"))["questions"] 80 | if annotations_path is not None: 81 | self.answers = json.load(open(annotations_path, "r"))["annotations"] 82 | else: 83 | self.answers = None 84 | self.image_dir_path = image_dir_path 85 | self.is_train = is_train 86 | self.dataset_name = dataset_name 87 | # self.img_coco_split = "train2014" 88 | self.tokenizer = tokenizer 89 | self.transforms = transforms 90 | self.seed = seed 91 | random.seed(self.seed) 92 | if self.dataset_name in {"vqav2", "ok_vqa"}: 93 | self.img_coco_split = self.image_dir_path.strip("/").split("/")[-1] 94 | assert self.img_coco_split in {"train2014", "val2014", "test2015"} 95 | 96 | def __len__(self): 97 | return len(self.questions) 98 | 99 | def get_img_path(self, question): 100 | if self.dataset_name in {"vqav2", "ok_vqa"}: 101 | return os.path.join( 102 | self.image_dir_path, 103 | f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg" 104 | if self.is_train 105 | else f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg", 106 | ) 107 | elif self.dataset_name == "vizwiz": 108 | return os.path.join(self.image_dir_path, question["image_id"]) 109 | elif self.dataset_name == "textvqa": 110 | return os.path.join(self.image_dir_path, f"{question['image_id']}.jpg") 111 | else: 112 | raise Exception(f"Unknown VQA dataset {self.dataset_name}") 113 | 114 | def __getitem__(self, idx): 115 | question = self.questions[idx] 116 | img_path = self.get_img_path(question) 117 | image = Image.open(img_path) 118 | # image.load() 119 | results = { 120 | "image": image, 121 | "question": question["question"], 122 | "question_id": question["question_id"], 123 | } 124 | if self.answers is not None: 125 | answers = self.answers[idx] 126 | results["answers"] = [a["answer"] for a in answers["answers"]] 127 | return results 128 | 129 | def get_vqa_prompt(self, question, answer=None): 130 | return f"Question:{question} Short answer:{answer if answer is not None else ''}" 131 | 132 | def get_vqa_ques_prompt(self, question): 133 | return f"Question:{question} Short answer:" 134 | 135 | def collator(self, samples): 136 | images = torch.stack([self.transforms(s['image']) for s in samples], dim=0) 137 | text = [self.get_vqa_prompt(s['question'], random.choice(s['answers'])) for s in samples] 138 | text_tensors, attention_mask = self.tokenizer(text) 139 | B, T = attention_mask.shape 140 | ques = [self.get_vqa_ques_prompt(s['question']) for s in samples] 141 | _, ques_mask = self.tokenizer(ques) 142 | ques_len = ques_mask.sum(dim=1).unsqueeze(-1).expand(B, T) 143 | answer_mask = torch.ones_like(attention_mask) 144 | indices = torch.arange(answer_mask.shape[-1]).unsqueeze(0).expand(B, T) 145 | index_mask = indices < ques_len 146 | answer_mask.masked_fill_(index_mask, 0) 147 | answer_mask = answer_mask * attention_mask # both mask for attention and question 148 | return images, (text_tensors, attention_mask), answer_mask -------------------------------------------------------------------------------- /robot_flamingo/models/factory.py: -------------------------------------------------------------------------------- 1 | from logging import debug 2 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 3 | import open_clip 4 | from typing import Optional 5 | from robot_flamingo.models.flamingo_bc import BCFlamingo 6 | from robot_flamingo.models.flamingo_mpt import MPTFlamingo 7 | from open_flamingo.src.flamingo_lm import FlamingoLMMixin 8 | from open_flamingo.src.utils import extend_instance 9 | from open_flamingo.src.factory import _infer_decoder_layers_attr_name 10 | import torch 11 | 12 | clip_path = "/mnt/bn/yueyang/archive/clip" 13 | mpt_dict = { 14 | "mpt_dolly_3b": { 15 | "lang_encoder_path": "/mnt/bn/yueyang/archive/mpt-1b-redpajama-200b-dolly", 16 | "tokenizer_path": "/mnt/bn/yueyang/archive/mpt-1b-redpajama-200b-dolly", 17 | "cross_attn_every_n_layers": 1, 18 | "openflamingo_checkpoint": "/mnt/bn/yueyang/archive/OpenFlamingo-3B-vitl-mpt1b-langinstruct.pt" 19 | }, 20 | "mpt_9b": { 21 | "lang_encoder_path": "/mnt/bn/yueyang/archive/mpt-7b", 22 | "tokenizer_path": "/mnt/bn/yueyang/archive/mpt-7b", 23 | "cross_attn_every_n_layers": 4, 24 | "openflamingo_checkpoint":"/mnt/bn/yueyang/archive/OpenFlamingo-9B-vitl-mpt7b.pt" 25 | }, 26 | } 27 | 28 | 29 | 30 | def get_transforms( 31 | clip_vision_encoder_path: str = "ViT-L-14", 32 | clip_vision_encoder_pretrained: str = "openai", 33 | tokenizer_path: str = "path_to/llama-7b-hf-jxu124", 34 | use_local_files: bool = False, 35 | ): 36 | vision_encoder, _, image_processor = open_clip.create_model_and_transforms( 37 | clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained 38 | ) 39 | 40 | text_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 41 | # add Flamingo special tokens to the tokenizer 42 | text_tokenizer.add_special_tokens( 43 | {"additional_special_tokens": ["<|endofchunk|>", ""]} 44 | ) 45 | if text_tokenizer.pad_token is None: 46 | # Issue: GPT models don't have a pad token, which we use to 47 | # modify labels for the loss. 48 | text_tokenizer.add_special_tokens({"pad_token": ""}) 49 | 50 | return image_processor, text_tokenizer 51 | 52 | 53 | def create_model_and_transforms( 54 | clip_vision_encoder_path: str, 55 | clip_vision_encoder_pretrained: str, 56 | lang_encoder_path: str, 57 | tokenizer_path: str, 58 | cross_attn_every_n_layers: int = 1, 59 | use_local_files: bool = False, 60 | decoder_layers_attr_name: str = None, 61 | # this is the window size sampled from the episode 62 | window_size: int = 32, 63 | freeze_embed: bool = False, 64 | train_params = -1, 65 | use_gripper=False, 66 | use_state=False, 67 | last_action=False, 68 | fusion_mode='', 69 | pad_length=-1, 70 | debug=False, 71 | sep_resampler=False, 72 | sep_lm_head=False, 73 | unfreeze_vit=False, 74 | return_feature=False, 75 | multi_step_action=1, 76 | llm_name='llama_9b', 77 | pooling='max', 78 | residual=False, 79 | tcp_rel=False, 80 | replan=-1, 81 | decoder_type='lstm', 82 | hidden_size=None, 83 | freeze_sampler=False, 84 | fwd_pred=False, 85 | fwd_pred_hand=False, 86 | no_image_patch=False, 87 | global_latent=1, 88 | refresh=-1, 89 | head_type='deterministic', 90 | **flamingo_kwargs, 91 | ): 92 | """ 93 | Initialize a Flamingo model from a pretrained vision encoder and language encoder. 94 | Appends special tokens to the tokenizer and freezes backbones. 95 | 96 | Args: 97 | clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32") 98 | clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k") 99 | lang_encoder_path (str): path to pretrained language encoder 100 | tokenizer_path (str): path to pretrained tokenizer 101 | cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1. 102 | use_local_files (bool, optional): whether to use local files. Defaults to False. 103 | decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. 104 | Returns: 105 | Flamingo: Flamingo model from pretrained vision and language encoders 106 | Image processor: Pipeline to preprocess input images 107 | Tokenizer: A tokenizer for the language model 108 | """ 109 | vision_encoder, _, image_processor = open_clip.create_model_and_transforms( 110 | clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained, 111 | cache_dir=clip_path, 112 | ) 113 | # set the vision encoder to output the visual features 114 | vision_encoder.visual.output_tokens = True 115 | 116 | text_tokenizer = AutoTokenizer.from_pretrained( 117 | tokenizer_path, local_files_only=use_local_files 118 | ) 119 | # add Flamingo special tokens to the tokenizer 120 | text_tokenizer.add_special_tokens( 121 | {"additional_special_tokens": ["<|endofchunk|>", ""]} 122 | ) 123 | if text_tokenizer.pad_token is None: 124 | # Issue: GPT models don't have a pad token, which we use to 125 | # modify labels for the loss. 126 | text_tokenizer.add_special_tokens({"pad_token": ""}) 127 | if debug: 128 | # Load the local checkpoint into a model instance. 129 | lang_encoder = AutoModelForCausalLM.from_pretrained(lang_encoder_path, ignore_keys=["config"], trust_remote_code=True) 130 | # Set the `init_weights` parameter to `False` to prevent the model from loading the pretrained weights. 131 | lang_encoder.init_weights(False) 132 | else: 133 | print(lang_encoder_path) 134 | lang_encoder = AutoModelForCausalLM.from_pretrained( 135 | lang_encoder_path, local_files_only=use_local_files, trust_remote_code=True 136 | ) 137 | 138 | # hacks for MPT-1B, which doesn't have a get_input_embeddings method 139 | if "mpt-1b-redpajama-200b" in lang_encoder_path: 140 | 141 | class EmbeddingFnMixin: 142 | def get_input_embeddings(self): 143 | return self.transformer.wte 144 | 145 | def set_input_embeddings(self, new_embeddings): 146 | self.transformer.wte = new_embeddings 147 | extend_instance(lang_encoder, EmbeddingFnMixin) 148 | 149 | print( 150 | f"MPT with {sum(p.numel() for p in lang_encoder.parameters())/1e6:.2f}M parameters" 151 | ) 152 | 153 | # extend MPT to Mixin (add cross-attention layers to a language model) 154 | extend_instance(lang_encoder, FlamingoLMMixin) 155 | 156 | if decoder_layers_attr_name is None: 157 | decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) 158 | lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) 159 | lang_encoder.resize_token_embeddings(len(text_tokenizer)) 160 | 161 | if 'llama' in llm_name: 162 | Model_fn = BCFlamingo 163 | elif 'mpt' in llm_name: 164 | Model_fn = MPTFlamingo 165 | else: 166 | raise NotImplementedError 167 | 168 | model = Model_fn( 169 | vision_encoder, 170 | lang_encoder, 171 | text_tokenizer.encode("<|endofchunk|>")[-1], 172 | text_tokenizer.encode("")[-1], 173 | vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][ 174 | "width" 175 | ], 176 | cross_attn_every_n_layers=cross_attn_every_n_layers, 177 | window_size=window_size, 178 | use_gripper=use_gripper, 179 | use_state=use_state, 180 | fusion_mode=fusion_mode, 181 | last_action=last_action, 182 | pad_length=pad_length, 183 | sep_resampler=sep_resampler, 184 | sep_lm_head=sep_lm_head, 185 | return_feature=return_feature, 186 | multi_step_action=multi_step_action, 187 | llm=llm_name, 188 | pooling=pooling, 189 | residual=residual, 190 | tcp_rel=tcp_rel, 191 | replan=replan, 192 | decoder_type=decoder_type, 193 | head_type=head_type, 194 | hidden_size=hidden_size, 195 | refresh=refresh, 196 | fwd_pred=fwd_pred, 197 | fwd_pred_hand=fwd_pred_hand, 198 | no_image_patch=no_image_patch, 199 | global_latent=global_latent, 200 | **flamingo_kwargs, 201 | ) 202 | 203 | # Freeze all parameters 204 | model.requires_grad_(False) 205 | assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 206 | 207 | # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings 208 | # model.perceiver.requires_grad_(True) 209 | if train_params == -1: 210 | model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) 211 | print(f'{len(model.lang_encoder.gated_cross_attn_layers)=}') 212 | model.perceiver.requires_grad_(True) 213 | else: 214 | param_per_layer = 140 215 | layer_num = int(train_params / param_per_layer + 0.5) 216 | cnt = 0 217 | for ix in range(len(model.lang_encoder.gated_cross_attn_layers)-1, -1, -1): 218 | if cnt >= layer_num: 219 | break 220 | if model.lang_encoder.gated_cross_attn_layers[ix] is not None: 221 | model.lang_encoder.gated_cross_attn_layers[ix].requires_grad_(True) 222 | cnt += 1 223 | if freeze_sampler: 224 | model.perceiver.requires_grad_(False) 225 | if not freeze_embed: 226 | model.lang_encoder.get_input_embeddings().requires_grad_(True) 227 | model.lang_encoder.lm_head.requires_grad_(True) 228 | 229 | if model.sep_lm_head: 230 | model.lm_head.requires_grad_(True) 231 | if model.use_diff: 232 | model.diffusion_model.requires_grad_(True) 233 | if unfreeze_vit: 234 | model.vision_encoder.requires_grad_(True) 235 | if len(model.lm_exits) > 0: 236 | model.lm_exit_modules.requires_grad_(True) 237 | model.extra_exit.requires_grad_(True) 238 | # # Unfreeze the action head 239 | # model.action_head.requires_grad_(True) 240 | 241 | if torch.distributed.get_rank() == 0: 242 | print( 243 | f"Vision Enocder with {sum(p.numel() for p in vision_encoder.parameters())/1e6:.2f}M parameters" 244 | ) 245 | print( 246 | f"Vision Perciver with {sum(p.numel() for p in model.perceiver.parameters())/1e6:.2f}M parameters" 247 | ) 248 | print( 249 | f"{model.early_exit_layer+1}-layer LLM with {sum(p.numel() for p in model.lang_encoder.parameters())/1e6:.2f}M parameters" 250 | ) 251 | print( 252 | f"One Action head with {sum(p.numel() for p in model.lm_head.parameters())/1e6:.2f}M parameters" 253 | ) 254 | if hasattr(model, 'extra_exit'): 255 | print( 256 | f"Extra Action head with {sum(p.numel() for p in model.extra_exit.parameters())/1e6:.2f}M parameters" 257 | ) 258 | 259 | print( 260 | f"Flamingo model initialized with {sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters" 261 | ) 262 | print( 263 | f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.2f}M trainable parameters" 264 | ) 265 | if hasattr(model, 'extra_exit'): 266 | print(model.extra_exit) 267 | 268 | 269 | return model, image_processor, text_tokenizer 270 | -------------------------------------------------------------------------------- /robot_flamingo/models/normalizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Dict, Callable 2 | 3 | import unittest 4 | import zarr 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | def dict_apply( 10 | x: Dict[str, torch.Tensor], 11 | func: Callable[[torch.Tensor], torch.Tensor] 12 | ) -> Dict[str, torch.Tensor]: 13 | result = dict() 14 | for key, value in x.items(): 15 | if isinstance(value, dict): 16 | result[key] = dict_apply(value, func) 17 | else: 18 | result[key] = func(value) 19 | return result 20 | 21 | class DictOfTensorMixin(nn.Module): 22 | def __init__(self, params_dict=None): 23 | super().__init__() 24 | if params_dict is None: 25 | params_dict = nn.ParameterDict() 26 | self.params_dict = params_dict 27 | 28 | @property 29 | def device(self): 30 | return next(iter(self.parameters())).device 31 | 32 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 33 | def dfs_add(dest, keys, value: torch.Tensor): 34 | if len(keys) == 1: 35 | dest[keys[0]] = value 36 | return 37 | 38 | if keys[0] not in dest: 39 | dest[keys[0]] = nn.ParameterDict() 40 | dfs_add(dest[keys[0]], keys[1:], value) 41 | 42 | def load_dict(state_dict, prefix): 43 | out_dict = nn.ParameterDict() 44 | for key, value in state_dict.items(): 45 | value: torch.Tensor 46 | if key.startswith(prefix): 47 | param_keys = key[len(prefix):].split('.')[1:] 48 | # if len(param_keys) == 0: 49 | # import pdb; pdb.set_trace() 50 | dfs_add(out_dict, param_keys, value.clone()) 51 | return out_dict 52 | 53 | self.params_dict = load_dict(state_dict, prefix + 'params_dict') 54 | self.params_dict.requires_grad_(False) 55 | return 56 | 57 | class LinearNormalizer(DictOfTensorMixin): 58 | avaliable_modes = ['limits', 'gaussian'] 59 | 60 | @torch.no_grad() 61 | def fit(self, 62 | data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array], 63 | last_n_dims=1, 64 | dtype=torch.float32, 65 | mode='limits', 66 | output_max=1., 67 | output_min=-1., 68 | range_eps=1e-4, 69 | fit_offset=True): 70 | if isinstance(data, dict): 71 | for key, value in data.items(): 72 | self.params_dict[key] = _fit(value, 73 | last_n_dims=last_n_dims, 74 | dtype=dtype, 75 | mode=mode, 76 | output_max=output_max, 77 | output_min=output_min, 78 | range_eps=range_eps, 79 | fit_offset=fit_offset) 80 | else: 81 | self.params_dict['_default'] = _fit(data, 82 | last_n_dims=last_n_dims, 83 | dtype=dtype, 84 | mode=mode, 85 | output_max=output_max, 86 | output_min=output_min, 87 | range_eps=range_eps, 88 | fit_offset=fit_offset) 89 | 90 | def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: 91 | return self.normalize(x) 92 | 93 | def __getitem__(self, key: str): 94 | return SingleFieldLinearNormalizer(self.params_dict[key]) 95 | 96 | def __setitem__(self, key: str , value: 'SingleFieldLinearNormalizer'): 97 | self.params_dict[key] = value.params_dict 98 | 99 | def _normalize_impl(self, x, forward=True): 100 | if isinstance(x, dict): 101 | result = dict() 102 | for key, value in x.items(): 103 | params = self.params_dict[key] 104 | result[key] = _normalize(value, params, forward=forward) 105 | return result 106 | else: 107 | if '_default' not in self.params_dict: 108 | raise RuntimeError("Not initialized") 109 | params = self.params_dict['_default'] 110 | return _normalize(x, params, forward=forward) 111 | 112 | def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: 113 | return self._normalize_impl(x, forward=True) 114 | 115 | def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: 116 | return self._normalize_impl(x, forward=False) 117 | 118 | def get_input_stats(self) -> Dict: 119 | if len(self.params_dict) == 0: 120 | raise RuntimeError("Not initialized") 121 | if len(self.params_dict) == 1 and '_default' in self.params_dict: 122 | return self.params_dict['_default']['input_stats'] 123 | 124 | result = dict() 125 | for key, value in self.params_dict.items(): 126 | if key != '_default': 127 | result[key] = value['input_stats'] 128 | return result 129 | 130 | 131 | def get_output_stats(self, key='_default'): 132 | input_stats = self.get_input_stats() 133 | if 'min' in input_stats: 134 | # no dict 135 | return dict_apply(input_stats, self.normalize) 136 | 137 | result = dict() 138 | for key, group in input_stats.items(): 139 | this_dict = dict() 140 | for name, value in group.items(): 141 | this_dict[name] = self.normalize({key:value})[key] 142 | result[key] = this_dict 143 | return result 144 | 145 | 146 | class SingleFieldLinearNormalizer(DictOfTensorMixin): 147 | avaliable_modes = ['limits', 'gaussian'] 148 | 149 | @torch.no_grad() 150 | def fit(self, 151 | data: Union[torch.Tensor, np.ndarray, zarr.Array], 152 | last_n_dims=1, 153 | dtype=torch.float32, 154 | mode='limits', 155 | output_max=1., 156 | output_min=-1., 157 | range_eps=1e-4, 158 | fit_offset=True): 159 | self.params_dict = _fit(data, 160 | last_n_dims=last_n_dims, 161 | dtype=dtype, 162 | mode=mode, 163 | output_max=output_max, 164 | output_min=output_min, 165 | range_eps=range_eps, 166 | fit_offset=fit_offset) 167 | 168 | @classmethod 169 | def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs): 170 | obj = cls() 171 | obj.fit(data, **kwargs) 172 | return obj 173 | 174 | @classmethod 175 | def create_manual(cls, 176 | scale: Union[torch.Tensor, np.ndarray], 177 | offset: Union[torch.Tensor, np.ndarray], 178 | input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]]): 179 | def to_tensor(x): 180 | if not isinstance(x, torch.Tensor): 181 | x = torch.from_numpy(x) 182 | x = x.flatten() 183 | return x 184 | 185 | # check 186 | for x in [offset] + list(input_stats_dict.values()): 187 | assert x.shape == scale.shape 188 | assert x.dtype == scale.dtype 189 | 190 | params_dict = nn.ParameterDict({ 191 | 'scale': to_tensor(scale), 192 | 'offset': to_tensor(offset), 193 | 'input_stats': nn.ParameterDict( 194 | dict_apply(input_stats_dict, to_tensor)) 195 | }) 196 | return cls(params_dict) 197 | 198 | @classmethod 199 | def create_identity(cls, dtype=torch.float32): 200 | scale = torch.tensor([1], dtype=dtype) 201 | offset = torch.tensor([0], dtype=dtype) 202 | input_stats_dict = { 203 | 'min': torch.tensor([-1], dtype=dtype), 204 | 'max': torch.tensor([1], dtype=dtype), 205 | 'mean': torch.tensor([0], dtype=dtype), 206 | 'std': torch.tensor([1], dtype=dtype) 207 | } 208 | return cls.create_manual(scale, offset, input_stats_dict) 209 | 210 | def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: 211 | return _normalize(x, self.params_dict, forward=True) 212 | 213 | def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: 214 | return _normalize(x, self.params_dict, forward=False) 215 | 216 | def get_input_stats(self): 217 | return self.params_dict['input_stats'] 218 | 219 | def get_output_stats(self): 220 | return dict_apply(self.params_dict['input_stats'], self.normalize) 221 | 222 | def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: 223 | return self.normalize(x) 224 | 225 | 226 | 227 | def _fit(data: Union[torch.Tensor, np.ndarray, zarr.Array], 228 | last_n_dims=1, 229 | dtype=torch.float32, 230 | mode='limits', 231 | output_max=1., 232 | output_min=-1., 233 | range_eps=1e-4, 234 | fit_offset=True): 235 | assert mode in ['limits', 'gaussian'] 236 | assert last_n_dims >= 0 237 | assert output_max > output_min 238 | 239 | # convert data to torch and type 240 | if isinstance(data, zarr.Array): 241 | data = data[:] 242 | if isinstance(data, np.ndarray): 243 | data = torch.from_numpy(data) 244 | if dtype is not None: 245 | data = data.type(dtype) 246 | 247 | # convert shape 248 | dim = 1 249 | if last_n_dims > 0: 250 | dim = np.prod(data.shape[-last_n_dims:]) 251 | data = data.reshape(-1,dim) 252 | 253 | # compute input stats min max mean std 254 | input_min, _ = data.min(axis=0) 255 | input_max, _ = data.max(axis=0) 256 | input_mean = data.mean(axis=0) 257 | input_std = data.std(axis=0) 258 | 259 | # compute scale and offset 260 | if mode == 'limits': 261 | if fit_offset: 262 | # unit scale 263 | input_range = input_max - input_min 264 | ignore_dim = input_range < range_eps 265 | input_range[ignore_dim] = output_max - output_min 266 | scale = (output_max - output_min) / input_range 267 | offset = output_min - scale * input_min 268 | offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim] 269 | # ignore dims scaled to mean of output max and min 270 | else: 271 | # use this when data is pre-zero-centered. 272 | assert output_max > 0 273 | assert output_min < 0 274 | # unit abs 275 | output_abs = min(abs(output_min), abs(output_max)) 276 | input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max)) 277 | ignore_dim = input_abs < range_eps 278 | input_abs[ignore_dim] = output_abs 279 | # don't scale constant channels 280 | scale = output_abs / input_abs 281 | offset = torch.zeros_like(input_mean) 282 | elif mode == 'gaussian': 283 | ignore_dim = input_std < range_eps 284 | scale = input_std.clone() 285 | scale[ignore_dim] = 1 286 | scale = 1 / scale 287 | 288 | if fit_offset: 289 | offset = - input_mean * scale 290 | else: 291 | offset = torch.zeros_like(input_mean) 292 | 293 | # save 294 | this_params = nn.ParameterDict({ 295 | 'scale': scale, 296 | 'offset': offset, 297 | 'input_stats': nn.ParameterDict({ 298 | 'min': input_min, 299 | 'max': input_max, 300 | 'mean': input_mean, 301 | 'std': input_std 302 | }) 303 | }) 304 | for p in this_params.parameters(): 305 | p.requires_grad_(False) 306 | return this_params 307 | 308 | 309 | def _normalize(x, params, forward=True): 310 | assert 'scale' in params 311 | if isinstance(x, np.ndarray): 312 | x = torch.from_numpy(x) 313 | x = x.to(device=scale.device, dtype=scale.dtype) 314 | scale = params['scale'].to(device=x.device, dtype=x.dtype) 315 | offset = params['offset'].to(device=x.device, dtype=x.dtype) 316 | # x = x.to(device=scale.device, dtype=scale.dtype) 317 | src_shape = x.shape 318 | x = x.reshape(-1, scale.shape[0]) 319 | if forward: 320 | x = x * scale + offset 321 | else: 322 | x = (x - offset) / scale 323 | x = x.reshape(src_shape) 324 | return x 325 | 326 | 327 | def test(): 328 | data = torch.zeros((100,10,9,2)).uniform_() 329 | data[...,0,0] = 0 330 | 331 | normalizer = SingleFieldLinearNormalizer() 332 | normalizer.fit(data, mode='limits', last_n_dims=2) 333 | datan = normalizer.normalize(data) 334 | assert datan.shape == data.shape 335 | assert np.allclose(datan.max(), 1.) 336 | assert np.allclose(datan.min(), -1.) 337 | dataun = normalizer.unnormalize(datan) 338 | assert torch.allclose(data, dataun, atol=1e-7) 339 | 340 | input_stats = normalizer.get_input_stats() 341 | output_stats = normalizer.get_output_stats() 342 | 343 | normalizer = SingleFieldLinearNormalizer() 344 | normalizer.fit(data, mode='limits', last_n_dims=1, fit_offset=False) 345 | datan = normalizer.normalize(data) 346 | assert datan.shape == data.shape 347 | assert np.allclose(datan.max(), 1., atol=1e-3) 348 | assert np.allclose(datan.min(), 0., atol=1e-3) 349 | dataun = normalizer.unnormalize(datan) 350 | assert torch.allclose(data, dataun, atol=1e-7) 351 | 352 | data = torch.zeros((100,10,9,2)).uniform_() 353 | normalizer = SingleFieldLinearNormalizer() 354 | normalizer.fit(data, mode='gaussian', last_n_dims=0) 355 | datan = normalizer.normalize(data) 356 | assert datan.shape == data.shape 357 | assert np.allclose(datan.mean(), 0., atol=1e-3) 358 | assert np.allclose(datan.std(), 1., atol=1e-3) 359 | dataun = normalizer.unnormalize(datan) 360 | assert torch.allclose(data, dataun, atol=1e-7) 361 | 362 | 363 | # dict 364 | data = torch.zeros((100,10,9,2)).uniform_() 365 | data[...,0,0] = 0 366 | 367 | normalizer = LinearNormalizer() 368 | normalizer.fit(data, mode='limits', last_n_dims=2) 369 | datan = normalizer.normalize(data) 370 | assert datan.shape == data.shape 371 | assert np.allclose(datan.max(), 1.) 372 | assert np.allclose(datan.min(), -1.) 373 | dataun = normalizer.unnormalize(datan) 374 | assert torch.allclose(data, dataun, atol=1e-7) 375 | 376 | input_stats = normalizer.get_input_stats() 377 | output_stats = normalizer.get_output_stats() 378 | 379 | data = { 380 | 'obs': torch.zeros((1000,128,9,2)).uniform_() * 512, 381 | 'action': torch.zeros((1000,128,2)).uniform_() * 512 382 | } 383 | normalizer = LinearNormalizer() 384 | normalizer.fit(data) 385 | datan = normalizer.normalize(data) 386 | dataun = normalizer.unnormalize(datan) 387 | for key in data: 388 | assert torch.allclose(data[key], dataun[key], atol=1e-4) 389 | 390 | input_stats = normalizer.get_input_stats() 391 | output_stats = normalizer.get_output_stats() 392 | 393 | state_dict = normalizer.state_dict() 394 | n = LinearNormalizer() 395 | n.load_state_dict(state_dict) 396 | datan = n.normalize(data) 397 | dataun = n.unnormalize(datan) 398 | for key in data: 399 | assert torch.allclose(data[key], dataun[key], atol=1e-4) -------------------------------------------------------------------------------- /robot_flamingo/pt_eval_ckpts.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export EVALUTION_ROOT=$(pwd) 4 | 5 | # !!! Set for your own path 6 | # calvin_dataset_path='YOUR_PATH/calvin/dataset/task_D_D' 7 | calvin_dataset_path='/mnt/bn/yueyang/archive/calvin/dataset/task_D_D' 8 | # calvin_conf_path 9 | # calvin_conf_path="YOUR_PATH/calvin/calvin_models/conf" 10 | calvin_conf_path="/mnt/bn/yueyang/archive/calvin/calvin_models/conf" 11 | 12 | use_gripper=1 13 | use_state=0 14 | 15 | evaluate_from_checkpoint=$1 16 | log_file=$2 17 | window_size=$3 18 | node_num=$4 19 | amp=$5 20 | exit_ratio=${6} 21 | num_seq=${7} 22 | max_layer=${8} 23 | diverse_inst=${9} 24 | precision=${10} 25 | 26 | export MESA_GL_VERSION_OVERRIDE=4.1 27 | echo logging to ${log_file} 28 | 29 | script=eval_calvin.py 30 | echo "EVAL IN LONG HORIZON MODE" 31 | 32 | PORT=$((RANDOM % 16383 + 49152)) 33 | 34 | torchrun --nnodes=1 --nproc_per_node=${node_num} --master_port=$PORT robot_flamingo/eval/$script \ 35 | --precision ${precision} \ 36 | --use_gripper \ 37 | --diverse_inst ${diverse_inst} \ 38 | --window_size ${window_size} \ 39 | --run_name DeeR \ 40 | --calvin_dataset ${calvin_dataset_path} \ 41 | --cross_attn_every_n_layers 4 \ 42 | --evaluate_from_checkpoint ${evaluate_from_checkpoint} \ 43 | --calvin_conf_path ${calvin_conf_path} \ 44 | --amp ${amp} \ 45 | --exit_ratio ${exit_ratio} \ 46 | --max_layer ${max_layer} \ 47 | --num_seq ${num_seq} \ 48 | --validation_set \ 49 | --workers 1 > ${log_file} 2>&1 50 | -------------------------------------------------------------------------------- /robot_flamingo/pt_run_gripper_post_ws_12_traj_aug_mpt_dolly_3b.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PATH=$PATH:path/to/robot-flamingo/robot_flamingo 4 | export PYTHONPATH=$PYTHONPATH:path/to/robot-flamingo/robot_flamingo 5 | 6 | # dataset path 7 | calvin_dataset_path='path/to/calvin_data/task_ABCD_D' 8 | # language model path 9 | lm_path='path/to/mpt-1b-dolly' 10 | # tokenizer path 11 | tokenizer_path='path/to/mpt-1b-dolly' 12 | # openflamingo ckpt path 13 | openflamingo_checkpoint='path/to/OpenFlamingo-3B-vitl-mpt-1b-dolly/checkpoint.pt' 14 | 15 | subfix=`date "+%Y%m%d-%H%M"` 16 | log_file="logs/training_"${subfix}".log" 17 | source /mnt/bn/robotics/resources/anaconda3_arnold/bin/activate calvin_mpt 18 | #python3 -m torch.distributed.launch --nnodes=1 --nproc_per_node=2 --master_port=6042 robot_flamingo/train/train_calvin.py \ 19 | torchrun --nnodes=1 --nproc_per_node=8 --master_port=6042 robot_flamingo/train/train_calvin.py \ 20 | --report_to_wandb \ 21 | --llm_name mpt_dolly_3b \ 22 | --traj_cons \ 23 | --use_gripper \ 24 | --fusion_mode post \ 25 | --rgb_pad 10 \ 26 | --gripper_pad 4 \ 27 | --precision fp32 \ 28 | --num_epochs 5 \ 29 | --gradient_accumulation_steps 1 \ 30 | --batch_size_calvin 6 \ 31 | --run_name RobotFlamingoDBG \ 32 | --calvin_dataset ${calvin_dataset_path} \ 33 | --lm_path ${lm_path} \ 34 | --tokenizer_path ${tokenizer_path} \ 35 | --openflamingo_checkpoint ${openflamingo_checkpoint} \ 36 | --cross_attn_every_n_layers 4 \ 37 | --dataset_resampled \ 38 | --loss_multiplier_calvin 1.0 \ 39 | --workers 1 \ 40 | --lr_scheduler constant \ 41 | --warmup_steps 5000 \ 42 | --learning_rate 1e-4 \ 43 | --save_every_iter 10000 \ 44 | --from_scratch \ 45 | --window_size 12 > ${log_file} 2>&1 46 | -------------------------------------------------------------------------------- /robot_flamingo/pt_run_gripper_post_ws_12_traj_aug_mpt_dolly_3b_co_train.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | scp -r .cache/clip ~/.cache/ 3 | export PATH=$PATH:path/to/robot-flamingo/robot_flamingo 4 | export PYTHONPATH=$PYTHONPATH:path/to/robot-flamingo/robot_flamingo 5 | 6 | # dataset path 7 | calvin_dataset_path='/mnt/bn/robotics/manipulation_data/calvin_data/task_ABCD_D' 8 | # language model path 9 | lm_path='path/to/mpt-7b' 10 | # tokenizer path 11 | tokenizer_path='path/to/mpt-7b' 12 | # openflamingo ckpt path 13 | openflamingo_checkpoint='path/to/OpenFlamingo-9B-vitl-mpt7b/checkpoint.pt' 14 | 15 | subfix=`date "+%Y%m%d-%H%M"` 16 | log_file="logs/training_"${subfix}".log" 17 | source /mnt/bn/robotics/resources/anaconda3_arnold/bin/activate calvin_mpt 18 | #python3 -m torch.distributed.launch --nnodes=1 --nproc_per_node=2 --master_port=6042 robot_flamingo/train/train_calvin.py \ 19 | torchrun --nnodes=1 --nproc_per_node=8 --master_port=6042 robot_flamingo/train/train_calvin.py \ 20 | --report_to_wandb \ 21 | --cotrain \ 22 | --llm_name mpt_dolly_3b \ 23 | --traj_cons \ 24 | --use_gripper \ 25 | --fusion_mode post \ 26 | --rgb_pad 10 \ 27 | --gripper_pad 4 \ 28 | --precision fp32 \ 29 | --num_epochs 5 \ 30 | --gradient_accumulation_steps 1 \ 31 | --batch_size_calvin 6 \ 32 | --run_name RobotFlamingoDBGCotrain \ 33 | --calvin_dataset ${calvin_dataset_path} \ 34 | --lm_path ${lm_path} \ 35 | --tokenizer_path ${tokenizer_path} \ 36 | --openflamingo_checkpoint ${openflamingo_checkpoint} \ 37 | --cross_attn_every_n_layers 4 \ 38 | --dataset_resampled \ 39 | --loss_multiplier_calvin 1.0 \ 40 | --workers 1 \ 41 | --lr_scheduler constant \ 42 | --warmup_steps 5000 \ 43 | --learning_rate 1e-4 \ 44 | --from_scratch \ 45 | --window_size 12 > ${log_file} 2>&1 46 | -------------------------------------------------------------------------------- /robot_flamingo/thresholds.bash: -------------------------------------------------------------------------------- 1 | # load exist threshold 2 | torchrun --nnodes=1 --nproc_per_node=8 --master_port=12345 robot_flamingo/eval/eval_calvin.py \ 3 | --precision fp16 \ 4 | --use_gripper \ 5 | --window_size 12 \ 6 | --fusion_mode post \ 7 | --run_name RobotFlamingoDBG \ 8 | --calvin_dataset /mnt/bn/yueyang/archive/calvin/dataset/task_D_D \ 9 | --validation_set \ 10 | --data_percent 0.1 \ 11 | --load_threshold 1 \ 12 | --cross_attn_every_n_layers 4 \ 13 | --evaluate_from_checkpoint RobotFlamingo_task_ABCD_D-exit-strategy/stg=post_3+1_layer_11_multie_intv=2_extrae_nodth_reg_mlpdrp=0.5_layerwise_lstmdrp=0.4_aug_10_4_traj_cons_ws_12_mpt_dolly_3b_3.pth \ 14 | --calvin_conf_path /mnt/bn/yueyang/archive/calvin/calvin_models/conf \ 15 | --eval_exit_mode dynamic \ 16 | --exit_ratio 1.0 \ 17 | --value_type action \ 18 | --threshold_type L2 --exit_dist exp --max_layer 12 \ 19 | --num_seq 56 \ 20 | --workers 1 > log_solving_threshold_ablation/ABC_D_solve_on_valD_1.5 2>&1 21 | 22 | # original: solve threshold on D val set 23 | torchrun --nnodes=1 --nproc_per_node=8 --master_port=12345 robot_flamingo/eval/eval_calvin.py \ 24 | --precision fp16 \ 25 | --use_gripper \ 26 | --window_size 12 \ 27 | --fusion_mode post \ 28 | --run_name RobotFlamingoDBG \ 29 | --calvin_dataset /mnt/bn/yueyang/archive/calvin/dataset/task_D_D \ 30 | --validation_set \ 31 | --data_percent 0.1 \ 32 | --load_threshold 0 \ 33 | --cross_attn_every_n_layers 4 \ 34 | --evaluate_from_checkpoint RobotFlamingo_task_ABCD_D-exit-strategy/stg=post_3+1_layer_11_multie_intv=2_extrae_nodth_reg_mlpdrp=0.5_layerwise_lstmdrp=0.4_aug_10_4_traj_cons_ws_12_mpt_dolly_3b_3.pth \ 35 | --calvin_conf_path /mnt/bn/yueyang/archive/calvin/calvin_models/conf \ 36 | --eval_exit_mode dynamic \ 37 | --exit_ratio 1.0 \ 38 | --value_type action \ 39 | --threshold_type L2 --exit_dist exp --max_layer 12 \ 40 | --num_seq 56 \ 41 | --workers 1 > log_solving_threshold_ablation/ABC_D_solve_on_valD_1.5 2>&1 42 | 43 | # solve threshold on ABC training set 44 | torchrun --nnodes=1 --nproc_per_node=8 --master_port=12345 robot_flamingo/eval/eval_calvin.py \ 45 | --precision fp16 \ 46 | --use_gripper \ 47 | --window_size 12 \ 48 | --fusion_mode post \ 49 | --run_name RobotFlamingoDBG \ 50 | --calvin_dataset /mnt/bn/yueyang/archive/calvin/dataset/task_ABC_D \ 51 | --data_percent 0.001 \ 52 | --load_threshold 0 \ 53 | --cross_attn_every_n_layers 4 \ 54 | --evaluate_from_checkpoint RobotFlamingo_task_ABC_D-exit-strategy/stg=post_4+4_layer_11_multie_intv=2_extrae_nodth_reg_aug_10_4_traj_cons_ws_12_mpt_dolly_3b_4.pth \ 55 | --calvin_conf_path /mnt/bn/yueyang/archive/calvin/calvin_models/conf \ 56 | --eval_exit_mode dynamic \ 57 | --exit_ratio 1.0 \ 58 | --value_type action \ 59 | --threshold_type L2 --exit_dist exp --max_layer 12 \ 60 | --num_seq 224 \ 61 | --workers 1 > log_solving_threshold_ablation/ABC_D_solve_on_train_0.001_alpha=1.0_seq224 2>&1 62 | 63 | torchrun --nnodes=1 --nproc_per_node=8 --master_port=12346 robot_flamingo/eval/eval_calvin.py \ 64 | --precision fp16 \ 65 | --use_gripper \ 66 | --window_size 12 \ 67 | --fusion_mode post \ 68 | --run_name RobotFlamingoDBG \ 69 | --calvin_dataset /mnt/bn/yueyang/archive/calvin/dataset/task_ABC_D \ 70 | --data_percent 0.01 \ 71 | --load_threshold 0 \ 72 | --cross_attn_every_n_layers 4 \ 73 | --evaluate_from_checkpoint RobotFlamingo_task_ABC_D-exit-strategy/stg=post_4+4_layer_11_multie_intv=2_extrae_nodth_reg_aug_10_4_traj_cons_ws_12_mpt_dolly_3b_4.pth \ 74 | --calvin_conf_path /mnt/bn/yueyang/archive/calvin/calvin_models/conf \ 75 | --eval_exit_mode dynamic \ 76 | --exit_ratio 1.5 \ 77 | --value_type action \ 78 | --threshold_type L2 --exit_dist exp --max_layer 12 \ 79 | --num_seq 56 \ 80 | --workers 1 > log_solving_threshold_ablation/ABC_D_solve_on_trainABC_0.01_alpha=1.5 2>&1 81 | 82 | # solve threshold on D training set 83 | -------------------------------------------------------------------------------- /robot_flamingo/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | # from pytorch3d.transforms import ( 5 | # euler_angles_to_matrix, 6 | # matrix_to_euler_angles, 7 | # matrix_to_quaternion, 8 | # quaternion_to_matrix, 9 | # ) 10 | import torch 11 | from torch.cuda.amp import autocast 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def world_to_tcp_frame(action, robot_obs): 17 | with autocast(dtype=torch.float32): 18 | flag = False 19 | if len(action.shape) == 4: 20 | flag = True 21 | b, s, f, _ = action.shape 22 | action = action.view(b, s*f, -1) 23 | robot_obs = robot_obs.view(b, s*f, -1) 24 | b, s, _ = action.shape 25 | world_T_tcp = euler_angles_to_matrix(robot_obs[..., 3:6], convention="XYZ").float().view(-1, 3, 3) 26 | tcp_T_world = torch.inverse(world_T_tcp) 27 | pos_w_rel = action[..., :3].view(-1, 3, 1) 28 | pos_tcp_rel = tcp_T_world @ pos_w_rel 29 | # downscaling is necessary here to get pseudo infinitesimal rotation 30 | orn_w_rel = action[..., 3:6] * 0.01 31 | world_T_tcp_new = ( 32 | euler_angles_to_matrix(robot_obs[..., 3:6] + orn_w_rel, convention="XYZ").float().view(-1, 3, 3) 33 | ) 34 | tcp_new_T_tcp_old = torch.inverse(world_T_tcp_new) @ world_T_tcp 35 | orn_tcp_rel = matrix_to_euler_angles(tcp_new_T_tcp_old, convention="XYZ").float() 36 | orn_tcp_rel = torch.where(orn_tcp_rel < -np.pi, orn_tcp_rel + 2 * np.pi, orn_tcp_rel) 37 | orn_tcp_rel = torch.where(orn_tcp_rel > np.pi, orn_tcp_rel - 2 * np.pi, orn_tcp_rel) 38 | # upscaling again 39 | orn_tcp_rel *= 100 40 | action_tcp = torch.cat([pos_tcp_rel.view(b, s, -1), orn_tcp_rel.view(b, s, -1), action[..., -1:]], dim=-1) 41 | if flag: 42 | action_tcp = action_tcp.view(b, s, -1, action_tcp.shape[-1]) 43 | assert not torch.any(action_tcp.isnan()) 44 | return action_tcp 45 | 46 | 47 | def tcp_to_world_frame(action, robot_obs): 48 | with autocast(dtype=torch.float32): 49 | flag = False 50 | if len(action.shape) == 4: 51 | flag = True 52 | b, s, f, _ = action.shape 53 | action = action.view(b, s*f, -1) 54 | robot_obs = robot_obs.view(b, s*f, -1) 55 | b, s, _ = action.shape 56 | world_T_tcp = euler_angles_to_matrix(robot_obs[..., 3:6], convention="XYZ").float().view(-1, 3, 3) 57 | pos_tcp_rel = action[..., :3].view(-1, 3, 1) 58 | pos_w_rel = world_T_tcp @ pos_tcp_rel 59 | # downscaling is necessary here to get pseudo infinitesimal rotation 60 | orn_tcp_rel = action[..., 3:6] * 0.01 61 | tcp_new_T_tcp_old = euler_angles_to_matrix(orn_tcp_rel, convention="XYZ").float().view(-1, 3, 3) 62 | world_T_tcp_new = world_T_tcp @ torch.inverse(tcp_new_T_tcp_old) 63 | 64 | orn_w_new = matrix_to_euler_angles(world_T_tcp_new, convention="XYZ").float() 65 | if torch.any(orn_w_new.isnan()): 66 | logger.warning("NaN value in euler angles.") 67 | orn_w_new = matrix_to_euler_angles( 68 | quaternion_to_matrix(matrix_to_quaternion(world_T_tcp_new)), convention="XYZ" 69 | ).float() 70 | orn_w_rel = orn_w_new - robot_obs[..., 3:6].view(-1, 3) 71 | orn_w_rel = torch.where(orn_w_rel < -np.pi, orn_w_rel + 2 * np.pi, orn_w_rel) 72 | orn_w_rel = torch.where(orn_w_rel > np.pi, orn_w_rel - 2 * np.pi, orn_w_rel) 73 | # upscaling again 74 | orn_w_rel *= 100 75 | action_w = torch.cat([pos_w_rel.view(b, s, -1), orn_w_rel.view(b, s, -1), action[..., -1:]], dim=-1) 76 | if flag: 77 | action_w = action_w.view(b, s, -1, action_w.shape[-1]) 78 | assert not torch.any(action_w.isnan()) 79 | return action_w 80 | 81 | if __name__ == "__main__": 82 | action = torch.randn((4, 5, 3, 7)) 83 | robot_obs = torch.randn((4, 5, 3, 7)) 84 | print(world_to_tcp_frame(action, robot_obs)) --------------------------------------------------------------------------------