├── templates ├── mobile_device.txt ├── narrative.txt ├── plan_template.txt ├── face.txt ├── multi_modal_event.txt ├── commonsense-goal.txt ├── commonsense-experience.txt ├── commonsense-relationship.txt ├── commonsense-characteristic.txt ├── commonsense-routine.txt ├── persona.txt ├── dialogue_summary.txt ├── first_stark_dialogue.txt └── next_stark_dialogue.txt ├── assets └── stark_mcu_overview.PNG ├── photomaker ├── __init__.py └── model.py ├── runner ├── __init__.py ├── alignment_runner.py ├── summarizer_runner.py ├── narrative_runner.py ├── face_runner.py ├── album_runner.py ├── event_runner.py ├── base_runner.py ├── commonsense_runner.py ├── dialogue_runner.py └── persona_runner.py ├── prepare_image_db ├── download_data.sh ├── build_index.sh └── build_embedding.sh ├── utils ├── etc_utils.py └── persona_utils.py ├── LICENSE ├── postprocess_final_dataset.py ├── execute_web_search.py ├── generate_stark_dialogue.py ├── execute_sdxl.py ├── execute_retrieval.py ├── generate_face_image.py ├── execute_photomaker.py ├── scripts └── run_mcu.sh ├── README.md ├── make_final_dataset.py └── plan_runner.py /templates/mobile_device.txt: -------------------------------------------------------------------------------- 1 | {sentence} 2 | 3 | Image descriptions stored on {name}'s mobile device: 4 | 1. -------------------------------------------------------------------------------- /assets/stark_mcu_overview.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/passing2961/Stark/HEAD/assets/stark_mcu_overview.PNG -------------------------------------------------------------------------------- /templates/narrative.txt: -------------------------------------------------------------------------------- 1 | {sentence} 2 | 3 | Rewrite this sentence with more specific details in two or three sentences: -------------------------------------------------------------------------------- /templates/plan_template.txt: -------------------------------------------------------------------------------- 1 | Name: {name} 2 | Gender: {gender} 3 | Age: {age} 4 | Image Description: {image_description} 5 | Module: -------------------------------------------------------------------------------- /templates/face.txt: -------------------------------------------------------------------------------- 1 | Profile Information: 2 | - Age: {age} 3 | - Gender: {gender} 4 | - Nationality: {nationality} 5 | 6 | Human Description: -------------------------------------------------------------------------------- /templates/multi_modal_event.txt: -------------------------------------------------------------------------------- 1 | {name}'s initial personal event: {event} 2 | 3 | Given the {name}'s initial personal event, generate the temporal event graph containing more than five events. 4 | Temporal Event Graph: -------------------------------------------------------------------------------- /photomaker/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import PhotoMakerIDEncoder 2 | from .pipeline import PhotoMakerStableDiffusionXLPipeline 3 | 4 | __all__ = [ 5 | "PhotoMakerIDEncoder", 6 | "PhotoMakerStableDiffusionXLPipeline", 7 | ] -------------------------------------------------------------------------------- /templates/commonsense-goal.txt: -------------------------------------------------------------------------------- 1 | {demo_sent} {persona_attr} I plan . 2 | 3 | Generate the most appropriate sentence for "" in the given sentence. You must provide the answer corresponding to "". 4 | : -------------------------------------------------------------------------------- /templates/commonsense-experience.txt: -------------------------------------------------------------------------------- 1 | I . Now, {demo_sent} {persona_attr} 2 | 3 | Generate the most appropriate sentence for "" in the given sentence. You must provide the answer corresponding to "". 4 | : -------------------------------------------------------------------------------- /templates/commonsense-relationship.txt: -------------------------------------------------------------------------------- 1 | {demo_sent} {persona_attr} So, I . 2 | 3 | Generate the most appropriate sentence for "" in the given sentence. You must provide the answer corresponding to "". 4 | : -------------------------------------------------------------------------------- /templates/commonsense-characteristic.txt: -------------------------------------------------------------------------------- 1 | {demo_sent} {persona_attr} I . 2 | 3 | Generate the most appropriate sentence for "" in the given sentence. You must provide the answer corresponding to "". 4 | : -------------------------------------------------------------------------------- /templates/commonsense-routine.txt: -------------------------------------------------------------------------------- 1 | {demo_sent} {persona_attr} I regularly . 2 | 3 | Generate the most appropriate sentence for "" in the given sentence. You must provide the answer corresponding to "". 4 | : -------------------------------------------------------------------------------- /templates/persona.txt: -------------------------------------------------------------------------------- 1 | Profile Information: 2 | - Age: {age} 3 | - Gender: {gender} 4 | - Birthplace: {birthplace} 5 | - Residence: {residence} 6 | 7 | Persona Category: {target_persona_category} 8 | Persona Entity Key: {target_persona_entity} 9 | Persona Sentences: 10 | 1. -------------------------------------------------------------------------------- /templates/dialogue_summary.txt: -------------------------------------------------------------------------------- 1 | The current time and date are {current_date}. {name} and AI assistant talked today and had the following conversation: 2 | 3 | {dialogue} 4 | 5 | Summarize the conversation between {name} and AI assistant so far. Include key details and include time references wherever possible. 6 | 7 | Summarization: -------------------------------------------------------------------------------- /runner/__init__.py: -------------------------------------------------------------------------------- 1 | from .persona_runner import PersonaRunner 2 | from .commonsense_runner import CommonsenseRunner 3 | from .narrative_runner import NarrativeRunner 4 | from .event_runner import EventRunner 5 | from .dialogue_runner import DialogueRunner 6 | from .album_runner import AlbumRunner 7 | from .face_runner import FaceRunner 8 | from .summarizer_runner import SummarizerRunner -------------------------------------------------------------------------------- /prepare_image_db/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget https://storage.googleapis.com/conceptual_12m/cc12m.tsv 4 | 5 | sed -i '1s/^/url\tcaption\n/' cc12m.tsv 6 | 7 | img2dataset --url_list cc12m.tsv --input_format "tsv" \ 8 | --url_col "url" --caption_col "caption" --output_format webdataset \ 9 | --output_folder cc12m --processes_count 16 --thread_count 64 --image_size 256 \ 10 | --enable_wandb False -------------------------------------------------------------------------------- /utils/etc_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def load_json(datadir: str): 5 | with open(datadir, 'r') as f: 6 | return json.load(f) 7 | 8 | def load_jsonl(datadir: str): 9 | output = [] 10 | with open(datadir) as f: 11 | for line in f.readlines(): 12 | output.append(json.loads(line)) 13 | return output 14 | 15 | def load_txt(datadir: str): 16 | with open(datadir, 'r') as f: 17 | return f.read() -------------------------------------------------------------------------------- /templates/first_stark_dialogue.txt: -------------------------------------------------------------------------------- 1 | {name}'s Profile Information: 2 | - Age: {age} 3 | - Gender: {gender} 4 | - Birthplace: {birthplace} 5 | - Residence: {residence} 6 | 7 | Existing image descriptions in {name}'s mobile device: {mobile_device} 8 | 9 | The topic of the conversation between the AI assistant and {name} on {date} today is as follows. 10 | - Topic on {date}: {event} 11 | 12 | Generate a long, in-depth conversation with multiple turns based on the given {name}'s profile information and the current topic of conversation. 13 | -------------------------------------------------------------------------------- /prepare_image_db/build_index.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | IMAGE_DATASETS=("cc12m""ai2d" "chartqa" "mathvision" "coco_train2017" "gqa" "ocr_vqa" "redcaps12m" "textvqa" "vg") 4 | 5 | for IMAGE_DATASET in "${IMAGE_DATASETS[@]}" 6 | do 7 | autofaiss build_index \ 8 | --embeddings="embeddings_folder/${IMAGE_DATASET}/img_emb" \ 9 | --index_path="index_folder/${IMAGE_DATASET}/knn.index" \ 10 | --index_infos_path="index_folder/${IMAGE_DATASET}/infos.json" \ 11 | --metric_type="ip" \ 12 | --max_index_query_time_ms=10 \ 13 | --max_index_memory_usage="16GB" 14 | done -------------------------------------------------------------------------------- /templates/next_stark_dialogue.txt: -------------------------------------------------------------------------------- 1 | {name}'s Profile Information: 2 | - Age: {age} 3 | - Gender: {gender} 4 | - Birthplace: {birthplace} 5 | - Residence: {residence} 6 | 7 | Existing image descriptions in {name}'s mobile device: {mobile_device} 8 | 9 | The topics of the conversation the user had with AI assistant by date are as follows: 10 | {history_event} 11 | 12 | {time_interval} later from the {last_date}, on {date} today, {name} has gone through a new experience, and based on this experience, {name} and the AI assistant engage in a conversation today. The new experience {name} went through and the topic of conversation with the AI assistant are as follows. 13 | - {name}'s Experience: {experience} 14 | - Topic on {date}: {event} 15 | 16 | Generate a long, in-depth conversation with multiple turns based on the given {name}'s profile information, the last topic of conversation, the experience and the current topic of conversation. 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Young-Jun Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /runner/alignment_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm import tqdm 4 | import pyarrow.parquet as pq 5 | import pandas as pd 6 | from joblib import Parallel, delayed 7 | from joblib_progress import joblib_progress 8 | 9 | from .base_runner import BaseRunner, console 10 | 11 | 12 | def _rebuild_cc12m_dataset(filename): 13 | if '.jpg' not in filename: 14 | return None 15 | else: 16 | return filename 17 | 18 | class AlignmentRunner(BaseRunner): 19 | def __init__(self, args): 20 | super().__init__(args) 21 | 22 | self._load_image_file() 23 | 24 | def _flatten_dataset(self, dataset): 25 | return [ 26 | ele for ele in dataset if ele is not None 27 | ] 28 | 29 | def system_msg(self): 30 | return "You are a helpful assistant." 31 | 32 | def prompt_prefix(self): 33 | return "persona-attr" 34 | 35 | def _load_image_file(self): 36 | cc12m_path = '/home/yjlee/workspace/ICCV2023/pipeline/data_collection/image_files/cc12m' 37 | 38 | subdir = os.listdir(cc12m_path)[:100] 39 | cc12m_filenames = [] 40 | for _dir in tqdm(subdir, total=len(subdir)): 41 | if '.parquet' in _dir or '.json' in _dir: 42 | continue 43 | 44 | for ele in os.listdir(os.path.join(cc12m_path, _dir)): 45 | cc12m_filenames.append(os.path.join(cc12m_path, _dir, ele)) 46 | 47 | console.log('[{}] # of CC12M dataset: {}'.format(self.__class__.__name__, len(cc12m_filenames))) 48 | 49 | with joblib_progress("Loading CC12M dataset...", total=len(cc12m_filenames)): 50 | cc12m_dataset = Parallel(n_jobs=32)(delayed(_rebuild_cc12m_dataset)(filename) for filename in cc12m_filenames) 51 | 52 | cc12m_filenames = self._flatten_dataset(cc12m_dataset) 53 | 54 | console.log('[{}] # of CC12M dataset: {}'.format(self.__class__.__name__, len(cc12m_filenames))) 55 | 56 | 57 | def run(self): 58 | return None -------------------------------------------------------------------------------- /prepare_image_db/build_embedding.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | clip-retrieval inference \ 5 | --input_dataset "data/textvqa/train_images" \ 6 | --output_folder embeddings_folder/textvqa \ 7 | --clip_model ViT-L/14@336px 8 | 9 | clip-retrieval inference \ 10 | --input_dataset "data/vg/VG_100K" \ 11 | --output_folder embeddings_folder/vg \ 12 | --clip_model ViT-L/14@336px 13 | 14 | clip-retrieval inference \ 15 | --input_dataset "data/ocr_vqa/images" \ 16 | --output_folder embeddings_folder/ocr_vqa \ 17 | --clip_model ViT-L/14@336px 18 | 19 | clip-retrieval inference \ 20 | --input_dataset "data/gqa/images" \ 21 | --output_folder embeddings_folder/gqa \ 22 | --clip_model ViT-L/14@336px 23 | 24 | clip-retrieval inference \ 25 | --input_dataset "data/coco/train2017" \ 26 | --output_folder embeddings_folder/coco_train2017 \ 27 | --clip_model ViT-L/14@336px 28 | 29 | clip-retrieval inference \ 30 | --input_dataset "cc12m/{00000..01242}.tar" \ 31 | --output_folder embeddings_folder/cc12m \ 32 | --input_format webdataset \ 33 | --clip_model ViT-L/14@336px \ 34 | --enable_metadata True \ 35 | --output_partition_count 1243 36 | 37 | clip-retrieval inference \ 38 | --input_dataset "data/redcaps/redcaps12m_shards/{00000..00180}.tar" \ 39 | --output_folder embeddings_folder/redcaps12m \ 40 | --input_format webdataset \ 41 | --output_partition_count 181 \ 42 | --clip_model ViT-L/14@336px \ 43 | --enable_metadata True 44 | 45 | 46 | clip-retrieval inference \ 47 | --input_dataset "data/ai2d/images" \ 48 | --output_folder embeddings_folder/ai2d \ 49 | --clip_model ViT-L/14@336px 50 | 51 | clip-retrieval inference \ 52 | --input_dataset "data/ChartQA Dataset/train/png" \ 53 | --output_folder embeddings_folder/chartqa \ 54 | --clip_model ViT-L/14@336px 55 | 56 | clip-retrieval inference \ 57 | --input_dataset "data/mathvision/images" \ 58 | --output_folder embeddings_folder/mathvision \ 59 | --clip_model ViT-L/14@336px -------------------------------------------------------------------------------- /postprocess_final_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import random 5 | from tqdm import tqdm 6 | 7 | 8 | def load_json(datadir): 9 | with open(datadir, 'r', encoding='utf-8') as f: 10 | return json.load(f) 11 | 12 | def load_jsonl(datadir: str): 13 | output = [] 14 | with open(datadir, 'r', encoding='utf-8') as f: 15 | for line in f.readlines(): 16 | output.append(json.loads(line)) 17 | return output 18 | 19 | def dump_json_output(outputs, file_name): 20 | with open(file_name, 'w', encoding='utf-8') as f: 21 | json.dump(outputs, f, ensure_ascii=False, indent='\t') 22 | 23 | def dump_jsonl_output(outputs, file_name=None): 24 | f = open(file_name, 'w', encoding='utf-8') 25 | for output in outputs: 26 | f.write(json.dumps(output) + '\n') 27 | f.close() 28 | 29 | 30 | def load_dataset(): 31 | all_stark = [] 32 | for persona_seed_num in range(0, 1): 33 | stark = load_json(os.path.join(f'./Stark/stark_{persona_seed_num}.json')) 34 | 35 | all_stark.extend(stark) 36 | 37 | return all_stark 38 | 39 | def process_dialog(dialog): 40 | dialogue = eval(dialog) 41 | cnt = 0 42 | new_cnt = 0 43 | redialog = [] 44 | for item in dialogue: 45 | utter_id = item['utter_id'] 46 | speaker = item['speaker'] 47 | utter = item['utter'] 48 | sharing_info = item['sharing_info'] 49 | 50 | cp_item = copy.deepcopy(item) 51 | if utter == '' and len(sharing_info) == 0: 52 | cp_item['utter'] = '' 53 | 54 | redialog.append(cp_item) 55 | 56 | return redialog 57 | 58 | def process_dataset(dataset): 59 | 60 | re_dataset = [] 61 | for instance in dataset: 62 | uuid = instance['unique_id'] 63 | name = instance['name'] 64 | 65 | episode = instance['episode'] 66 | 67 | cp_instance = copy.deepcopy(instance) 68 | 69 | re_epi = [] 70 | for idx, session in enumerate(episode): 71 | session_dialog = session[f'session{idx+1}:dialogue'] 72 | p_dialog = process_dialog(session_dialog) 73 | 74 | cp_session = copy.deepcopy(session) 75 | cp_session[f'session{idx+1}:dialogue'] = p_dialog 76 | re_epi.append(cp_session) 77 | 78 | cp_instance['episode'] = re_epi 79 | re_dataset.append(cp_instance) 80 | 81 | return re_dataset 82 | #return all_count 83 | 84 | if __name__ == '__main__': 85 | 86 | dataset = load_dataset() 87 | print(len(dataset)) 88 | save_dir = 'Stark/post-process' 89 | os.makedirs(save_dir, exist_ok=True) 90 | 91 | processed_dataset = process_dataset(dataset) 92 | print(len(processed_dataset)) 93 | dump_json_output(processed_dataset, os.path.join(save_dir, 'stark_0.json')) -------------------------------------------------------------------------------- /runner/summarizer_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import copy 4 | import random 5 | from collections import defaultdict 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from glob import glob 10 | 11 | from .base_runner import BaseRunner, console 12 | from utils.etc_utils import load_jsonl, load_txt, load_json 13 | 14 | 15 | class SummarizerRunner(BaseRunner): 16 | def __init__(self, args): 17 | super().__init__(args) 18 | 19 | self.args = args 20 | 21 | self.save_dir = os.path.join(self.output_base_dir, 'dialogue-summary', f'persona_seed:{args.persona_seed_num}') 22 | 23 | os.makedirs(self.save_dir, exist_ok=True) 24 | self.last_save_chunk_idx_file = os.path.join(self.save_dir, 'last_save_chunk_idx.txt') 25 | 26 | self._load_prompt_template() 27 | 28 | def _load_prompt_template(self): 29 | self.template = load_txt('./templates/dialogue_summary.txt') 30 | 31 | @property 32 | def system_msg(self): 33 | return "Your job is to summarize the given conversation." 34 | 35 | @property 36 | def prompt_prefix(self): 37 | #if self.args.image_alignment_target == 'mobile-device-image': 38 | return "dialogue-summary" 39 | 40 | def convert_flatten_dialogue(self, dialogue): 41 | 42 | flatten_dialogue = [] 43 | for instance in dialogue: 44 | print(instance) 45 | spk = instance['speaker'] 46 | utter = instance['utterance'] 47 | 48 | if len(instance['sharing_info']) != 0: 49 | image_desc = instance['sharing_info']['image_description'] 50 | flatten_dialogue.append(f'{spk}: [Sharing Image of {image_desc}]') 51 | else: 52 | flatten_dialogue.append(f'{spk}: {utter}') 53 | 54 | return '\n'.join(flatten_dialogue) 55 | 56 | def prepare_prompt(self): 57 | try: 58 | results = load_jsonl(os.path.join(self.output_base_dir, 'dialogue', f'persona_seed:{self.args.persona_seed_num}', f'session_num:{self.args.target_session_num}', 'final_output.jsonl')) 59 | console.log('[{}] # of Total results: {}'.format(self.__class__.__name__, len(results))) 60 | except FileNotFoundError as e: 61 | return [] 62 | 63 | if self.args.debug: 64 | try: 65 | results = random.sample(results, self.args.debug_sample_num) 66 | except ValueError as e: 67 | results = results 68 | 69 | prompts = [] 70 | for instance in tqdm(results, total=len(results)): 71 | print(instance.keys()) 72 | print(instance['dialogue:date']) 73 | current_date = instance['dialogue:date'] 74 | name = instance['name'] 75 | 76 | print(instance['dialogue:last_date']) 77 | print(instance['dialogue:history_event']) 78 | print(instance['session_number']) 79 | flatten_dialogue = self.convert_flatten_dialogue(instance['parsed_dialogue_generation']) 80 | 81 | prompt = self.template.format(current_date=current_date, dialogue=flatten_dialogue, name=name) 82 | print(prompt) 83 | assert False 84 | 85 | def parse_and_filter(self): 86 | return None -------------------------------------------------------------------------------- /execute_web_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import copy 5 | import time 6 | import torch 7 | import random 8 | import argparse 9 | import warnings 10 | from tqdm import tqdm 11 | import pandas as pd 12 | from pathlib import Path 13 | 14 | import uuid 15 | import requests 16 | import concurrent.futures 17 | from icrawler.builtin import BingImageCrawler 18 | 19 | 20 | # Suppress warnings 21 | warnings.filterwarnings("ignore") 22 | 23 | 24 | 25 | MODULE_MAPPER = { 26 | 't2i': 'sdxl-lightning', 27 | 'p-t2i': 'photomaker', 28 | 'web': 'bing', 29 | 'retrieval': 'image_db' 30 | } 31 | 32 | def load_json(path): 33 | with open(path, 'r', encoding='utf-8') as f: 34 | return json.load(f) 35 | 36 | def search_image(target_instance, SAVE_PATH): 37 | image_uuid = target_instance['image_uuid'] 38 | bing_crawler = BingImageCrawler(downloader_threads=8, storage={"root_dir": os.path.join(SAVE_PATH, image_uuid)}) 39 | 40 | image_desc = target_instance['image_description'] 41 | bing_crawler.crawl(keyword=image_desc, offset=0, max_num=10, filters=None) 42 | 43 | cp_instance = copy.deepcopy(target_instance) 44 | cp_instance['image_save_path'] = os.path.join(SAVE_PATH, image_uuid) 45 | return cp_instance 46 | 47 | def batch_images(dataset, SAVE_PATH): 48 | target_dataset, non_target_dataset = [], [] 49 | for instance in tqdm(dataset, total=len(dataset)): 50 | module = instance['image_alignment_module'] 51 | model_id = MODULE_MAPPER[module] 52 | 53 | image_uuid = instance['image_uuid'] 54 | 55 | 56 | if model_id == 'bing': 57 | if os.path.exists(os.path.join(SAVE_PATH, image_uuid)): 58 | non_target_dataset.append(instance) 59 | continue 60 | target_dataset.append(instance) 61 | else: 62 | non_target_dataset.append(instance) 63 | 64 | print('# of total dataset:', len(dataset)) 65 | print('# of target dataset:', len(target_dataset)) 66 | print('# of non-target dataset:', len(non_target_dataset)) 67 | 68 | final_dataset = [] 69 | with concurrent.futures.ProcessPoolExecutor(max_workers=16) as executor: 70 | futures = [] 71 | 72 | for instance in tqdm(target_dataset, total=len(target_dataset)): 73 | cp_instance = copy.deepcopy(instance) 74 | 75 | future = executor.submit(search_image, cp_instance, SAVE_PATH) 76 | futures.append(future) 77 | 78 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 79 | ret = future.result() 80 | final_dataset.append(ret) 81 | 82 | return final_dataset + non_target_dataset 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument('--start-idx', type=int) 87 | parser.add_argument('--end-idx', type=int) 88 | args = parser.parse_args() 89 | 90 | for persona_seed_num in range(args.start_idx, args.end_idx): 91 | 92 | dataset = load_json(f'curated_stark/planner-parsed-openai/stark_{persona_seed_num}.json') 93 | SAVE_PATH = f'generated_image/plan-and-execute/web_searcher/stark_{persona_seed_num}' 94 | os.makedirs(SAVE_PATH, exist_ok=True) 95 | 96 | generations = batch_images(dataset, SAVE_PATH) 97 | 98 | data_save_path = 'curated_stark/plan-and-execute/web_searcher' 99 | os.makedirs(data_save_path, exist_ok=True) 100 | 101 | with open(os.path.join(data_save_path, f'stark_{persona_seed_num}.json'), 'w', encoding='utf-8') as f: 102 | json.dump(generations, f, ensure_ascii=False, indent='\t') 103 | -------------------------------------------------------------------------------- /generate_stark_dialogue.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | 5 | from runner import ( 6 | PersonaRunner, 7 | CommonsenseRunner, 8 | NarrativeRunner, 9 | AlignmentRunner, 10 | EventRunner, 11 | DialogueRunner, 12 | #ImageRunner, 13 | AlbumRunner, 14 | AlbumImageRunner, 15 | FaceRunner, 16 | FaceImageRunner 17 | ) 18 | 19 | 20 | RUNNER_MAP = { 21 | 'persona-attr': PersonaRunner, 22 | 'commonsense': CommonsenseRunner, 23 | 'narrative': NarrativeRunner, 24 | 'event': EventRunner, 25 | 'dialogue': DialogueRunner, 26 | 'album': AlbumRunner, 27 | 'face': FaceRunner, 28 | } 29 | 30 | def main(args): 31 | random.seed(42) 32 | 33 | runner = RUNNER_MAP[args.runner_name](args) 34 | runner.run() 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser(description='arguments for generating multi-modal dialogues using LLM') 38 | parser.add_argument('--run-id', 39 | type=str, 40 | default='vanilla', 41 | help='the name of the directory where the output will be dumped') 42 | parser.add_argument('--model', 43 | type=str, 44 | default='gpt-3.5-turbo-1106', 45 | help='which LLM to use') 46 | parser.add_argument('--temperature', 47 | type=float, 48 | default=0.9, 49 | help="control randomness: lowering results in less random completion") 50 | parser.add_argument('--top-p', 51 | type=float, 52 | default=0.95, 53 | help="nucleus sampling") 54 | parser.add_argument('--frequency-penalty', 55 | type=float, 56 | default=1.0, 57 | help="decreases the model's likelihood to repeat the same line verbatim") 58 | parser.add_argument('--presence-penalty', 59 | type=float, 60 | default=0.6, 61 | help="increases the model's likelihood to talk about new topics") 62 | parser.add_argument('--max-tokens', 63 | type=int, 64 | default=1024, 65 | help='maximum number of tokens to generate') 66 | parser.add_argument('--split', 67 | type=str, 68 | default=None, 69 | help='Specify the dataset split (i.e., train, validation, test).') 70 | parser.add_argument('--runner-name', 71 | type=str, 72 | default=None, 73 | help='Specify the runner name (e.g., persona-attribute)') 74 | parser.add_argument('--do-parse-filter', 75 | action='store_true', 76 | help='do parsing and filtering based on llm-generated results') 77 | parser.add_argument('--diffusion-model-id', 78 | type=str, 79 | default=None, 80 | help='Specify the diffusion model.') 81 | parser.add_argument('--cache-dir', 82 | type=str, 83 | default=None, 84 | help='Cache dir for downloading pre-trained diffusion model.') 85 | parser.add_argument('--debug', 86 | action='store_true', 87 | help='do debugging for generating small number of sampels.') 88 | parser.add_argument('--debug-sample-num', 89 | type=int, 90 | default=None, 91 | help="Number of sample for debug.") 92 | parser.add_argument('--shard-num', 93 | type=int, 94 | default=200, 95 | help='Number of sharded files.') 96 | parser.add_argument('--persona-seed-num', 97 | type=int, 98 | default=None, 99 | help="Persona seed number.") 100 | parser.add_argument('--target-session-num', 101 | type=int, 102 | default=None, 103 | help="Target dialogue session number.") 104 | args = parser.parse_args() 105 | main(args) -------------------------------------------------------------------------------- /execute_sdxl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import copy 5 | import time 6 | import torch 7 | import random 8 | import argparse 9 | import warnings 10 | from tqdm import tqdm 11 | import pandas as pd 12 | from pathlib import Path 13 | import concurrent.futures 14 | 15 | import torch 16 | from accelerate import PartialState 17 | from diffusers import ( 18 | StableDiffusionXLPipeline, 19 | StableDiffusionPipeline, 20 | UNet2DConditionModel, 21 | EulerDiscreteScheduler, 22 | PixArtAlphaPipeline, 23 | DiffusionPipeline, 24 | DDIMScheduler 25 | ) 26 | from diffusers.utils import load_image 27 | from safetensors.torch import load_file 28 | from huggingface_hub import hf_hub_download 29 | from accelerate.utils import gather_object 30 | from photomaker import PhotoMakerStableDiffusionXLPipeline 31 | 32 | # Suppress warnings 33 | warnings.filterwarnings("ignore") 34 | 35 | MODULE_MAPPER = { 36 | 't2i': 'sdxl-lightning', 37 | 'p-t2i': 'photomaker', 38 | 'web': 'bing', 39 | 'retrieval': 'image_db' 40 | } 41 | 42 | def load_json(path): 43 | with open(path, 'r', encoding='utf-8') as f: 44 | return json.load(f) 45 | 46 | def load_sdxl_model(cache_dir: str, device: str): 47 | """Load the SDXL Lightning diffusion model.""" 48 | base = "stabilityai/stable-diffusion-xl-base-1.0" 49 | repo = "ByteDance/SDXL-Lightning" 50 | ckpt = "sdxl_lightning_8step_unet.safetensors" 51 | unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, torch.float16) 52 | unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device)) 53 | pipe = StableDiffusionXLPipeline.from_pretrained( 54 | base, unet=unet, torch_dtype=torch.float16, cache_dir=cache_dir, variant="fp16" 55 | ) 56 | pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") 57 | return pipe.to(device) 58 | 59 | 60 | cache_dir = './pretrained_diffusion_model' 61 | 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--start-idx', type=int) 64 | parser.add_argument('--end-idx', type=int) 65 | args = parser.parse_args() 66 | 67 | model = load_sdxl_model(cache_dir, 'cuda:2') 68 | 69 | @torch.inference_mode() 70 | def generate_image(target_instance, SAVE_PATH): 71 | 72 | target_image_uuid = target_instance['image_uuid'] 73 | 74 | images = model( 75 | target_instance['image_description'], 76 | num_inference_steps=8, 77 | guidance_scale=0, 78 | num_images_per_prompt=1 79 | ).images 80 | 81 | save_paths = [] 82 | for idx, image in enumerate(images): 83 | save_paths.append(os.path.join(SAVE_PATH, f'{idx}:{target_image_uuid}.png')) 84 | image.save(os.path.join(SAVE_PATH, f'{idx}:{target_image_uuid}.png')) 85 | 86 | cp_instance = copy.deepcopy(target_instance) 87 | cp_instance['image_save_paths'] = save_paths 88 | return cp_instance 89 | 90 | 91 | def batch_images(dataset, SAVE_PATH): 92 | 93 | target_dataset, non_target_dataset = [], [] 94 | for instance in tqdm(dataset, total=len(dataset)): 95 | module = instance['image_alignment_module'] 96 | model_id = MODULE_MAPPER[module] 97 | 98 | if model_id == 'sdxl-lightning': 99 | if os.path.exists(os.path.join(SAVE_PATH, '0:{}.png'.format(instance['image_uuid']))): 100 | non_target_dataset.append(instance) 101 | continue 102 | target_dataset.append(instance) 103 | else: 104 | non_target_dataset.append(instance) 105 | 106 | print('# of total dataset:', len(dataset)) 107 | print('# of target dataset:', len(target_dataset)) 108 | print('# of non-target dataset:', len(non_target_dataset)) 109 | 110 | completions_per_process = [] 111 | for batch in tqdm(target_dataset, total=len(target_dataset)): 112 | 113 | result = generate_image(batch, SAVE_PATH) 114 | completions_per_process.append(result) 115 | 116 | print('# of final dataset:', len(completions_per_process) + len(non_target_dataset)) 117 | return completions_per_process + non_target_dataset 118 | 119 | if __name__ == '__main__': 120 | 121 | for persona_seed_num in range(args.start_idx, args.end_idx): 122 | dataset = load_json(f'curated_stark/planner-parsed-openai/stark_{persona_seed_num}.json') 123 | SAVE_PATH = f'generated_image/plan-and-execute/sdxl/stark_{persona_seed_num}' 124 | os.makedirs(SAVE_PATH, exist_ok=True) 125 | 126 | generations = batch_images(dataset, SAVE_PATH) 127 | 128 | data_save_path = 'curated_stark/plan-and-execute/sdxl' 129 | os.makedirs(data_save_path, exist_ok=True) 130 | 131 | with open(os.path.join(data_save_path, f'stark_{persona_seed_num}.json'), 'w', encoding='utf-8') as f: 132 | json.dump(generations, f, ensure_ascii=False, indent='\t') 133 | -------------------------------------------------------------------------------- /photomaker/model.py: -------------------------------------------------------------------------------- 1 | # Merge image encoder and fuse module to create an ID Encoder 2 | # send multiple ID images, we can directly obtain the updated text encoder containing a stacked ID embedding 3 | 4 | import torch 5 | import torch.nn as nn 6 | from transformers.models.clip.modeling_clip import CLIPVisionModelWithProjection 7 | from transformers.models.clip.configuration_clip import CLIPVisionConfig 8 | from transformers import PretrainedConfig 9 | 10 | VISION_CONFIG_DICT = { 11 | "hidden_size": 1024, 12 | "intermediate_size": 4096, 13 | "num_attention_heads": 16, 14 | "num_hidden_layers": 24, 15 | "patch_size": 14, 16 | "projection_dim": 768 17 | } 18 | 19 | class MLP(nn.Module): 20 | def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True): 21 | super().__init__() 22 | if use_residual: 23 | assert in_dim == out_dim 24 | self.layernorm = nn.LayerNorm(in_dim) 25 | self.fc1 = nn.Linear(in_dim, hidden_dim) 26 | self.fc2 = nn.Linear(hidden_dim, out_dim) 27 | self.use_residual = use_residual 28 | self.act_fn = nn.GELU() 29 | 30 | def forward(self, x): 31 | residual = x 32 | x = self.layernorm(x) 33 | x = self.fc1(x) 34 | x = self.act_fn(x) 35 | x = self.fc2(x) 36 | if self.use_residual: 37 | x = x + residual 38 | return x 39 | 40 | 41 | class FuseModule(nn.Module): 42 | def __init__(self, embed_dim): 43 | super().__init__() 44 | self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False) 45 | self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True) 46 | self.layer_norm = nn.LayerNorm(embed_dim) 47 | 48 | def fuse_fn(self, prompt_embeds, id_embeds): 49 | stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) 50 | stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds 51 | stacked_id_embeds = self.mlp2(stacked_id_embeds) 52 | stacked_id_embeds = self.layer_norm(stacked_id_embeds) 53 | return stacked_id_embeds 54 | 55 | def forward( 56 | self, 57 | prompt_embeds, 58 | id_embeds, 59 | class_tokens_mask, 60 | ) -> torch.Tensor: 61 | # id_embeds shape: [b, max_num_inputs, 1, 2048] 62 | id_embeds = id_embeds.to(prompt_embeds.dtype) 63 | num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case 64 | batch_size, max_num_inputs = id_embeds.shape[:2] 65 | # seq_length: 77 66 | seq_length = prompt_embeds.shape[1] 67 | # flat_id_embeds shape: [b*max_num_inputs, 1, 2048] 68 | flat_id_embeds = id_embeds.view( 69 | -1, id_embeds.shape[-2], id_embeds.shape[-1] 70 | ) 71 | # valid_id_mask [b*max_num_inputs] 72 | valid_id_mask = ( 73 | torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :] 74 | < num_inputs[:, None] 75 | ) 76 | valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()] 77 | 78 | prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) 79 | class_tokens_mask = class_tokens_mask.view(-1) 80 | valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) 81 | # slice out the image token embeddings 82 | image_token_embeds = prompt_embeds[class_tokens_mask] 83 | stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) 84 | assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" 85 | prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) 86 | updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) 87 | return updated_prompt_embeds 88 | 89 | class PhotoMakerIDEncoder(CLIPVisionModelWithProjection): 90 | def __init__(self): 91 | super().__init__(CLIPVisionConfig(**VISION_CONFIG_DICT)) 92 | self.visual_projection_2 = nn.Linear(1024, 1280, bias=False) 93 | self.fuse_module = FuseModule(2048) 94 | 95 | def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask): 96 | b, num_inputs, c, h, w = id_pixel_values.shape 97 | id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) 98 | 99 | shared_id_embeds = self.vision_model(id_pixel_values)[1] 100 | id_embeds = self.visual_projection(shared_id_embeds) 101 | id_embeds_2 = self.visual_projection_2(shared_id_embeds) 102 | 103 | id_embeds = id_embeds.view(b, num_inputs, 1, -1) 104 | id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) 105 | 106 | id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) 107 | updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) 108 | 109 | return updated_prompt_embeds 110 | 111 | 112 | if __name__ == "__main__": 113 | PhotoMakerIDEncoder() -------------------------------------------------------------------------------- /runner/narrative_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import random 4 | from tqdm import tqdm 5 | from glob import glob 6 | from collections import defaultdict 7 | 8 | from names_dataset import NameDataset 9 | 10 | from .base_runner import BaseRunner, console 11 | from utils.etc_utils import load_jsonl, load_txt 12 | from utils.persona_utils import COUNTRY_ALPHA_LIST 13 | 14 | 15 | class NarrativeRunner(BaseRunner): 16 | def __init__(self, args): 17 | super().__init__(args) 18 | 19 | self.save_dir = os.path.join(self.output_base_dir, 'narrative', f'persona_seed:{args.persona_seed_num}') 20 | os.makedirs(self.save_dir, exist_ok=True) 21 | 22 | self.nd = NameDataset() 23 | 24 | self._load_prompt_template() 25 | self._load_universe_name_list() 26 | self.last_save_chunk_idx_file = os.path.join(self.save_dir, 'last_save_chunk_idx.txt') 27 | 28 | @property 29 | def system_msg(self): 30 | return "You are a helpful assistant." 31 | 32 | @property 33 | def prompt_prefix(self): 34 | return "narrative" 35 | 36 | def _load_persona_commonsense_knowledge(self): 37 | our_peacok = load_jsonl(os.path.join(self.output_base_dir, 'commonsense-knowledge', 'final_output_0.jsonl')) 38 | 39 | self.persona_CK = our_peacok 40 | console.log('[{}] Done Loading Persona Commonsense Knowledge..'.format(self.__class__.__name__)) 41 | 42 | def _load_universe_name_list(self): 43 | all_names = defaultdict(dict) 44 | for country_alpha2_code in COUNTRY_ALPHA_LIST: 45 | top_names = self.nd.get_top_names(n=1000, country_alpha2=country_alpha2_code) 46 | 47 | male_names = top_names[country_alpha2_code]['M'] 48 | female_names = top_names[country_alpha2_code]['F'] 49 | 50 | all_names[country_alpha2_code] = { 51 | 'Male': male_names, 52 | 'Female': female_names, 53 | 'Non-binary': male_names + female_names 54 | } 55 | 56 | self.name_group = all_names 57 | 58 | def _load_prompt_template(self): 59 | self.sentence_form_template = { 60 | 'routine': 'My name is {name}. {demo_sent} {persona_attr} I regularly {commonsense}.', 61 | 'characteristic': 'My name is {name}. {demo_sent} {persona_attr} I {commonsense}.', 62 | 'experience': 'My name is {name}. I {commonsense}. Now, {demo_sent} {persona_attr}', 63 | 'goal': 'My name is {name}. {demo_sent} {persona_attr} I plan {commonsense}.', 64 | 'relationship': 'My name is {name}. {demo_sent} {persona_attr} So, I {commonsense}.' 65 | } 66 | 67 | self.sentence_to_narrative_template = load_txt('./templates/narrative.txt') 68 | 69 | def prepare_prompt(self): 70 | 71 | persona_CK = load_jsonl(os.path.join(self.output_base_dir, 'commonsense', f'persona_seed:{self.args.persona_seed_num}', 'final_output.jsonl')) 72 | console.log('[{}] # of Total persona commonsense: {}'.format(self.__class__.__name__, len(persona_CK))) 73 | 74 | if self.args.debug: 75 | persona_CK = random.sample(persona_CK, self.args.debug_sample_num) 76 | 77 | prompts = [] 78 | for ck in tqdm(persona_CK, total=len(persona_CK)): 79 | 80 | persona_attr = ck['persona-attr:sent'] 81 | commonsense = ck['parsed_commonsense_generation'] 82 | relation = ck['commonsense_relation'] 83 | 84 | age = ck['age'] 85 | gender = ck['gender'] 86 | #nationality = ck['nationality'] 87 | birthplace = ck['birthplace'] 88 | residence = ck['residence'] 89 | 90 | birthplace_alpha2_code = ck['birthplace_alpha2_code'] 91 | sampled_name = random.sample(self.name_group[birthplace_alpha2_code][gender], 1)[0] 92 | 93 | demo_sent = "I am a {}-year-old {}. I was born in {}, I currently reside in {}.".format( 94 | #sampled_name, 95 | age, 96 | gender.lower(), 97 | birthplace, residence 98 | ) 99 | 100 | sentence_form = self.sentence_form_template[relation].format( 101 | demo_sent=demo_sent, 102 | persona_attr=persona_attr, 103 | commonsense=commonsense, 104 | name=sampled_name 105 | ) 106 | 107 | prompt = self.sentence_to_narrative_template.format( 108 | sentence=f'{sentence_form}', 109 | ) 110 | 111 | cp_instance = copy.deepcopy(ck) 112 | cp_instance[f'{self.prompt_prefix}_sentence_form'] = sentence_form 113 | cp_instance[f'{self.prompt_prefix}_prompt'] = prompt 114 | cp_instance['name'] = sampled_name 115 | 116 | prompts.append(cp_instance) 117 | 118 | return prompts 119 | 120 | def parse_and_filter(self, generations): 121 | self.dump_output(generations, os.path.join(self.save_dir, 'final_output.jsonl')) 122 | 123 | def _generate_initial_narrative(self, prompts, prompt_prefix=None): 124 | return self.interact(prompts, prompt_prefix=prompt_prefix) -------------------------------------------------------------------------------- /execute_retrieval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import copy 5 | import time 6 | import torch 7 | import random 8 | import argparse 9 | import warnings 10 | from tqdm import tqdm 11 | import pandas as pd 12 | from pathlib import Path 13 | 14 | import clip 15 | import faiss 16 | import yaml 17 | import uuid 18 | import requests 19 | import concurrent.futures 20 | 21 | # Suppress warnings 22 | warnings.filterwarnings("ignore") 23 | 24 | 25 | 26 | IMAGE_DB_SOURCE_NAMES = ["cc12m", "redcaps12m", "mathvision", "chartqa", "ai2d"] #, "gqa", "ocr_vqa", "textvqa", "vg"] 27 | IMAGE_SEARCH_NUM = 5 28 | 29 | def load_json(path): 30 | with open(path, 'r', encoding='utf-8') as f: 31 | return json.load(f) 32 | 33 | class ImageDBLoader: 34 | """ 35 | Load image list and image indices. 36 | """ 37 | @staticmethod 38 | def load_image_db(): 39 | image_list = dict() 40 | image_indices = dict() 41 | for image_dataset in IMAGE_DB_SOURCE_NAMES: 42 | data_dir = Path(f'../Sonny-PM/prepare_image_db/embeddings_folder/{image_dataset}/metadata') 43 | df = pd.concat( 44 | pd.read_parquet(parquet_file) 45 | for parquet_file in data_dir.glob('*.parquet') 46 | ) 47 | 48 | ind = faiss.read_index(f'../Sonny-PM/prepare_image_db/index_folder/{image_dataset}/knn.index') 49 | image_list[image_dataset] = df['image_path'].tolist() 50 | image_indices[image_dataset] = ind 51 | 52 | return { 53 | 'image_list': image_list, 54 | 'image_indices': image_indices 55 | } 56 | 57 | image_db = ImageDBLoader.load_image_db() 58 | print(f'Load Image DB Done!') 59 | image_mapper = { 60 | 'cc12m': load_json('../Sonny-PM/prepare_image_db/image_mapper/cc12m.json'), 61 | 'redcaps12m': load_json('../Sonny-PM/prepare_image_db/image_mapper/redcaps12m.json') 62 | } 63 | print(f'Load image mapper Done!') 64 | 65 | MODULE_MAPPER = { 66 | 't2i': 'sdxl-lightning', 67 | 'p-t2i': 'photomaker', 68 | 'web': 'bing', 69 | 'retrieval': 'image_db' 70 | } 71 | 72 | def load_openai_clip_model(device: str): 73 | model, _ = clip.load('ViT-L/14@336px', device=device, jit=False) 74 | return model 75 | 76 | model = load_openai_clip_model('cuda:0') 77 | 78 | @torch.no_grad 79 | def retrieve_image(instance, persona_seed_num): 80 | image_desc = instance['image_description'] 81 | image_uuid = '{}:{}'.format(persona_seed_num, instance['image_uuid']) 82 | 83 | desc_tokens = clip.tokenize(image_desc, truncate=True) 84 | desc_feats = model.encode_text(desc_tokens.to('cuda:0')) 85 | desc_feats /= desc_feats.norm(dim=-1, keepdim=True) 86 | desc_embeds = desc_feats.cpu().detach().numpy().astype('float32') 87 | #print("Done get embedding") 88 | 89 | image_search_result = dict() 90 | for src_name in IMAGE_DB_SOURCE_NAMES: 91 | D, I = image_db['image_indices'][src_name].search(desc_embeds, IMAGE_SEARCH_NUM) 92 | 93 | tmp_result = [] 94 | for item_D, item_I in zip(D[0], I[0]): 95 | if src_name in ['redcaps12m', 'cc12m']: 96 | target_mapper = image_mapper[src_name] 97 | 98 | tmp_result.append({ 99 | 'image_path_from_db': image_db['image_list'][src_name][item_I], #target_mapper[image_db['image_list'][src_name][item_I]], 100 | 'clip_score': str(item_D) 101 | }) 102 | else: 103 | tmp_result.append({ 104 | 'image_path_from_db': image_db['image_list'][src_name][item_I], 105 | 'clip_score': str(item_D) 106 | }) 107 | image_search_result[src_name] = tmp_result 108 | 109 | return image_search_result 110 | 111 | def batch_images(dataset, persona_seed_num): 112 | 113 | target_dataset, non_target_dataset = [], [] 114 | for instance in tqdm(dataset, total=len(dataset)): 115 | module = instance['image_alignment_module'] 116 | model_id = MODULE_MAPPER[module] 117 | if model_id == 'image_db': 118 | 119 | target_dataset.append(instance) 120 | else: 121 | non_target_dataset.append(instance) 122 | 123 | final_dataset = [] 124 | for target_instance in tqdm(target_dataset, total=len(target_dataset)): 125 | retrieved_results = retrieve_image(target_instance, persona_seed_num) 126 | cp_instance = copy.deepcopy(target_instance) 127 | cp_instance['db_searched_results'] = retrieved_results 128 | final_dataset.append(cp_instance) 129 | 130 | return final_dataset + non_target_dataset 131 | 132 | if __name__ == '__main__': 133 | parser = argparse.ArgumentParser() 134 | parser.add_argument('--start-idx', type=int) 135 | parser.add_argument('--end-idx', type=int) 136 | args = parser.parse_args() 137 | 138 | for persona_seed_num in range(args.start_idx, args.end_idx): 139 | 140 | dataset = load_json(f'curated_stark/planner-parsed-openai/stark_{persona_seed_num}.json') 141 | generations = batch_images(dataset, persona_seed_num) 142 | 143 | data_save_path = 'curated_stark/plan-and-execute/image_db' 144 | os.makedirs(data_save_path, exist_ok=True) 145 | 146 | with open(os.path.join(data_save_path, f'stark_{persona_seed_num}.json'), 'w', encoding='utf-8') as f: 147 | json.dump(generations, f, ensure_ascii=False, indent='\t') 148 | 149 | -------------------------------------------------------------------------------- /runner/face_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import copy 4 | import random 5 | from tqdm import tqdm 6 | from glob import glob 7 | 8 | from .base_runner import BaseRunner, console 9 | from utils.etc_utils import load_jsonl, load_txt 10 | 11 | 12 | SYSTEM_MESSAGE = """Given the profile information, your job is to generate a detailed description of a human that includes specific details such as background, tops, bottoms, hair, and shoes. 13 | 14 | For example, 15 | 16 | Profile Information: 17 | - Age: 32 18 | - Gender: Woman 19 | - Nationality: South Korea 20 | 21 | Human Description: A full-body shot, an Asian adult female, fit, small road with trees, straight red above-chest hair, normal-length, white and long sleeve cotton shirt, short plaid skirt in pleated shape, cotton backpack, socks, black leather oxford shoes.""" 22 | 23 | STYLE_TYPES = [ 24 | "full-body", "upper body", "portrait", "headshot", "nearly full-body" 25 | ] 26 | 27 | HUMAN_ATTR_KEY_ORDER = { 28 | 'body shape': ['body shape'], 29 | 'background': ['background'], 30 | 'hair': ['hair style', 'hair color', 'hair length'], 31 | 'special clothings': ['sleeve length', 'type'], 32 | 'one-piece outfits': ['shoulder exposure level', 'length', 'collar shape', 'sleeve length', 'material', 'pattern', 'type'], 33 | 'tops': ['graphic', 'color', 'collar shape', 'top length', 'sleeve length', 'material', 'pattern', 'type'], 34 | 'coats': ['graphic', 'color', 'collar shape', 'coat length', 'material', 'pattern', 'type'], 35 | 'bottoms': ['graphic', 'color', 'bottom shape', 'length', 'material', 'pattern', 'type'], 36 | 'shoes': ['color', 'boots length', 'material', 'pattern', 'type'], 37 | 'bags': ['material', 'type'], 38 | 'hats': ['material', 'type'], 39 | 'belts': ['material'], 40 | 'scarf': ['material', 'pattern'], 41 | 'headband': ['material', 'pattern'], 42 | 'headscarf': ['material', 'pattern'], 43 | 'veil': ['material', 'pattern'], 44 | 'socks': ['material', 'pattern'], 45 | 'ties': ['material', 'pattern'] 46 | } 47 | 48 | class FaceRunner(BaseRunner): 49 | def __init__(self, args): 50 | super().__init__(args) 51 | 52 | self.save_dir = os.path.join(self.output_base_dir, 'face', f'persona_seed:{args.persona_seed_num}') 53 | os.makedirs(self.save_dir, exist_ok=True) 54 | 55 | self._load_human_attribute_pool() 56 | self._load_prompt_template() 57 | self.last_save_chunk_idx_file = os.path.join(self.save_dir, 'last_save_chunk_idx.txt') 58 | 59 | @property 60 | def system_msg(self): 61 | return SYSTEM_MESSAGE 62 | 63 | @property 64 | def prompt_prefix(self): 65 | return "face" 66 | 67 | def _load_prompt_template(self): 68 | self.face_template = load_txt('./templates/face.txt') 69 | 70 | def _load_human_attribute_pool(self): 71 | self.human_attribute_pool = load_jsonl(os.path.join('./datasets/cosmic/human_attribute_pool.jsonl')) 72 | 73 | def make_human_attribute_sentence(self, human_attr, age, gender, birthplace): 74 | style_type = random.sample(STYLE_TYPES, 1)[0] 75 | gender = gender.lower() 76 | 77 | template = f'A {style_type} shot, a {age}-years-old {gender} from {birthplace},' #-years-old 78 | for category, attribute in human_attr.items(): 79 | if category == 'face': 80 | continue 81 | if category == 'overall-style': 82 | continue 83 | 84 | if category == 'hair': 85 | assert 'wears' not in attribute.keys() 86 | order_keys = HUMAN_ATTR_KEY_ORDER[category] 87 | for order_key in order_keys: 88 | try: 89 | template += ' {}'.format(attribute[order_key]) 90 | except KeyError as e: 91 | continue 92 | #for k, v in attribute.items(): 93 | # template += f' {v}' 94 | 95 | if category == 'hair': 96 | template += ' hair' 97 | 98 | if category == 'belts': 99 | template += ' belt' 100 | 101 | if category == 'scarf': 102 | template += ' scarf' 103 | 104 | if category == 'headband': 105 | template += ' headband' 106 | 107 | if category == 'headscarf': 108 | template += ' headscarf' 109 | if category == 'veil': 110 | template += ' veil' 111 | if category == 'socks': 112 | template += ' socks' 113 | if category == 'ties': 114 | template += ' tie' 115 | 116 | template += ',' 117 | 118 | template = template[:-1] + '.' 119 | 120 | return template 121 | 122 | def prepare_prompt(self): 123 | persona_sentence = load_jsonl(os.path.join(self.output_base_dir, 'persona-attr', f'final_output_{self.args.persona_seed_num}.jsonl')) 124 | console.log('[{}] # of Total persona sentence: {}'.format(self.__class__.__name__, len(persona_sentence))) 125 | 126 | if self.args.debug: 127 | persona_sentence = random.sample(persona_sentence, self.args.debug_sample_num) 128 | 129 | sample_num = len(persona_sentence) 130 | sampled_human_attribute = random.sample(self.human_attribute_pool, sample_num) 131 | assert len(persona_sentence) == len(sampled_human_attribute) 132 | 133 | prompts = [] 134 | for idx, instance in enumerate(tqdm(persona_sentence, total=len(persona_sentence))): 135 | human_attr = sampled_human_attribute[idx] 136 | 137 | human_attr_sent = self.make_human_attribute_sentence( 138 | human_attr, 139 | instance['age'], instance['gender'], instance['birthplace'] 140 | ) 141 | 142 | cp_instance = copy.deepcopy(instance) 143 | cp_instance[f'{self.prompt_prefix}_prompt'] = human_attr_sent 144 | 145 | prompts.append(cp_instance) 146 | 147 | return prompts 148 | 149 | def parse_and_filter(self, generations): 150 | self.dump_output(generations, os.path.join(self.save_dir, 'final_output.jsonl')) 151 | 152 | def _generate_album(self, prompts, prompt_prefix=None): 153 | return self.interact(prompts, prompt_prefix=prompt_prefix) -------------------------------------------------------------------------------- /runner/album_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import copy 4 | import random 5 | from tqdm import tqdm 6 | from glob import glob 7 | from collections import defaultdict 8 | 9 | from .base_runner import BaseRunner, console 10 | from utils.etc_utils import load_jsonl, load_txt 11 | 12 | 13 | SYSTEM_MESSAGE = """Given the sentence related to a person's daily life, your task is to generate five image descriptions that could be stored on the person's mobile device, along with corresponding image categories. You should use the format " (Category: )". The image category may include selfies, past memories, screenshots, landmarks, animals, art, celebrities, nature, and food. 14 | 15 | For example, 16 | 17 | My name is Tom. I am a 32-year-old man. I was born in the USA and currently reside there. I have a strong interest in basketball. I played basketball in middle school, but now I work as a chatbot developer at a startup. I enjoy watching the NBA because I love basketball. 18 | 19 | Image descriptions stored on Tom's mobile device: 20 | 1. A photo of a young Tom playing basketball in a middle school gymnasium (Category: Past Memory, Sport) 21 | 2. A selfie of Tom smiling at the Golden State Warriors' arena during a game (Category: Selfie, Sport) 22 | 3. A screenshot of chatbot development code using Python (Category: Screenshot, Computer, Software) 23 | 4. A picture of Tom enjoying a night out with coworkers at a local pub (Category: Social Networking, Food, Drink) 24 | 5. A photo of Tom meeting a famous NBA player at a basketball event (Category: Celebrity, Sport)""" 25 | 26 | class AlbumRunner(BaseRunner): 27 | def __init__(self, args): 28 | super().__init__(args) 29 | 30 | self.save_dir = os.path.join(self.output_base_dir, 'mobile-device', f'persona_seed:{args.persona_seed_num}') 31 | os.makedirs(self.save_dir, exist_ok=True) 32 | 33 | self._load_prompt_template() 34 | self.last_save_chunk_idx_file = os.path.join(self.save_dir, 'last_save_chunk_idx.txt') 35 | 36 | @property 37 | def system_msg(self): 38 | return SYSTEM_MESSAGE 39 | 40 | @property 41 | def prompt_prefix(self): 42 | return "mobile-device" 43 | 44 | def _load_prompt_template(self): 45 | self.album_template = load_txt('./templates/mobile_device.txt') 46 | 47 | def prepare_prompt(self): 48 | 49 | event_graph = load_jsonl(os.path.join(self.output_base_dir, 'event-graph', f'persona_seed:{self.args.persona_seed_num}', 'final_output.jsonl')) 50 | console.log('[{}] # of Total event graph: {}'.format(self.__class__.__name__, len(event_graph))) 51 | 52 | prompts = [] 53 | for instance in tqdm(event_graph, total=len(event_graph)): 54 | narrative = instance['narrative_generation'] 55 | 56 | _prompt = self.album_template.format( 57 | name=instance['name'], 58 | #age=instance['demographic:age'], 59 | #gender=instance['demographic:gender'], 60 | #nationality=instance['demographic:nationality'], 61 | sentence=instance['narrative_generation'] 62 | ) 63 | 64 | cp_instance = copy.deepcopy(instance) 65 | cp_instance[f'{self.prompt_prefix}_prompt'] = _prompt 66 | prompts.append(cp_instance) 67 | 68 | return prompts 69 | 70 | def parse_and_filter(self, generations): 71 | 72 | stat = defaultdict(int) 73 | stat['total_num'] = len(generations) * 5 74 | 75 | results = [] 76 | regex_parsed_results, regex_discard_results = [], [] 77 | for generation in tqdm(generations, total=len(generations)): 78 | parsed_results, discard_results = self._parse_mobile_device_generation(generation[f'{self.prompt_prefix}_generation']) 79 | if len(parsed_results) == 0: 80 | for discard_result in discard_results: 81 | cp_generation = copy.deepcopy(generation) 82 | cp_generation['regex:discard_result'] = discard_result 83 | regex_discard_results.append(cp_generation) 84 | continue 85 | 86 | cp_instance = copy.deepcopy(generation) 87 | cp_instance[f'parsed_{self.prompt_prefix}_generation'] = parsed_results 88 | results.append(cp_instance) 89 | 90 | for parsed_result in parsed_results: 91 | cp_generation = copy.deepcopy(generation) 92 | for k, v in parsed_result.items(): 93 | cp_generation[f'{self.prompt_prefix}:{k}'] = parsed_result[k] 94 | regex_parsed_results.append(cp_generation) 95 | 96 | for discard_result in discard_results: 97 | cp_generation = copy.deepcopy(generation) 98 | cp_generation['regex:discard_result'] = discard_result 99 | regex_discard_results.append(cp_generation) 100 | 101 | stat['regex:parsed_result'] = len(regex_parsed_results) 102 | stat['regex:discard_result'] = len(regex_discard_results) 103 | 104 | self.dump_output(results, os.path.join(self.save_dir, 'final_output.jsonl')) 105 | self.dump_output(regex_parsed_results, os.path.join(self.save_dir, 'regex_parsed_output.jsonl')) 106 | self.dump_output(regex_discard_results, os.path.join(self.save_dir, 'regex_discard_output.jsonl')) 107 | self.dump_report(stat, os.path.join(self.save_dir, 'report_output.txt')) 108 | 109 | def _parse_mobile_device_generation(self, generation): 110 | # First, split the generation based on the number prefix (e.g., 1., 2.) 111 | delims = [f'\n{i}. ' for i in range(1, 6)] + [f'\n{i}.' for i in range(1, 6)] 112 | splitted_generation = re.split('|'.join(delims), generation) 113 | 114 | # Second, extract the persona-related information using the regex pattern 115 | pattern = '(?P.*) [\(|\[]Category: (?P.*)[\)|\]]' # [] case should be possible 116 | compiled_regex = re.compile(pattern) 117 | 118 | parsed_results = [] 119 | discard = [] 120 | for generation in splitted_generation: 121 | matched = compiled_regex.match(generation) 122 | 123 | if matched: 124 | parsed_results.append(matched.groupdict()) 125 | else: 126 | discard.append(generation) 127 | 128 | return parsed_results, discard -------------------------------------------------------------------------------- /generate_face_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import random 5 | import argparse 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from diffusers import ( 10 | StableDiffusionXLPipeline, 11 | StableDiffusionPipeline, 12 | UNet2DConditionModel, 13 | EulerDiscreteScheduler, 14 | PixArtAlphaPipeline, 15 | DiffusionPipeline, 16 | DDIMScheduler 17 | ) 18 | from safetensors.torch import load_file 19 | from huggingface_hub import hf_hub_download 20 | 21 | 22 | IMAGE_DB_SOURCE_NAMES = ["coco_train2017", "gqa", "ocr_vqa", "textvqa", "vg"] 23 | IMAGE_SEARCH_NUM = 5 24 | 25 | STYLE_TYPES = [ 26 | "full-body", "upper body", "portrait", "headshot", "nearly full-body" 27 | ] 28 | 29 | HUMAN_ATTR_KEY_ORDER = { 30 | 'body shape': ['body shape'], 31 | 'background': ['background'], 32 | 'hair': ['hair style', 'hair color', 'hair length'], 33 | 'special clothings': ['sleeve length', 'type'], 34 | 'one-piece outfits': ['shoulder exposure level', 'length', 'collar shape', 'sleeve length', 'material', 'pattern', 'type'], 35 | 'tops': ['graphic', 'color', 'collar shape', 'top length', 'sleeve length', 'material', 'pattern', 'type'], 36 | 'coats': ['graphic', 'color', 'collar shape', 'coat length', 'material', 'pattern', 'type'], 37 | 'bottoms': ['graphic', 'color', 'bottom shape', 'length', 'material', 'pattern', 'type'], 38 | 'shoes': ['color', 'boots length', 'material', 'pattern', 'type'], 39 | 'bags': ['material', 'type'], 40 | 'hats': ['material', 'type'], 41 | 'belts': ['material'], 42 | 'scarf': ['material', 'pattern'], 43 | 'headband': ['material', 'pattern'], 44 | 'headscarf': ['material', 'pattern'], 45 | 'veil': ['material', 'pattern'], 46 | 'socks': ['material', 'pattern'], 47 | 'ties': ['material', 'pattern'] 48 | } 49 | 50 | 51 | def load_json(datadir): 52 | with open(datadir, 'r', encoding='utf-8') as f: 53 | return json.load(f) 54 | 55 | def load_jsonl(datadir: str): 56 | output = [] 57 | with open(datadir, 'r', encoding='utf-8') as f: 58 | for line in f.readlines(): 59 | output.append(json.loads(line)) 60 | return output 61 | 62 | def dump_json_output(outputs, file_name): 63 | with open(file_name, 'w', encoding='utf-8') as f: 64 | json.dump(outputs, f, ensure_ascii=False, indent='\t') 65 | 66 | def dump_jsonl_output(outputs, file_name=None): 67 | f = open(file_name, 'w', encoding='utf-8') 68 | for output in outputs: 69 | f.write(json.dumps(output) + '\n') 70 | f.close() 71 | 72 | def make_human_attribute_sentence(human_attr, age, gender, birthplace): 73 | style_type = random.choice(STYLE_TYPES) 74 | gender = gender.lower() 75 | template_parts = [f'A {style_type} shot, a {age}-years-old {gender} from {birthplace},'] 76 | 77 | excluded_categories = {'face', 'overall-style'} 78 | suffix_map = { 79 | 'hair': ' hair', 80 | 'belts': ' belt', 81 | 'scarf': ' scarf', 82 | 'headband': ' headband', 83 | 'headscarf': ' headscarf', 84 | 'veil': ' veil', 85 | 'socks': ' socks', 86 | 'ties': ' tie' 87 | } 88 | 89 | for category, attribute in human_attr.items(): 90 | if category in excluded_categories: 91 | continue 92 | 93 | if category == 'hair' and 'wears' in attribute: 94 | continue 95 | 96 | order_keys = HUMAN_ATTR_KEY_ORDER.get(category, []) 97 | for order_key in order_keys: 98 | attr_value = attribute.get(order_key) 99 | if attr_value: 100 | template_parts.append(f' {attr_value}') 101 | 102 | suffix = suffix_map.get(category) 103 | if suffix: 104 | template_parts[-1] += suffix 105 | 106 | template_parts[-1] += ',' 107 | 108 | template = ''.join(template_parts).rstrip(',') + '.' 109 | 110 | return template 111 | 112 | def load_human_attribute_pool(): 113 | return load_jsonl(os.path.join('./datasets/cosmic/human_attribute_pool.jsonl')) 114 | 115 | def load_sdxl_diffusion_model(cache_dir: str, device: str): 116 | base = "stabilityai/stable-diffusion-xl-base-1.0" 117 | repo = "ByteDance/SDXL-Lightning" 118 | ckpt = "sdxl_lightning_8step_unet.safetensors" 119 | 120 | unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, torch.float16) 121 | unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device)) 122 | pipe = StableDiffusionXLPipeline.from_pretrained( 123 | base, unet=unet, torch_dtype=torch.float16, cache_dir=cache_dir, variant="fp16" 124 | ).to(device) 125 | 126 | pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") 127 | return pipe 128 | 129 | def process_face(stark, human_attribute_pool, pipeline, persona_seed_num, save_dir): 130 | total_id_keys = list(set([ele['unique_id'] for ele in stark])) 131 | total_id_num = len(total_id_keys) 132 | sampled_human_attribute = random.sample(human_attribute_pool, total_id_num) 133 | id2human_attr = {key: sampled_human_attribute[i] for i, key in enumerate(total_id_keys)} 134 | 135 | final_results = [] 136 | for idx, instance in enumerate(tqdm(stark, total=len(stark))): 137 | uuid = instance['unique_id'] 138 | age = instance['age'] 139 | gender = instance['gender'] 140 | birthplace = instance['birthplace'] 141 | 142 | human_attr = id2human_attr[uuid] 143 | human_attr_sent = make_human_attribute_sentence(human_attr, age, gender, birthplace) 144 | 145 | face_image = pipeline(human_attr_sent, num_inference_steps=8, guidance_scale=0).images[0] 146 | face_image.save(os.path.join(save_dir, f'{uuid}.png')) 147 | 148 | cp_instance = copy.deepcopy(instance) 149 | cp_instance['face_description'] = human_attr_sent 150 | cp_instance['face_image_path'] = os.path.join(save_dir, f'{uuid}.png') 151 | final_results.append(cp_instance) 152 | 153 | return final_results 154 | 155 | if __name__ == '__main__': 156 | 157 | parser = argparse.ArgumentParser() 158 | parser.add_argument('--start-idx', type=int) 159 | parser.add_argument('--end-idx', type=int) 160 | parser.add_argument('--device', type=str) 161 | args = parser.parse_args() 162 | 163 | human_attribute_pool = load_human_attribute_pool() 164 | pipeline = load_sdxl_diffusion_model('./pretrained_diffusion_model', args.device) 165 | 166 | for persona_seed_num in range(args.start_idx, args.end_idx): 167 | stark = load_json(os.path.join(f'./Stark/post-process/stark_{persona_seed_num}.json')) 168 | 169 | save_dir = f'generated_image/human-face/stark_{persona_seed_num}' 170 | os.makedirs(save_dir, exist_ok=True) 171 | 172 | processed_results = process_face(stark, human_attribute_pool, pipeline, persona_seed_num, save_dir) 173 | 174 | curated_save_dir = f'curated_stark/human-face' 175 | os.makedirs(curated_save_dir, exist_ok=True) 176 | 177 | dump_json_output(processed_results, os.path.join(curated_save_dir, f'stark_{persona_seed_num}.json')) -------------------------------------------------------------------------------- /execute_photomaker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import re 4 | import json 5 | import copy 6 | import time 7 | import torch 8 | import random 9 | import argparse 10 | import warnings 11 | from tqdm import tqdm 12 | import pandas as pd 13 | from pathlib import Path 14 | import concurrent.futures 15 | 16 | import torch 17 | from accelerate import PartialState 18 | from diffusers import ( 19 | StableDiffusionXLPipeline, 20 | StableDiffusionPipeline, 21 | UNet2DConditionModel, 22 | EulerDiscreteScheduler, 23 | PixArtAlphaPipeline, 24 | DiffusionPipeline, 25 | DDIMScheduler 26 | ) 27 | from diffusers.utils import load_image 28 | from safetensors.torch import load_file 29 | from huggingface_hub import hf_hub_download 30 | from accelerate.utils import gather_object 31 | from photomaker import PhotoMakerStableDiffusionXLPipeline 32 | 33 | 34 | 35 | # Suppress warnings 36 | warnings.filterwarnings("ignore") 37 | 38 | MODULE_MAPPER = { 39 | 't2i': 'sdxl-lightning', 40 | 'p-t2i': 'photomaker', 41 | 'web': 'bing', 42 | 'retrieval': 'image_db' 43 | } 44 | 45 | def load_json(path): 46 | with open(path, 'r', encoding='utf-8') as f: 47 | return json.load(f) 48 | 49 | def load_sdxl_model(cache_dir: str, device: str): 50 | """Load the SDXL Lightning diffusion model.""" 51 | base = "stabilityai/stable-diffusion-xl-base-1.0" 52 | repo = "ByteDance/SDXL-Lightning" 53 | ckpt = "sdxl_lightning_8step_unet.safetensors" 54 | unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, torch.float16) 55 | unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device)) 56 | pipe = StableDiffusionXLPipeline.from_pretrained( 57 | base, unet=unet, torch_dtype=torch.float16, cache_dir=cache_dir, variant="fp16" 58 | ) 59 | pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") 60 | return pipe.to(device) 61 | 62 | distributed_state = PartialState() 63 | 64 | def memory_optimization(): 65 | # memory deallocation 66 | gc.collect() 67 | 68 | # removing cache 69 | torch.cuda.empty_cache() 70 | 71 | def load_photomaker_model(cache_dir: str): 72 | base_model_path = 'SG161222/RealVisXL_V3.0' 73 | 74 | photomaker_ckpt = hf_hub_download(repo_id="TencentARC/PhotoMaker", cache_dir=cache_dir, filename="photomaker-v1.bin", repo_type="model") 75 | 76 | pipe = PhotoMakerStableDiffusionXLPipeline.from_pretrained( 77 | base_model_path, 78 | torch_dtype=torch.bfloat16, 79 | use_safetensors=True, 80 | variant="fp16", 81 | ).to(distributed_state.device) 82 | 83 | pipe.load_photomaker_adapter( 84 | os.path.dirname(photomaker_ckpt), 85 | subfolder="", 86 | weight_name=os.path.basename(photomaker_ckpt), 87 | trigger_word="img" 88 | ) 89 | pipe.id_encoder.to(distributed_state.device) #device) 90 | 91 | pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) 92 | pipe.fuse_lora() 93 | 94 | return pipe 95 | 96 | cache_dir = './pretrained_diffusion_model' 97 | 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument('--start-idx', type=int) 100 | parser.add_argument('--end-idx', type=int) 101 | args = parser.parse_args() 102 | 103 | model = load_photomaker_model(cache_dir) 104 | 105 | negative_prompt = "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" 106 | generator = torch.Generator(device=distributed_state.device).manual_seed(42) 107 | 108 | @torch.inference_mode() 109 | def generate_image(target_instance, SAVE_PATH): 110 | 111 | target_image_uuid = target_instance['image_uuid'] 112 | 113 | 114 | num_steps = 50 115 | style_strength_ratio = 20 116 | start_merge_step = int(float(style_strength_ratio) / 100 * num_steps) 117 | if start_merge_step > 30: 118 | start_merge_step = 30 119 | 120 | input_id_path = target_instance['face_image_path'] 121 | input_id_images = [load_image(input_id_path)] 122 | images = model( 123 | prompt=target_instance['modified_image_description'], 124 | input_id_images=input_id_images, 125 | negative_prompt=negative_prompt, 126 | num_images_per_prompt=1, 127 | num_inference_steps=num_steps, 128 | start_merge_step=start_merge_step, 129 | generator=generator, 130 | #guidance_scale=5, 131 | ).images 132 | 133 | save_paths = [] 134 | for idx, image in enumerate(images): 135 | save_paths.append(os.path.join(SAVE_PATH, f'{idx}:{target_image_uuid}.png')) 136 | image.save(os.path.join(SAVE_PATH, f'{idx}:{target_image_uuid}.png')) 137 | 138 | cp_instance = copy.deepcopy(target_instance) 139 | cp_instance['image_save_paths'] = save_paths 140 | return cp_instance 141 | 142 | 143 | def batch_images(dataset, SAVE_PATH): 144 | 145 | target_dataset, non_target_dataset = [], [] 146 | for instance in tqdm(dataset, total=len(dataset)): 147 | module = instance['image_alignment_module'] 148 | model_id = MODULE_MAPPER[module] 149 | 150 | if model_id == 'photomaker': 151 | if os.path.exists(os.path.join(SAVE_PATH, '0:{}.png'.format(instance['image_uuid']))): 152 | non_target_dataset.append(instance) 153 | continue 154 | target_dataset.append(instance) 155 | else: 156 | non_target_dataset.append(instance) 157 | 158 | print('# of total dataset:', len(dataset)) 159 | print('# of target dataset:', len(target_dataset)) 160 | print('# of non-target dataset:', len(non_target_dataset)) 161 | 162 | completions_per_process = [] 163 | with distributed_state.split_between_processes(target_dataset) as batched_prompts: 164 | for batch in tqdm(batched_prompts, total=len(batched_prompts)): 165 | 166 | memory_optimization() 167 | result = generate_image(batch, SAVE_PATH) 168 | completions_per_process.append(result) 169 | 170 | completions_gather = gather_object(completions_per_process) 171 | completions = completions_gather[: len(target_dataset)] 172 | 173 | memory_optimization() 174 | 175 | print('# of final dataset:', len(completions) + len(non_target_dataset)) 176 | return completions + non_target_dataset 177 | 178 | if __name__ == '__main__': 179 | 180 | for persona_seed_num in range(args.start_idx, args.end_idx): 181 | dataset = load_json(f'curated_stark/planner-parsed-openai/stark_{persona_seed_num}.json') 182 | SAVE_PATH = f'generated_image/plan-and-execute/generator/stark_{persona_seed_num}' 183 | os.makedirs(SAVE_PATH, exist_ok=True) 184 | 185 | generations = batch_images(dataset, SAVE_PATH) 186 | 187 | data_save_path = 'curated_stark/plan-and-execute/generator' 188 | os.makedirs(data_save_path, exist_ok=True) 189 | 190 | with open(os.path.join(data_save_path, f'stark_{persona_seed_num}.json'), 'w', encoding='utf-8') as f: 191 | json.dump(generations, f, ensure_ascii=False, indent='\t') 192 | -------------------------------------------------------------------------------- /scripts/run_mcu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | python generate_sonny_dataset.py --run-id sonny_v1 \ 5 | --diffusion-model-id jialuliluka/selma-xl \ 6 | --runner-name image \ 7 | --cache-dir pretrained_diffusion_model \ 8 | --model gpt-3.5-turbo-0125 \ 9 | --persona-seed-num demo 10 | 11 | stabilityai/stable-diffusion-xl-base-1.0 \ 12 | 13 | python generate_sonny_dataset.py --run-id sonny_v4 \ 14 | --model gpt-3.5-turbo-0125 \ 15 | --temperature 0.9 \ 16 | --top-p 1.0 \ 17 | --frequency-penalty .0 \ 18 | --presence-penalty 0. \ 19 | --max-tokens 4096 \ 20 | --runner-name dialogue \ 21 | --persona-seed-num 0 \ 22 | --do-parse-filter 23 | 24 | --debug \ 25 | --debug-sample-num 10 \ 26 | --shard-num 1 \ 27 | 28 | python generate_sonny_dataset.py --run-id sonny_v4 \ 29 | --model gpt-3.5-turbo-0125 \ 30 | --temperature 0.9 \ 31 | --top-p 1.0 \ 32 | --frequency-penalty 0.4 \ 33 | --presence-penalty 0.4 \ 34 | --max-tokens 1024 \ 35 | --runner-name face-image \ 36 | --debug \ 37 | --debug-sample-num 100 \ 38 | --shard-num 1 \ 39 | --persona-seed-num 10 \ 40 | --do-parse-filter 41 | 42 | python generate_sonny_dataset.py --run-id sonny_v2 \ 43 | --diffusion-model-id jialuliluka/selma-xl \ 44 | --runner-name album-image \ 45 | --cache-dir pretrained_diffusion_model \ 46 | --model gpt-3.5-turbo-0125 \ 47 | --do-parse-filter 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | python generate_stark_dialogue.py --run-id stark_v1 \ 58 | --model gpt-3.5-turbo-0125 \ 59 | --temperature 0.9 \ 60 | --top-p 1.0 \ 61 | --frequency-penalty 0.0 \ 62 | --presence-penalty 0.0 \ 63 | --max-tokens 2048 \ 64 | --runner-name persona-attr \ 65 | --shard-num 1 \ 66 | --debug \ 67 | --debug-sample-num 10 68 | 69 | python generate_stark_dialogue.py --run-id stark_v1 \ 70 | --model gpt-3.5-turbo-0125 \ 71 | --temperature 0.9 \ 72 | --top-p 1.0 \ 73 | --frequency-penalty 0.0 \ 74 | --presence-penalty 0.0 \ 75 | --max-tokens 2048 \ 76 | --runner-name persona-attr \ 77 | --shard-num 1 \ 78 | --do-parse-filter 79 | 80 | 81 | python generate_stark_dialogue.py --run-id stark_v1 \ 82 | --model gpt-3.5-turbo-0125 \ 83 | --temperature 0.9 \ 84 | --top-p 1.0 \ 85 | --frequency-penalty 0.0 \ 86 | --presence-penalty 0.0 \ 87 | --max-tokens 2048 \ 88 | --runner-name face \ 89 | --persona-seed-num 0 \ 90 | --shard-num 1 91 | 92 | python generate_stark_dialogue.py --run-id stark_v1 \ 93 | --model gpt-3.5-turbo-0125 \ 94 | --temperature 0.9 \ 95 | --top-p 1.0 \ 96 | --frequency-penalty 0.0 \ 97 | --presence-penalty 0.0 \ 98 | --max-tokens 2048 \ 99 | --runner-name face \ 100 | --persona-seed-num 0 \ 101 | --shard-num 1 \ 102 | --do-parse-filter 103 | 104 | 105 | 106 | python generate_stark_dialogue.py --run-id stark_v1 \ 107 | --model gpt-3.5-turbo-0125 \ 108 | --temperature 0.9 \ 109 | --top-p 1.0 \ 110 | --frequency-penalty .0 \ 111 | --presence-penalty .0 \ 112 | --max-tokens 1024 \ 113 | --runner-name commonsense \ 114 | --persona-seed-num 0 \ 115 | --debug \ 116 | --debug-sample-num 20 117 | 118 | python generate_stark_dialogue.py --run-id stark_v1 \ 119 | --model gpt-3.5-turbo-0125 \ 120 | --temperature 0.9 \ 121 | --top-p 1.0 \ 122 | --frequency-penalty .0 \ 123 | --presence-penalty .0 \ 124 | --max-tokens 1024 \ 125 | --runner-name commonsense \ 126 | --persona-seed-num 0 \ 127 | --do-parse-filter 128 | 129 | 130 | 131 | 132 | python generate_stark_dialogue.py --run-id stark_v1 \ 133 | --model gpt-3.5-turbo-0125 \ 134 | --temperature 0.9 \ 135 | --top-p 0.95 \ 136 | --frequency-penalty 1.0 \ 137 | --presence-penalty 0.6 \ 138 | --max-tokens 2048 \ 139 | --runner-name narrative \ 140 | --debug \ 141 | --debug-sample-num 20 \ 142 | --persona-seed-num 0 143 | 144 | python generate_stark_dialogue.py --run-id stark_v1 \ 145 | --model gpt-3.5-turbo-0125 \ 146 | --temperature 0.9 \ 147 | --top-p 0.95 \ 148 | --frequency-penalty 1.0 \ 149 | --presence-penalty 0.6 \ 150 | --max-tokens 2048 \ 151 | --runner-name narrative \ 152 | --debug \ 153 | --debug-sample-num 20 \ 154 | --persona-seed-num 0 \ 155 | --do-parse-filter 156 | 157 | 158 | 159 | python generate_stark_dialogue.py --run-id stark_v1 \ 160 | --model gpt-3.5-turbo-0125 \ 161 | --temperature 0.9 \ 162 | --top-p 1.0 \ 163 | --frequency-penalty 0. \ 164 | --presence-penalty 0. \ 165 | --max-tokens 4096 \ 166 | --runner-name event \ 167 | --debug \ 168 | --debug-sample-num 20 \ 169 | --persona-seed-num 0 170 | 171 | python generate_stark_dialogue.py --run-id stark_v1 \ 172 | --model gpt-3.5-turbo-0125 \ 173 | --temperature 0.9 \ 174 | --top-p 1.0 \ 175 | --frequency-penalty 0. \ 176 | --presence-penalty 0. \ 177 | --max-tokens 4096 \ 178 | --runner-name event \ 179 | --debug \ 180 | --debug-sample-num 20 \ 181 | --persona-seed-num 0 \ 182 | --do-parse-filter 183 | 184 | 185 | 186 | 187 | python generate_stark_dialogue.py --run-id stark_v1 \ 188 | --model gpt-3.5-turbo-0125 \ 189 | --temperature 0.9 \ 190 | --top-p 1.0 \ 191 | --frequency-penalty 0.0 \ 192 | --presence-penalty 0.0 \ 193 | --max-tokens 1024 \ 194 | --runner-name album \ 195 | --debug \ 196 | --debug-sample-num 20 \ 197 | --persona-seed-num 0 198 | 199 | python generate_stark_dialogue.py --run-id stark_v1 \ 200 | --model gpt-3.5-turbo-0125 \ 201 | --temperature 0.9 \ 202 | --top-p 1.0 \ 203 | --frequency-penalty 0.0 \ 204 | --presence-penalty 0.0 \ 205 | --max-tokens 1024 \ 206 | --runner-name album \ 207 | --debug \ 208 | --debug-sample-num 20 \ 209 | --persona-seed-num 0 \ 210 | --do-parse-filter 211 | 212 | 213 | for session_num in {1..6}; do 214 | python generate_stark_dialogue.py --run-id stark_v1 \ 215 | --model gpt-3.5-turbo-0125 \ 216 | --temperature 0.9 \ 217 | --top-p 1.0 \ 218 | --frequency-penalty 0.0 \ 219 | --presence-penalty 0.0 \ 220 | --max-tokens 4096 \ 221 | --runner-name dialogue \ 222 | --debug \ 223 | --debug-sample-num 20 \ 224 | --persona-seed-num 0 \ 225 | --target-session-num "$session_num" \ 226 | 227 | python generate_stark_dialogue.py --run-id stark_v1 \ 228 | --model gpt-3.5-turbo-0125 \ 229 | --temperature 0.9 \ 230 | --top-p 1.0 \ 231 | --frequency-penalty 0.0 \ 232 | --presence-penalty 0.0 \ 233 | --max-tokens 4096 \ 234 | --runner-name dialogue \ 235 | --debug \ 236 | --debug-sample-num 20 \ 237 | --persona-seed-num 0 \ 238 | --target-session-num "$session_num" \ 239 | --do-parse-filter 240 | done 241 | 242 | python make_final_dataset.py 243 | 244 | python postprocess_final_dataset.py 245 | 246 | python generate_face_image.py \ 247 | --start-idx 0 \ 248 | --end-idx 1 \ 249 | --device cuda:0 250 | 251 | python plan_runner.py \ 252 | --start-idx 0 \ 253 | --end-idx 1 254 | 255 | python plan_runner.py \ 256 | --start-idx 0 \ 257 | --end-idx 1 \ 258 | --do-planner 259 | 260 | python execute_photomaker.py \ 261 | --start-idx 0 \ 262 | --end-idx 1 263 | 264 | python execute_sdxl.py \ 265 | --start-idx 0 \ 266 | --end-idx 1 267 | 268 | python execute_retrieval.py \ 269 | --start-idx 0 \ 270 | --end-idx 1 271 | 272 | python execute_web_search.py \ 273 | --start-idx 0 \ 274 | --end-idx 1 -------------------------------------------------------------------------------- /runner/event_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import random 4 | from tqdm import tqdm 5 | from collections import defaultdict 6 | from glob import glob 7 | 8 | from .base_runner import BaseRunner, console 9 | from utils.etc_utils import load_jsonl, load_txt 10 | 11 | 12 | SYSTEM_MESSAGE = """You should generate a temporal event graph composed of daily events occuring in a person's life. The temporal event graph contains nodes and edges. Each node represents a daily event which is written in two or three sentences. Each edge represents the casual relationship between two nodes (events), i.e., a past event -> current event. The current event is determined by how much time has passed since the past event and what personal experiences were had during that period. You must generate the temporal event graph following the guidelines below. 13 | 14 | [Guideline] 15 | - The graph is represented in the form of a json list. 16 | - Each entry is a python dictionary containing the following keys: "id", "event", "date", "caused_by". 17 | - The "id" field contains a unique identifier for the current event. 18 | - The "event" field contains a description of the current event. 19 | - The "date" field contains a specific date of the current event and is represented in the form of "%Y.%m.%d". 20 | - The "caused_by" field represents the edge (i.e., a past event) and is represented in the form of a python dictionary containing the following keys: "caused_by:id", "caused_by:time_interval", "caused_by:experience_op", "caused_by:experience". 21 | - The "caused_by:id" field contains an "id" of the past event that has caused the current event. 22 | - The "caused_by:time_interval" field contains a time interval between the past event and the current event. 23 | - The "caused_by:experience_op" field contains an episodic experience operation. 24 | - The "caused_by:experience" field contains a short description of the added or updated episodic experience. 25 | - The unit of time interval is ["hour", "day", "week", "month", "year"]. 26 | - The selected time interval should be formatted as "