├── assets └── banner.png ├── .gitmodules ├── setup.py ├── evaporate ├── run.sh ├── evaluate_synthetic_utils.py ├── retrieval.py ├── weak_supervision │ ├── ws_utils.py │ ├── pgm.py │ ├── run_ws.py │ └── binary_deps.py ├── prompts.py ├── schema_identification.py ├── utils.py ├── evaluate_profiler.py ├── main.py ├── configs.py ├── profiler_utils.py ├── run_profiler.py └── evaluate_synthetic.py ├── .gitignore ├── README.md ├── example.ipynb └── demo.ipynb /assets/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/evaporate/HEAD/assets/banner.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "metal-evap"] 2 | path = metal-evap 3 | url = git@github.com:simran-arora/metal-evap.git 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | _REQUIRED = [ 4 | "tqdm", 5 | "openai", 6 | "manifest-ml", 7 | "pandas", 8 | "snorkel", 9 | "cvxpy", 10 | "bs4", 11 | "snorkel-metal", 12 | "tensorboardX", 13 | "numpy == 1.20.3", 14 | "networkx == 2.3" 15 | ] 16 | 17 | setup( 18 | name="evaporate", 19 | version="0.0.1", 20 | description="evaporating data lakes with foundation models", 21 | author="simran brandon sabri avanika andrew immanuel chris", 22 | packages=["evaporate"], 23 | install_requires=_REQUIRED, 24 | ) 25 | -------------------------------------------------------------------------------- /evaporate/run.sh: -------------------------------------------------------------------------------- 1 | keys=# INSERT YOUR API KEY(S) HERE 2 | 3 | #evaporate code clse ie 4 | python run_profiler.py \ 5 | --data_lake fda_510ks \ 6 | --num_attr_to_cascade 50 \ 7 | --num_top_k_scripts 10 \ 8 | --train_size 10 \ 9 | --combiner_mode ws \ 10 | --use_dynamic_backoff \ 11 | --KEYS ${keys}\ 12 | --data_dir /data/fda_510ks/data/evaporate/fda-ai-pmas/510k \ 13 | --base_data_dir /data/evaporate/data/fda_510ks \ 14 | --gold_extractions_file /data/evaporate/data/fda_510ks/table.json \ 15 | #evaporate code open ie 16 | python run_profiler.py \ 17 | --data_lake fda_510ks \ 18 | --num_attr_to_cascade 50 \ 19 | --num_top_k_scripts 10 \ 20 | --train_size 10 \ 21 | --combiner_mode ws \ 22 | --use_dynamic_backoff \ 23 | --KEYS ${keys}\ 24 | --do_end_to_end \ 25 | --data_dir /data/fda_510ks/data/evaporate/fda-ai-pmas/510k \ 26 | --base_data_dir /data/evaporate/data/fda_510ks \ 27 | --gold_extractions_file /data/evaporate/data/fda_510ks/table.json \ 28 | -------------------------------------------------------------------------------- /evaporate/evaluate_synthetic_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import Counter, defaultdict 3 | 4 | def text_f1(preds=[], golds=[], attribute= ''): 5 | """Compute average F1 of text spans. 6 | Taken from Squad without prob threshold for no answer. 7 | """ 8 | total_f1 = 0 9 | total_recall = 0 10 | total_prec = 0 11 | f1s = [] 12 | for pred, gold in zip(preds, golds): 13 | if isinstance(pred, list): 14 | pred = ' '.join(pred) # Example way to convert list to string 15 | if isinstance(gold, list): 16 | gold = ' '.join(gold) # Example way to convert list to string 17 | pred_toks = pred.split() 18 | gold_toks = gold.split() 19 | common = Counter(pred_toks) & Counter(gold_toks) 20 | num_same = sum(common.values()) 21 | if len(gold_toks) == 0 or len(pred_toks) == 0: 22 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 23 | total_f1 += int(gold_toks == pred_toks) 24 | f1s.append(int(gold_toks == pred_toks)) 25 | elif num_same == 0: 26 | total_f1 += 0 27 | f1s.append(0) 28 | else: 29 | precision = 1.0 * num_same / len(pred_toks) 30 | recall = 1.0 * num_same / len(gold_toks) 31 | f1 = (2 * precision * recall) / (precision + recall) 32 | total_f1 += f1 33 | total_recall += recall 34 | total_prec += precision 35 | f1s.append(f1) 36 | f1_avg = total_f1 / len(golds) 37 | f1_median = np.percentile(f1s, 50) 38 | return f1_avg, f1_median 39 | 40 | def get_file_attribute(attribute): 41 | attribute = attribute.lower() 42 | attribute = attribute.replace("/", "_").replace(")", "").replace("-", "_") 43 | attribute = attribute.replace("(", "").replace(" ", "_") 44 | if len(attribute) > 30: 45 | attribute = attribute[:30] 46 | return attribute -------------------------------------------------------------------------------- /evaporate/retrieval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, AutoModel 3 | 4 | 5 | def mean_pooling(token_embeddings, mask): 6 | token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.) 7 | sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None] 8 | return sentence_embeddings 9 | 10 | def get_embeddings(sentences): 11 | tokenizer = AutoTokenizer.from_pretrained('facebook/contriever') 12 | model = AutoModel.from_pretrained('facebook/contriever') 13 | 14 | sentences = [ 15 | "Where was Marie Curie born?", 16 | "Maria Sklodowska, later known as Marie Curie, was born on November 7, 1867.", 17 | "Born in Paris on 15 May 1859, Pierre Curie was the son of Eugène Curie, a doctor of French Catholic origin from Alsace." 18 | ] 19 | 20 | # Apply tokenizer 21 | inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') 22 | 23 | # Compute token embeddings 24 | outputs = model(**inputs) 25 | 26 | # Mean pooling 27 | 28 | embeddings = mean_pooling(outputs[0], inputs['attention_mask']) 29 | return embeddings 30 | 31 | def get_most_similarity(target_sentence, sentences): 32 | target_embedding = get_embeddings([target_sentence])[0] 33 | embeddings = get_embeddings(sentences) 34 | max_similarity = torch.nn.functional.cosine_similarity(target_embedding, embeddings, dim = -1) 35 | most_similar_sentence = max_similarity.argmax() 36 | return sentences[most_similar_sentence] 37 | 38 | #write a main function to test the code with if __name__ == "__main__": 39 | if __name__ == "__main__": 40 | target_sentence = "Where was Marie Curie born?" 41 | sentences = [ 42 | "Where was Marie Curie born?", 43 | "Maria Sklodowska, later known as Marie Curie, was born on November 7, 1867.", 44 | "Born in Paris on 15 May 1859, Pierre Curie was the son of Eugène Curie, a doctor of French Catholic origin from Alsace." 45 | ] 46 | print(get_most_similarity(target_sentence, sentences)) 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Custom 132 | .sqlite.cache 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Evaporate 2 | 3 |
4 | Evaporate diagram 5 |
6 | 7 | Code, datasets, and extended writeup for paper [Language Models Enable Simple Systems for Generating Structured Views of Heterogeneous Data Lakes](https://www.vldb.org/pvldb/vol17/p92-arora.pdf). 8 | 9 | ## Setup 10 | 11 | We encourage the use of conda environments: 12 | ``` 13 | conda create --name evaporate python=3.8 14 | conda activate evaporate 15 | ``` 16 | 17 | Clone as follows: 18 | ```bash 19 | # Evaporate code 20 | git clone git@github.com:HazyResearch/evaporate.git 21 | cd evaporate 22 | pip install -e . 23 | 24 | # Weak supervision code 25 | cd metal-evap 26 | git submodule init 27 | git submodule update 28 | pip install -e . 29 | 30 | # Manifest (to install from source, which helps you modify the set of supported models. Otherwise, ``setup.py`` installs ``manifest-ml``) 31 | git clone git@github.com:HazyResearch/manifest.git 32 | cd manifest 33 | pip install -e . 34 | ``` 35 | 36 | ## Datasets 37 | The data used in the paper is hosted on Hugging Face's datasets platform: https://huggingface.co/datasets/hazyresearch/evaporate. 38 | 39 | To download the datasets, run the following commands in your terminal: 40 | ```bash 41 | git lfs install 42 | git clone https://huggingface.co/datasets/hazyresearch/evaporate 43 | ``` 44 | 45 | Or download it via Python: 46 | ```python 47 | from datasets import load_dataset 48 | dataset = load_dataset("hazyresearch/evaporate") 49 | ``` 50 | 51 | The code expects the data to be stored at ``/data/evaporate/`` as specified in ``constants.py`` CONSTANTS, though can be modified. 52 | 53 | 54 | ## Running the code 55 | Run closed IE and open IE using the commands: 56 | 57 | ```cd src/ 58 | bash run.sh 59 | ``` 60 | 61 | The ``keys`` in run.sh can be obtained by registering with the LLM provider. For instance, if you want to run inference with the OpenAI API models, create an account [here](https://openai.com/api/). 62 | 63 | The script includes commands for both closed and open IE runs. To walk through the code, look at ``run_profiler.py``. For open IE, the code first uses ``schema_identification.py`` to generate a list of attributes for the schema. Next, the code iterates through this list to perform extraction using ``profiler.py``. As functions are generated in ``profiler.py``, ``evaluate_profiler.py`` is used to score the function outputs against the outputs of directly prompting the LM on the sample documents. 64 | 65 | 66 | ## Citation 67 | If you use this codebase, or otherwise found our work valuable, please cite: 68 | ``` 69 | @article{arora2023evaporate, 70 | title={Language Models Enable Simple Systems for Generating Structured Views of Heterogeneous Data Lakes}, 71 | author={Arora, Simran and Yang, Brandon and Eyuboglu, Sabri and Narayan, Avanika and Hojel, Andrew and Trummer, Immanuel and R\'e, Christopher}, 72 | journal={arXiv:2304.09433}, 73 | year={2023} 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /evaporate/weak_supervision/ws_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | 4 | def get_probabilties(num_lfs, num_examples, predictions, label_name_to_int): 5 | 6 | lf_array = np.zeros((num_lfs, num_examples)) 7 | golds = [] 8 | 9 | # Collect golds and preds 10 | for i, (k, item) in enumerate(predictions.items()): 11 | preds = item['chosen_answers_lst'] 12 | preds_mapped = [] 13 | for p in preds: 14 | if p in label_name_to_int: 15 | preds_mapped.append(label_name_to_int[p]) 16 | else: 17 | preds_mapped.append(0) 18 | preds = preds_mapped.copy() 19 | for lf_num, p in zip(range(num_lfs), preds): 20 | lf_array[lf_num][i] = p 21 | gold = label_name_to_int[item['gold']] 22 | golds.append(gold) 23 | golds = np.array(golds) 24 | neg_indices, pos_indices = [np.where(golds == -1)[0], np.where(golds == 1)[0]] 25 | indices = { 26 | -1: neg_indices, 27 | 1: pos_indices 28 | } 29 | 30 | # [i, j, k] = Pr(prompt_i = j| y = k) 31 | # Accuracies 32 | lf_accuracies = [] 33 | for i in range(num_lfs): 34 | lf_accuracies.append(np.sum(golds == np.array(lf_array[i]))/num_examples) 35 | print(f"LF Accs: {lf_accuracies}") 36 | 37 | # [i, j, k] = Pr(prompt_i = j| y = k) 38 | classes = label_name_to_int.values() 39 | accs = np.zeros((num_lfs, len(classes), len(classes))) 40 | for p in range(num_lfs): 41 | for i in classes: 42 | for j in classes: 43 | j_idx = j 44 | if j == -1: 45 | j_idx = 0 46 | i_idx = i 47 | if i == -1: 48 | i_idx = 0 49 | accs[p, i_idx, j_idx] = len(np.where(lf_array[p, indices[i]] == j)[0]) / len(indices[i]) 50 | 51 | # Compute probabilities 52 | pos_probs = [] 53 | for i in range(num_lfs): 54 | sub_preds = lf_array[i][pos_indices] 55 | sub_golds = golds[pos_indices] 56 | pos_probs.append(np.sum(sub_golds == np.array(sub_preds))/len(pos_indices)) 57 | print(f"Pos Probs: {pos_probs}") 58 | 59 | neg_probs = [] 60 | for i in range(num_lfs): 61 | sub_preds = lf_array[i][neg_indices] 62 | sub_golds = golds[neg_indices] 63 | neg_probs.append(np.sum(sub_golds == np.array(sub_preds))/len(neg_indices)) 64 | print(f"Neg Probs: {neg_probs}\n\n") 65 | 66 | return lf_accuracies, accs, pos_probs, neg_probs, golds, indices 67 | 68 | 69 | """ Independence Assumption: take the product of probabilities as p(L1, L2, ..., LK | y) """ 70 | 71 | # Pr(y = 1 | lf votes) 72 | def get_cond_probs(votes, y, indices_train, golds_train, accs_train, num_lfs_test): 73 | prop_pos = len(indices_train[1])/len(golds_train) 74 | pr_y = prop_pos if y == 1 else 1 - prop_pos 75 | prod = pr_y 76 | for i in range(num_lfs_test): 77 | if y == -1: 78 | y = 0 79 | prod *= accs_train[i, y, votes[i]] 80 | return prod 81 | 82 | # Pr(y = 1 | lf votes) 83 | def get_probs(votes, indices_train, golds_train, acc_train, num_lfs_test): 84 | votes = [max(v, 0) for v in votes] 85 | numerator = get_cond_probs(votes, 1, indices_train, golds_train, acc_train, num_lfs_test) 86 | denominator = numerator + get_cond_probs(votes, -1, indices_train, golds_train, acc_train, num_lfs_test) 87 | return numerator / denominator 88 | 89 | 90 | def get_nb_accuracy(num_examples_test, num_lfs_test, predictions_test, label_name_to_int, golds_test, indices_train, golds_train, accs_train): 91 | output = np.zeros(num_examples_test) 92 | errors = 0 93 | for i, (k, item) in enumerate(predictions_test.items()): 94 | votes = item['chosen_answers_lst'] 95 | votes_mapped = [] 96 | for v in votes: 97 | if v in label_name_to_int: 98 | votes_mapped.append(label_name_to_int[v]) 99 | else: 100 | votes_mapped.append(0) 101 | votes = votes_mapped.copy() 102 | probs = np.round(get_probs(votes, indices_train, golds_train, accs_train, num_lfs_test)) 103 | output[i] = probs 104 | 105 | # Mean squared error 106 | g = golds_test[i] 107 | if golds_test[i] == -1: 108 | g = 0 109 | error = np.abs(output[i] - g)**2 110 | errors += error 111 | accuracy = 1 - (errors / num_examples_test) 112 | return accuracy, output 113 | 114 | 115 | def estimate_matrix(m, n, L): 116 | E_prod = np.zeros((m, m)) 117 | l_avg = np.zeros(m) 118 | for i in range(n): 119 | l = L[i, :] 120 | l_avg += l 121 | E_prod += np.outer(l, l) 122 | 123 | l_avg = l_avg/n 124 | E_prod = E_prod/n 125 | 126 | cov = E_prod - np.outer(l_avg, l_avg) 127 | 128 | return (E_prod, cov, l_avg) 129 | 130 | 131 | def get_vote_vectors(num_samples, num_lfs, predictions, label_name_to_int): 132 | vectors = np.zeros((num_samples, num_lfs+1), float) 133 | vectors_no_y = np.zeros((num_samples, num_lfs), float) 134 | labels_vector = np.zeros((num_samples, 1), float) 135 | for i, p in enumerate(predictions.values()): 136 | votes = p['chosen_answers_lst'] 137 | votes_mapped = [] 138 | for v in votes: 139 | if v in label_name_to_int: 140 | votes_mapped.append(label_name_to_int[v]) 141 | else: 142 | votes_mapped.append(0) 143 | votes = votes_mapped.copy() 144 | # votes = [max(v, 0) for v in votes] 145 | gold = p['gold'] 146 | gold = label_name_to_int[gold] 147 | vectors_no_y[i] = np.array(votes) 148 | vectors[i] = np.array(votes + [gold]) #- lf_accuracies_train 149 | labels_vector[i] = np.array([gold]) 150 | print(f"Shape: {vectors.shape}") 151 | print(f"Sample: {vectors[0]}") 152 | 153 | return vectors, vectors_no_y, labels_vector 154 | 155 | def get_feature_vector(vote_vectors, include_pairwise=False, include_singletons=True): 156 | feature_vectors = [] 157 | for votes in vote_vectors: 158 | if include_singletons: 159 | feature_vector = list(votes[:]) 160 | else: 161 | feature_vector = [] 162 | if include_pairwise: 163 | for subset in itertools.combinations(votes[:], 2): 164 | feature_vector.append(subset[0] * subset[1]) 165 | feature_vectors.append(feature_vector) 166 | X = np.matrix(feature_vectors) 167 | return X -------------------------------------------------------------------------------- /example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 12, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "UsageError: Line magic function `%autoreload` not found.\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "%autoreload 2\n", 18 | "%load_ext autoreload" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 13, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import time\n", 28 | "from tqdm import tqdm\n", 29 | "import sys\n", 30 | "sys.path.append(f\"./evaporate/\")\n", 31 | "from run_profiler import prerun_profiler, identify_attributes, get_attribute_function\n", 32 | "from configs import set_profiler_args" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "MANIFEST_URL = \"http://127.0.0.1:5000\" # please make sure that a local manifest session is running with your model at this address\n", 42 | "DATA_DIR = \"/var/cr05_data/sim_data/code/evaporate/data/\"" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "# Set Up" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "profiler_args= set_profiler_args({\n", 59 | " \"data_lake\": \"wiki_nba\", \n", 60 | " \"num_attr_to_cascade\": 50, \n", 61 | " \"num_top_k_scripts\": 10, \n", 62 | " \"train_size\": 10, \n", 63 | " \"combiner_mode\": \"mv\", \n", 64 | " \"do_end_to_end\": False,\n", 65 | " \"use_dynamic_backoff\": True, \n", 66 | " \"KEYS\": [\"\"], \n", 67 | " \"MODELS\":[\"mistralai/Mistral-7B-instruct-v0.1\", \"gpt-4\"],\n", 68 | " \"EXTRACTION_MODELS\": [\"mistralai/Mistral-7B-instruct-v0.1\"],\n", 69 | " \"GOLD_KEY\": \"gpt-4\",\n", 70 | " \"MODEL2URL\" : {\"mistralai/Mistral-7B-instruct-v0.1\": MANIFEST_URL},\n", 71 | " \"data_dir\": f\"{DATA_DIR}/documents/\", \n", 72 | " \"base_data_dir\": f\"{DATA_DIR}\", \n", 73 | " \"gold_extractions_file\": f\"{DATA_DIR}/data_fda_510ks_table.json\"}\n", 74 | ")\n", 75 | "print(profiler_args)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "data_dict = prerun_profiler(profiler_args)\n", 85 | "print(data_dict.keys())" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "print(profiler_args.overwrite_cache)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "# Get Schema Attribute" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "if profiler_args.do_end_to_end:\n", 111 | " attributes, total_time, num_toks, evaluation_result = identify_attributes(profiler_args, data_dict, evaluation=True)\n", 112 | " print(\"total_time: \", total_time, \"num_toks: \", num_toks, \"evaluation_result: \", evaluation_result)\n", 113 | "else:\n", 114 | " attributes = data_dict[\"gold_attributes\"]" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "attributes" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "Model Information\n", 131 | "Data Dir\n", 132 | "CONSTANTS = {\n", 133 | " \"fda_510ks\": {\n", 134 | " \"data_dir\": os.path.join(BASE_DATA_DIR, \"fda-ai-pmas/510k/\"),\n", 135 | " \"database_name\": \"fda_510ks\",\n", 136 | " \"cache_dir\": \".cache/fda_510ks/\",\n", 137 | " \"generative_index_path\": os.path.join(BASE_DATA_DIR, \"generative_indexes/fda_510ks/\"),\n", 138 | " \"gold_extractions_file\": os.path.join(BASE_DATA_DIR, \"ground_truth/fda_510ks_gold_extractions.json\"),\n", 139 | " \"topic\": \"fda 510k device premarket notifications\",\n", 140 | " },\n", 141 | " }" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "# Get Attribute function" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "functions = {}\n", 158 | "selected_keys = {}\n", 159 | "for i, attribute in enumerate(attributes):\n", 160 | " print(f\"\\n\\nExtracting {attribute} ({i+1} / {len(attributes)})\")\n", 161 | " t0 = time.time()\n", 162 | " functions[attribute], selected_keys[attribute], total_time, num_toks= get_attribute_function(\n", 163 | " profiler_args, data_dict, attribute\n", 164 | " )\n", 165 | " print(functions[attribute])" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "# Evaluation" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "from evaporate.evaluate_synthetic import main as evaluate_synthetic_main" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "results = evaluate_synthetic_main(\n", 191 | " profiler_args.run_string, \n", 192 | " profiler_args, \n", 193 | " profiler_args, \n", 194 | " profiler_args.data_lake,\n", 195 | " gold_attributes=data_dict[\"gold_attributes\"], \n", 196 | " stage='extract'\n", 197 | ")\n", 198 | "print(results)" 199 | ] 200 | } 201 | ], 202 | "metadata": { 203 | "kernelspec": { 204 | "display_name": "Python 3 (ipykernel)", 205 | "language": "python", 206 | "name": "python3" 207 | }, 208 | "language_info": { 209 | "codemirror_mode": { 210 | "name": "ipython", 211 | "version": 3 212 | }, 213 | "file_extension": ".py", 214 | "mimetype": "text/x-python", 215 | "name": "python", 216 | "nbconvert_exporter": "python", 217 | "pygments_lexer": "ipython3", 218 | "version": "3.8.18" 219 | } 220 | }, 221 | "nbformat": 4, 222 | "nbformat_minor": 4 223 | } 224 | -------------------------------------------------------------------------------- /evaporate/weak_supervision/pgm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | import matplotlib.pyplot as plt 4 | import scipy.stats 5 | 6 | 7 | 8 | class Ising(): 9 | 10 | 11 | def __init__(self, m, potentials, thetas = None, vals = [-1, 1], ) -> None: 12 | self.m = m 13 | self.v = m + 1 # total number of vertices 14 | self.potentials = potentials 15 | 16 | self.vals = vals 17 | #TODO support values in 0, 1 18 | 19 | if thetas is not None: 20 | assert len(thetas) >= len(potentials), f"Need to specify at least {len(potentials)} theta parameters." 21 | self.thetas = thetas 22 | else: 23 | self.thetas = np.random.rand(len(potentials)) 24 | 25 | self.support = np.array(list(map(list, itertools.product(vals, repeat=self.v)))) 26 | 27 | self._make_pdf() 28 | self._make_cdf() 29 | 30 | self._get_means() 31 | self._get_balance() 32 | self._get_accs() 33 | 34 | def _exponential_family(self, labels): 35 | x = 0.0 36 | for i in range(len(self.potentials)): 37 | x += self.thetas[i] * labels[self.potentials[i]].prod() 38 | 39 | return np.exp(x) 40 | 41 | def _make_pdf(self): 42 | p = np.zeros(len(self.support)) 43 | for i, labels in enumerate(self.support): 44 | p[i] = self._exponential_family(labels) 45 | 46 | self.z = sum(p) 47 | 48 | self.pdf = p/self.z 49 | 50 | def _make_cdf(self): 51 | self.cdf = np.cumsum(self.pdf) 52 | 53 | 54 | 55 | 56 | def joint_p(self, C, values): 57 | p = 0.0 58 | for k, labels in enumerate(self.support): 59 | flag = True 60 | for i in range(len(C)): 61 | prod = labels[C[i]].prod() 62 | if prod != values[i]: 63 | flag = False 64 | 65 | if flag == True: 66 | p += self.pdf[k] 67 | 68 | return p 69 | 70 | def expectation(self, C): 71 | return self.vals[0] * self.joint_p(C, self.vals[0] * np.ones(len(C))) + self.vals[1] * self.joint_p(C, self.vals[1] * np.ones(len(C))) 72 | 73 | def _get_means(self): 74 | self.means = np.zeros(self.m) 75 | for k in range(self.m): 76 | self.means[k] = self.expectation([[k]]) 77 | 78 | 79 | def _get_balance(self): 80 | self.balance = self.joint_p([[self.m]], [1]) 81 | 82 | # def _get_covariance(self): 83 | 84 | def _get_accs(self): 85 | """ 86 | self.accs[k, i, j] = Pr(lf_k = j | y = i) (i, j scaled to -1, 1 if needed) 87 | """ 88 | self.accs = np.zeros((self.m, 2, 2)) 89 | for k in range(self.m): 90 | self.accs[k, 1, 1] = self.joint_p([[k], [self.m]], [self.vals[1], self.vals[1]]) / self.balance 91 | self.accs[k, 0, 0] = self.joint_p([[k], [self.m]], [self.vals[0], self.vals[0]]) / (1 - self.balance) 92 | self.accs[k, 1, 0] = 1 - self.accs[k, 1, 1] 93 | self.accs[k, 0, 1] = 1 - self.accs[k, 0, 0] 94 | 95 | 96 | def sample(self): 97 | r = np.random.random_sample() 98 | smaller = np.where(self.cdf < r)[0] 99 | if len(smaller) == 0: 100 | i = 0 101 | else: 102 | i = smaller.max() + 1 103 | 104 | return self.support[i] 105 | 106 | def make_data(self, n, has_label = True): 107 | L = np.zeros((n, self.m)) 108 | gold = np.zeros(n) 109 | for i in range(n): 110 | l = self.sample() 111 | L[i, :] = l[:self.m] 112 | 113 | if has_label: 114 | gold[i] = l[self.m] 115 | 116 | return L.astype(int), gold.astype(int) 117 | 118 | 119 | 120 | def est_accs(m, vote, gold): 121 | # compute pr(lf | y) accuracies. Each prompt has 4 values (2x2) 122 | # we need to do this on the train/dev set 123 | classes = [0, 1] 124 | gold_idxs = [np.where(gold == -1)[0], np.where(gold == 1)[0]] 125 | 126 | accs = np.zeros((m, 2, 2)) # [i, j, k] = Pr(prompt_i = j| y = k) 127 | for p in range(m): 128 | for i in classes: 129 | for j in classes: 130 | accs[p, i, j] = len(np.where(vote[gold_idxs[i], p] == 2*j-1)[0]) / len(gold_idxs[i]) 131 | 132 | return accs 133 | 134 | def est_balance(gold, n): 135 | return len(np.where(gold == 1)[0]) / n 136 | 137 | # Pr(lf votes, y) 138 | def get_cond_probs(m, votes, y, accs, balance): 139 | pr_y = balance if y == 1 else 1 - balance 140 | prod = pr_y 141 | for i in range(m): 142 | prod *= accs[i, y, int(0.5*(votes[i] + 1))] # this assumes everything is independent 143 | return prod 144 | 145 | # Pr(y = 1 | lf votes) 146 | def get_probs(m, votes, accs, balance): 147 | pos = get_cond_probs(m, votes, 1, accs, balance) 148 | neg = get_cond_probs(m, votes, 0, accs, balance) 149 | 150 | if pos == 0: 151 | return 0 152 | else: 153 | return pos / (pos + neg) 154 | 155 | 156 | def pick_best_prompt(m, vote, gold, n): 157 | # overall accuracies Pr(lf_p = y) on test (we don't know these) 158 | overall_train_acc = np.zeros(m) 159 | for i in range(m): 160 | overall_train_acc[i] = len(np.where((vote[:, i] == gold) == True)[0])/n 161 | 162 | return overall_train_acc.argmax() 163 | 164 | 165 | def main(): 166 | 167 | # number of weak labels 168 | m = 5 169 | 170 | # total number of vertices 171 | v = m + 1 172 | 173 | # randomly parametrize exponential family to determine accuracies and correlations 174 | #theta = np.random.rand() 175 | #theta_cliques = (np.random.randint(0, 2, 5)*2 - 1)*theta 176 | #theta = np.random.rand() 177 | #theta_cliques = [1, 1, 1, 1, 1, 1, 1] 178 | thetas = np.random.rand(30) 179 | 180 | # all conditionally independent 181 | potentials = [[5], [0], [1], [4], [0, 5], [1, 5], [2, 5], [3, 5], [4, 5]] 182 | 183 | pgm = Ising(m, potentials, thetas) 184 | 185 | n_train = 10000 186 | vote_train, gold_train = pgm.make_data(n_train) 187 | 188 | n_test = 1000 189 | vote_test, gold_test = pgm.make_data(n_test) 190 | 191 | accs = est_accs(m, vote_train, gold_train) 192 | balance = est_balance(gold_train, n_train) 193 | 194 | nb_output = np.zeros(n_test) # naive bayes 195 | mv_output = np.zeros(n_test) 196 | 197 | nb_err = 0 198 | mv_err = 0 199 | 200 | for i in range(n_test): 201 | nb_output[i] = 2*np.round(get_probs(m, vote_test[i], accs, balance))-1 202 | if nb_output[i] != gold_test[i]: 203 | nb_err += 1 204 | 205 | 206 | # note: play around with MV tie breaking strategy 207 | if len(np.where(vote_test[i] == 1)[0]) >= m / 2: 208 | mv_output[i] = 1 209 | elif len(np.where(vote_test[i] == 1)[0]) < m / 2: 210 | mv_output[i] = -1 211 | else: 212 | mv_output[i] = 2*np.random.randint(0, 2)-1 213 | 214 | if mv_output[i] != gold_test[i]: 215 | mv_err += 1 216 | 217 | nb_acc = 1 - (nb_err / n_test) 218 | mv_acc = 1 - (mv_err / n_test) 219 | #fs_acc = 1 - (fs_err / n_test) 220 | 221 | best_prompt = pick_best_prompt(m, vote_train, gold_train, n_train) 222 | 223 | best_prompt_acc = len(np.where((vote_test[:, best_prompt] == gold_test) == True)[0]) / n_test 224 | 225 | print(f"Naive bayes: {nb_acc}") 226 | print(f"Best prompt: {best_prompt_acc}") 227 | print(f"Majority vote: {mv_acc}") 228 | 229 | 230 | if __name__ == "__main__": 231 | main() -------------------------------------------------------------------------------- /evaporate/weak_supervision/run_ws.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import json 4 | import sys 5 | import pickle 6 | import random 7 | import cvxpy as cp 8 | import scipy as sp 9 | from tqdm import tqdm 10 | from evaporate.weak_supervision.methods import Aggregator 11 | from metal.label_model import LabelModel 12 | from collections import defaultdict, Counter 13 | 14 | from evaporate.evaluate_synthetic import clean_comparison 15 | 16 | 17 | def get_data( 18 | all_votes, 19 | gold_extractions_file, 20 | attribute='', 21 | has_abstains=1.0, 22 | num_elts = 5, 23 | extraction_fraction_thresh=0.9, 24 | ): 25 | """ 26 | Load in dataset from task_name depending on where files are saved. 27 | 28 | - num_elts = number of ``choices'' to use in the multiple-choice setup. 29 | 30 | """ 31 | label_name_to_ints = [] 32 | has_abstains = has_abstains >= extraction_fraction_thresh 33 | 34 | try: 35 | with open(gold_extractions_file) as f: 36 | gold_extractions = json.load(f) 37 | except: 38 | with open(gold_extractions_file, "rb") as f: 39 | gold_extractions = pickle.load(f) 40 | 41 | test_votes = [] 42 | test_golds = [] 43 | total_abstains = [] 44 | average_unique_votes = [] 45 | missing_files = [] 46 | random.seed(0) 47 | for file, extractions in tqdm(all_votes.items()): 48 | if file not in gold_extractions: 49 | missing_files.append(file) 50 | continue 51 | 52 | extractions = [clean_comparison(e) for e in extractions] 53 | if has_abstains: 54 | extractions = [e if e else 'abstain' for e in extractions] 55 | 56 | unique_votes = Counter(extractions).most_common(num_elts) 57 | unique_votes = [i for i, _ in unique_votes if i != 'abstain'] 58 | average_unique_votes.append(len(unique_votes)) 59 | if len(unique_votes) < num_elts: 60 | missing_elts = num_elts - len(unique_votes) 61 | for elt_num in range(missing_elts): 62 | unique_votes.append(f"dummy{elt_num}") 63 | 64 | random.shuffle(unique_votes) 65 | label_name_to_int = {elt: j for j, elt in enumerate(unique_votes)} 66 | label_name_to_ints.append(label_name_to_int) 67 | 68 | test_votes.append(np.array( 69 | [label_name_to_int[ans] if ans in label_name_to_int else -1 for ans in extractions] 70 | )) 71 | 72 | num_abstains = len([a for a in extractions if a not in label_name_to_int]) 73 | total_abstains.append(num_abstains) 74 | 75 | # golds are just for class balance purposes 76 | if attribute in gold_extractions[file]: 77 | gold = gold_extractions[file][attribute] 78 | elif clean_comparison(attribute) in gold_extractions[file]: 79 | gold = gold_extractions[file][clean_comparison(attribute)] 80 | else: 81 | gold = "" 82 | gold = clean_comparison(gold) 83 | if gold in label_name_to_int: 84 | test_golds.append(label_name_to_int[gold]) 85 | else: 86 | gold = random.sample(range(len(label_name_to_int)), 1) 87 | test_golds.append(gold[0]) 88 | 89 | test_votes = np.array(test_votes) 90 | test_gold = np.array(test_golds) 91 | 92 | test_votes = test_votes.astype(int) 93 | test_gold = test_gold.astype(int) 94 | 95 | print(f"Average abstains across documents: {np.mean(total_abstains)}") 96 | print(f"Average unique votes per document: {np.mean(average_unique_votes)}") 97 | 98 | return test_votes, test_gold, label_name_to_ints, missing_files 99 | 100 | 101 | def get_top_deps_from_inverse_sig(J, k): 102 | m = J.shape[0] 103 | deps = [] 104 | sorted_idxs = np.argsort(np.abs(J), axis=None) 105 | n = m*m 106 | idxs = sorted_idxs[-k:] 107 | for idx in idxs: 108 | i = int(np.floor(idx / m)) 109 | j = idx % m 110 | if (j, i) in deps: 111 | continue 112 | deps.append((i, j)) 113 | return deps 114 | 115 | 116 | def learn_structure(L): 117 | m = L.shape[1] 118 | n = float(np.shape(L)[0]) 119 | sigma_O = (np.dot(L.T,L))/(n-1) - np.outer(np.mean(L,axis=0), np.mean(L,axis=0)) 120 | 121 | #bad code 122 | O = 1/2*(sigma_O+sigma_O.T) 123 | O_root = np.real(sp.linalg.sqrtm(O)) 124 | 125 | # low-rank matrix 126 | L_cvx = cp.Variable([m,m], PSD=True) 127 | 128 | # sparse matrix 129 | S = cp.Variable([m,m], PSD=True) 130 | 131 | # S-L matrix 132 | R = cp.Variable([m,m], PSD=True) 133 | 134 | #reg params 135 | lam = 1/np.sqrt(m) 136 | gamma = 1e-8 137 | 138 | objective = cp.Minimize(0.5*(cp.norm(R @ O_root, 'fro')**2) - cp.trace(R) + lam*(gamma*cp.pnorm(S,1) + cp.norm(L_cvx, "nuc"))) 139 | constraints = [R == S - L_cvx, L_cvx>>0] 140 | 141 | prob = cp.Problem(objective, constraints) 142 | result = prob.solve(verbose=False, solver=cp.SCS) 143 | opt_error = prob.value 144 | 145 | #extract dependencies 146 | J_hat = S.value 147 | 148 | if J_hat is None: 149 | raise ValueError("CVXPY failed to solve the structured learning problem, use result without dependencies.") 150 | 151 | for i in range(m): 152 | J_hat[i, i] = 0 153 | return J_hat 154 | 155 | 156 | def learn_structure_multiclass(L, k): 157 | m = L.shape[1] 158 | J_hats = np.zeros((k, m, m)) 159 | for c in range(k): 160 | 161 | all_votes_c = np.where(L == c, 1, 0) 162 | J_hats[c] = learn_structure(all_votes_c) 163 | 164 | return J_hats 165 | 166 | 167 | def get_min_off_diagonal(J_hat): 168 | J_hat_copy = J_hat.copy() 169 | for i in range(len(J_hat_copy)): 170 | J_hat_copy[i, i] = np.inf 171 | return np.abs(J_hat_copy).min() 172 | 173 | 174 | def run_ws( 175 | all_votes, 176 | gold_extractions_file, 177 | symmetric=True, 178 | attribute='', 179 | has_abstains=1.0, 180 | extraction_fraction_thresh=0.9, 181 | ): 182 | test_votes, test_gold, label_name_to_ints, missing_files = get_data( 183 | all_votes, 184 | gold_extractions_file, 185 | attribute=attribute, 186 | has_abstains=has_abstains, 187 | extraction_fraction_thresh=extraction_fraction_thresh, 188 | ) 189 | 190 | classes = np.sort(np.unique(test_gold)) 191 | vote_classes = np.sort(np.unique(test_votes)) 192 | n_test, m = test_votes.shape 193 | k = len(classes) 194 | abstains = len(vote_classes) == len(classes) + 1 195 | print(f"Abstains: {abstains}") 196 | 197 | m = test_votes.shape[1] 198 | all_votes = test_votes 199 | 200 | label_model = LabelModel(k=k, seed=123) 201 | 202 | # scale to 0, 1, 2 (0 is abstain) 203 | test_votes_scaled = (test_votes + np.ones((n_test, m))).astype(int) 204 | test_gold_scaled = (test_gold + np.ones(n_test)).astype(int) 205 | all_votes_scaled = test_votes_scaled 206 | 207 | label_model.train_model( 208 | all_votes_scaled, 209 | Y_dev=test_gold_scaled, 210 | abstains=abstains, 211 | symmetric=symmetric, 212 | n_epochs=10000, 213 | log_train_every=1000, 214 | lr=0.00001 215 | ) 216 | 217 | print('Trained Label Model Metrics (No deps):') 218 | scores, preds = label_model.score( 219 | (test_votes_scaled, test_gold_scaled), 220 | metric=['accuracy','precision', 'recall', 'f1'] 221 | ) 222 | print(scores) 223 | all_votes_no_abstains = np.where(all_votes == -1, 0, all_votes) 224 | 225 | used_deps = False 226 | try: 227 | if len(classes) == 2: 228 | J_hat = learn_structure(all_votes_no_abstains) 229 | else: 230 | J_hats = learn_structure_multiclass(all_votes_no_abstains, len(classes)) 231 | J_hat = J_hats.mean(axis=0) 232 | 233 | # if values in J are all too large, then everything is connected / structure learning isn't learning the right thing. Don't model deps then 234 | min_entry = get_min_off_diagonal(J_hat) 235 | if min_entry < 1: 236 | deps = get_top_deps_from_inverse_sig(J_hat, 1) 237 | print("Recovered dependencies: ", deps) 238 | 239 | label_model.train_model( 240 | all_votes_scaled, 241 | Y_dev=test_gold_scaled, 242 | abstains=abstains, 243 | symmetric=symmetric, 244 | n_epochs=80000, 245 | log_train_every=1000, 246 | lr=0.000001, 247 | deps=deps 248 | ) 249 | print('Trained Label Model Metrics (with deps):') 250 | scores, preds = label_model.score( 251 | (test_votes_scaled, test_gold_scaled), 252 | metric=['accuracy', 'precision', 'recall', 'f1'] 253 | ) 254 | print(scores) 255 | used_deps = True 256 | except: 257 | print(f"Not modeling dependencies.") 258 | 259 | # convert the preds back 260 | mapped_preds = [] 261 | for label_name_to_int, pred in tqdm(zip(label_name_to_ints, preds)): 262 | int_to_label_name = {v:k for k, v in label_name_to_int.items()} 263 | try: 264 | pred = int_to_label_name[pred-1] 265 | except: 266 | pred = '' 267 | mapped_preds.append(pred) 268 | return mapped_preds, used_deps, missing_files 269 | 270 | 271 | if __name__ == "__main__": 272 | run_ws() 273 | -------------------------------------------------------------------------------- /evaporate/prompts.py: -------------------------------------------------------------------------------- 1 | ############################ SCHEMA ID PROMPTS ############################ 2 | SCHEMA_ID_PROMPTS = [ 3 | f"""Sample text: 4 |
• Monarch
Charles III 5 |
• Governor General
Mary Simon 6 | Provinces and Territories 7 | 20 | 21 | Question: List all relevant attributes about 'Canada' that are exactly mentioned in this sample text if any. 22 | Answer: 23 | - Monarch: Charles III 24 | - Governor General: Mary Simon 25 | - Provinces and Territories: Saskatchewan, Manitoba, Ontario, Quebec, New Brunswick, Prince Edward Island, Nova Scotia, Newfoundland and Labrador, Yukon, Nunavut, Northwest Territories 26 | 27 | ---- 28 | 29 | Sample text: 30 | Patient birth date: 1990-01-01 31 | Prescribed medication: aspirin, ibuprofen, acetaminophen 32 | Prescribed dosage: 1 tablet, 2 tablets, 3 tablets 33 | Doctor's name: Dr. Burns 34 | Date of discharge: 2020-01-01 35 | Hospital address: 123 Main Street, New York, NY 10001 36 | 37 | Question: List all relevant attributes about 'medications' that are exactly mentioned in this sample text if any. 38 | Answer: 39 | - Prescribed medication: aspirin, ibuprofen, acetaminophen 40 | - Prescribed dosage: 1 tablet, 2 tablets, 3 tablets 41 | 42 | ---- 43 | 44 | Sample text: 45 | {{chunk:}} 46 | 47 | Question: List all relevant attributes about '{{topic:}}' that are exactly mentioned in this sample text if any. 48 | Answer:""" 49 | ] 50 | 51 | 52 | ############################ PROMPTS FOR EXTRACTING A SPECIFIC FIELD BY DIRECTLY GIVING THE MODEL THE CONTEXT ############################ 53 | METADATA_EXTRACTION_WITH_LM = [ 54 | f"""Here is a file sample: 55 | 56 | Location 57 | Cupertino, CaliforniaSince 1987 58 | 59 | Question: Return the full "location" span of this sample if it exists, otherwise output []. 60 | Answer: ['Cupertino, California Since 1987'] 61 | 62 | ---- 63 | 64 | Here is a file sample: 65 | 66 | {{chunk:}} 67 | 68 | Question: Return the full "{{attribute:}}" span of this sample if it exists, otherwise output []. 69 | Answer:""", 70 | ] 71 | 72 | 73 | METADATA_EXTRACTION_WITH_LM_ZERO_SHOT = [ 74 | f"""Sample text: 75 | 76 | {{chunk:}} 77 | 78 | Question: What is the "{{attribute:}}" value in the text? 79 | Answer:""" 80 | ] 81 | 82 | EXTRA_PROMPT = [ 83 | f"""Here is a file sample: 84 | 85 | 86 | 87 | Question: Return the full "price" from this sample if it exists, otherwise output []. 88 | Answer: ['$550'] 89 | 90 | ---- 91 | 92 | Here is a file sample: 93 | 94 | {{chunk:}} 95 | 96 | Question: Return the full "{{attribute:}}" from this sample if it exists, otherwise output []. 97 | Answer:""", 98 | ] 99 | 100 | METADATA_EXTRACTION_WITH_LM_CONTEXT = [ 101 | f"""Here is a file sample: 102 | 103 | A. 510(k) Number: 104 | k143467 105 | 106 | Question: Return the full "510(k) Number" from this sample if it exists and the context around it, otherwise output []. 107 | Answer: [510(k) Number: k143467] 108 | 109 | ---- 110 | 111 | Here is a file sample: 112 | 113 | The iphone price increases a lot this there. Each iphone's price is as high as 1000$. 114 | 115 | Question: Return the full "price" from this sample if it exists and the context around it, otherwise output []. 116 | Answer: [Each iphone's price is as high as 1000$] 117 | 118 | ---- 119 | 120 | Here is a file sample: 121 | 122 | {{chunk:}} 123 | 124 | Question: Return the full "{{attribute:}}" from this sample if it exists and the context around it, otherwise output []. 125 | Answer:""", 126 | ] 127 | 128 | IS_VALID_ATTRIBUTE = [ 129 | f"""Question: Could "2014" be a "year" value in a "students" database? 130 | Answer: Yes 131 | 132 | ---- 133 | 134 | Question: Could "cupcake" be a "occupation" value in a "employee" database? 135 | Answer: No 136 | 137 | ---- 138 | 139 | Question: Could "''" be a "animal" value in a "zoo" database? 140 | Answer: No 141 | 142 | ---- 143 | 144 | Question: Could "police officer" be a "occupation" value in a "employee" database? 145 | Answer: Yes 146 | 147 | ---- 148 | 149 | Question: Could "{{value:}}" be a "{{attr_str:}}" value in a {{topic:}} database? 150 | Answer:""" 151 | ] 152 | 153 | 154 | PICK_VALUE = [ 155 | f"""Examples: 156 | - 32 157 | - 2014 158 | - 99.4 159 | - 2012 160 | 161 | Question: Which example is a "year"? 162 | Answer: 2012, 2014 163 | 164 | ---- 165 | 166 | Examples: 167 | - police officer 168 | - occupation 169 | 170 | Question: Which example is a "occupation"? 171 | Answer: police officer 172 | 173 | ---- 174 | 175 | Examples: 176 | {{pred_str:}} 177 | 178 | Question: Which example is a "{{attribute:}}"? 179 | Answer:""" 180 | ] 181 | 182 | 183 | PICK_VALUE_CONTEXT = [ 184 | f"""Here are file samples: 185 | 186 | -The purpose for submission is to obtain substantial equivalence determination for the illumigene HSV 1&2 DNA Amplification Assay. 187 | -The purpose for submission of this document is not specified in the provided sample. 188 | -The purpose for submission of this file is not specified. 189 | 190 | Question: Extract "the purpose for submission" from the right sample , otherwise output []. 191 | Answer: to obtain substantial equivalence determination for the illumigene HSV 1&2 DNA Amplification Assay 192 | 193 | ---- 194 | 195 | Here are file samples: 196 | 197 | {{pred_str:}} 198 | 199 | Question: Return the full "{{attribute:}}" from this sample if it exists, otherwise output []. 200 | Answer:""", 201 | ] 202 | 203 | 204 | 205 | ############################## PROMPTS TO GENERATE FUNCTIONS THAT PARSE FOR A SPECIFIC FIELD ############################## 206 | METADATA_GENERATION_FOR_FIELDS = [ 207 | # base prompt 208 | f"""Here is a sample of text: 209 | 210 | {{chunk:}} 211 | 212 | 213 | Question: Write a python function to extract the entire "{{attribute:}}" field from text, but not any other metadata. Return the result as a list. 214 | 215 | 216 | import re 217 | 218 | def get_{{function_field:}}_field(text: str): 219 | \""" 220 | Function to extract the "{{attribute:}} field". 221 | \""" 222 | """, 223 | 224 | # prompt with flexible library imports 225 | f"""Here is a file sample: 226 | 227 | DESCRIPTION: This file answers the question, "How do I sort a dictionary by value?" 228 | DATES MODIFIED: The file was modified on the following dates: 229 | 2009-03-05T00:49:05 230 | 2019-04-07T00:22:14 231 | 2011-11-20T04:21:49 232 | USERS: The users who modified the file are: 233 | Jeff Jacobs 234 | Richard Smith 235 | Julia D'Angelo 236 | Rebecca Matthews 237 | FILE TYPE: This is a text file. 238 | 239 | Question: Write a python function called "get_dates_modified_field" to extract the "DATES MODIFIED" field from the text. Include any imports. 240 | 241 | import re 242 | 243 | def get_dates_modified_field(text: str): 244 | \""" 245 | Function to extract the dates modified. 246 | \""" 247 | parts= text.split("USERS")[0].split("DATES MODIFIED")[-1] 248 | pattern = r'\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}' 249 | return re.findall(pattern, text) 250 | 251 | ---- 252 | 253 | Here is a file sample: 254 | 255 | U.S. GDP Rose 2.9% in the Fourth Quarter After a Year of High Inflation - WSJ 256 | 257 | 258 | 259 | 260 | 261 | Question: Write a python function called "get_date_published_field" to extract the "datePublished" field from the text. Include any imports. 262 | 263 | from bs4 import BeautifulSoup 264 | 265 | def get_date_published_field(text: str): 266 | \""" 267 | Function to extract the date published. 268 | \""" 269 | soup = BeautifulSoup(text, parser="html.parser") 270 | date_published_field = soup.find('meta', itemprop="datePublished") 271 | date_published_field = date_published_field['content'] 272 | return date_published_field 273 | 274 | ---- 275 | 276 | Here is a sample of text: 277 | 278 | {{chunk:}} 279 | 280 | Question: Write a python function called "get_{{function_field:}}_field" to extract the "{{attribute:}}" field from the text. Include any imports.""" 281 | ] 282 | 283 | 284 | class Step: 285 | def __init__(self, prompt) -> None: 286 | self.prompt = prompt 287 | 288 | def execute(self): 289 | pass 290 | 291 | 292 | -------------------------------------------------------------------------------- /evaporate/schema_identification.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import statistics 4 | import random 5 | from tqdm import tqdm 6 | from collections import Counter, defaultdict 7 | from typing import List, Dict, Tuple, Set 8 | 9 | from evaporate.prompts import Step, SCHEMA_ID_PROMPTS 10 | from evaporate.utils import apply_prompt 11 | from evaporate.profiler_utils import clean_metadata 12 | 13 | 14 | def directly_extract_from_chunks_w_value( 15 | file2chunks, 16 | sample_files, 17 | manifest_session, 18 | overwrite_cache=False, 19 | topic=None, 20 | use_dynamic_backoff=True, 21 | ): 22 | total_tokens_prompted = 0 23 | field2value = defaultdict(list) 24 | field2count = Counter() 25 | file2results = defaultdict() 26 | num_chunks_per_file = [len(file2chunks[file]) for file in file2chunks] 27 | avg_num_chunks_per_file = statistics.mean(num_chunks_per_file) 28 | stdev_num_chunks_per_file = statistics.stdev(num_chunks_per_file) 29 | 30 | for i, file in enumerate(sample_files): 31 | chunks = file2chunks[file] 32 | print(f"Chunks in sample file {file}: {len(chunks)}") 33 | 34 | for i, file in tqdm( 35 | enumerate(sample_files), 36 | total=len(sample_files), 37 | desc="Directly extracting metadata from chunks", 38 | ): 39 | chunks = file2chunks[file] 40 | extractionset = set() 41 | file_results = {} 42 | for chunk_num, chunk in enumerate(chunks): 43 | if (chunk_num > avg_num_chunks_per_file + stdev_num_chunks_per_file) and use_dynamic_backoff: 44 | break 45 | prompt_template = SCHEMA_ID_PROMPTS[0] 46 | prompt = prompt_template.format(chunk=chunk, topic=topic) 47 | try: 48 | result, num_toks = apply_prompt( 49 | Step(prompt), 50 | max_toks=500, 51 | manifest=manifest_session, 52 | overwrite_cache=overwrite_cache 53 | ) 54 | except: 55 | print("Failed to apply prompt to chunk.") 56 | continue 57 | total_tokens_prompted += num_toks 58 | result = result.split("---")[0].strip("\n") 59 | results = result.split("\n") 60 | results = [r.strip("-").strip() for r in results] 61 | results = [r[2:].strip() if len(r) > 2 and r[1] == "." else r for r in results ] 62 | for result in results: 63 | try: 64 | field = result.split(": ")[0].strip(":") 65 | value = ": ".join(result.split(": ")[1:]) 66 | except: 67 | print(f"Skipped: {result}") 68 | continue 69 | field_versions = [ 70 | field, 71 | field.replace(" ", ""), 72 | field.replace("-", ""), 73 | field.replace("_", ""), 74 | ] 75 | if not any([f.lower() in chunk.lower() for f in field_versions]) and use_dynamic_backoff: 76 | continue 77 | if not value and use_dynamic_backoff: 78 | continue 79 | field = field.lower().strip("-").strip("_").strip(" ").strip(":") 80 | if field in extractionset and use_dynamic_backoff: 81 | continue 82 | field2value[field].append(value) 83 | extractionset.add(field) 84 | field2count[field] += 1 85 | file_results[field] = value 86 | file2results[file] = file_results 87 | return field2value, field2count, total_tokens_prompted 88 | 89 | 90 | def get_metadata_string_w_value(field2value, exclude=[], key=0): 91 | field2num_extractions = Counter() 92 | for field, values in field2value.items(): 93 | field2num_extractions[field] += len(values) 94 | 95 | reranked_metadata = {} 96 | try: 97 | max_count = field2num_extractions.most_common(1)[0][1] 98 | except: 99 | return '' 100 | fields = [] 101 | sort_field2num_extractions = sorted( 102 | field2num_extractions.most_common(), 103 | key=lambda x: (x[1], x[0]), 104 | reverse=True 105 | ) 106 | for item in sort_field2num_extractions: 107 | field, count = item[0], item[1] 108 | if field.lower() in exclude: 109 | continue 110 | if count == 1 and max_count > 1: 111 | continue 112 | idx = min(key, len(field2value[field]) - 1) 113 | values = [field2value[field][idx]] 114 | if idx < len(field2value[field]) - 1: 115 | values.append(field2value[field][idx + 1]) 116 | reranked_metadata[field] = values 117 | if len(reranked_metadata) > 200: 118 | break 119 | fields.append(field) 120 | 121 | random.seed(key) 122 | keys=reranked_metadata.keys() 123 | random.shuffle(list(keys)) 124 | reordered_dict = {} 125 | for key in keys: 126 | reordered_dict[key] = reranked_metadata[key] 127 | reranked_metadata_str = str(reordered_dict) 128 | return reranked_metadata_str 129 | 130 | 131 | def rerank( 132 | field2value, exclude, cleaned_counter, order_of_addition, base_extraction_count, 133 | most_in_context_example, topic, manifest_session, overwrite_cache=False 134 | ): 135 | total_tokens_prompted = 0 136 | votes_round1 = Counter() 137 | for i in range(3): 138 | reranked_metadata_str = get_metadata_string_w_value(field2value, exclude=exclude, key=i) 139 | if not reranked_metadata_str: 140 | continue 141 | 142 | prompt = \ 143 | f"""{most_in_context_example}Attributes: 144 | {reranked_metadata_str} 145 | 146 | List the most useful keys to include in a SQL database about "{topic}", if any. 147 | Answer:""" 148 | try: 149 | result, num_toks = apply_prompt( 150 | Step(prompt), 151 | max_toks=500, 152 | manifest=manifest_session, 153 | overwrite_cache=overwrite_cache, 154 | ) 155 | except: 156 | print("Failed to apply prompt") 157 | continue 158 | total_tokens_prompted += num_toks 159 | result = result.split("---")[0].strip("\n") 160 | results = result.split("\n") 161 | result = results[0].replace("[", "").replace("]", "").replace("'", "").replace('"', '') 162 | result = result.split(", ") 163 | result = [r.lower() for r in result] 164 | 165 | indices = [idx for idx, r in enumerate(result) if not r] 166 | if result and indices: 167 | result = result[:indices[0]] 168 | 169 | # Deduplicate but preserve order 170 | result = list(dict.fromkeys(result)) 171 | for r in result: 172 | r = r.strip("_").strip("-") 173 | r = r.strip("'").strip('"').strip() 174 | if not r or r in exclude or r not in base_extraction_count: 175 | continue 176 | votes_round1[r] += 2 177 | 178 | fields = sorted(list(votes_round1.keys())) 179 | for r in fields: 180 | r = r.strip("_").strip("-") 181 | r = r.strip("'").strip('"').strip() 182 | if not r or r in exclude or r not in base_extraction_count: 183 | continue 184 | if votes_round1[r] > 1: 185 | cleaned_counter[r] = votes_round1[r] * base_extraction_count[r] 186 | order_of_addition.append(r) 187 | else: 188 | cleaned_counter[r] = base_extraction_count[r] 189 | order_of_addition.append(r) 190 | exclude.append(r) 191 | 192 | return cleaned_counter, order_of_addition, exclude, total_tokens_prompted 193 | 194 | 195 | def rerank_metadata( 196 | base_extraction_count, field2value, topic, manifest_session, overwrite_cache 197 | ): 198 | 199 | most_in_context_example = \ 200 | """Attributes: 201 | {'name': 'Jessica', 'student major': 'Computer Science', 'liscense': 'accredited', 'college name': 'University of Michigan', ''GPA': '3.9', 'student email': 'jess@umich.edu', 'rating': '42', 'title': 'details'} 202 | 203 | List the most useful keys to include in a SQL database for "students", if any. 204 | Answer: ['name', 'student major', 'college name', 'GPA', 'student email'] 205 | 206 | ---- 207 | 208 | """ 209 | 210 | total_tokens_prompted = 0 211 | cleaned_counter = Counter() 212 | exclude = [] 213 | order_of_addition = [] 214 | 215 | cleaned_counter, order_of_addition, exclude, total_tokens_prompted = rerank( 216 | field2value, exclude, cleaned_counter, order_of_addition, base_extraction_count, 217 | most_in_context_example, topic, manifest_session, overwrite_cache=overwrite_cache 218 | ) 219 | 220 | cleaned_counter, order_of_addition, exclude, total_tokens_prompted = rerank( 221 | field2value, exclude, cleaned_counter, order_of_addition, base_extraction_count, 222 | most_in_context_example, topic, manifest_session, overwrite_cache=overwrite_cache 223 | ) 224 | 225 | fields = sorted(list(base_extraction_count.keys())) 226 | for field in fields: 227 | if field not in cleaned_counter: 228 | cleaned_counter[field] = base_extraction_count[field] / 2 229 | order_of_addition.append(field) 230 | return cleaned_counter, total_tokens_prompted, order_of_addition 231 | 232 | 233 | #################### SAVE GENERATIVE INDEX OF FILE BASED METADATA ######################### 234 | def identify_schema(run_string, args, file2chunks: Dict, file2contents: Dict, sample_files: List, manifest_sessions: Dict, group_name: str, profiler_args): 235 | # get sample and eval files, convert the sample scripts to chunks 236 | random.seed(0) 237 | total_tokens_prompted = 0 238 | 239 | field2value, extract_w_value, num_toks = directly_extract_from_chunks_w_value( 240 | file2chunks, 241 | sample_files, 242 | manifest_sessions[profiler_args.GOLD_KEY], 243 | overwrite_cache=profiler_args.overwrite_cache, 244 | topic=args.topic, 245 | use_dynamic_backoff=profiler_args.use_dynamic_backoff, 246 | ) 247 | total_tokens_prompted += num_toks 248 | 249 | base_extraction_count, num_toks, order_of_addition = rerank_metadata( 250 | extract_w_value, 251 | field2value, 252 | args.topic, 253 | manifest_sessions[profiler_args.GOLD_KEY], 254 | profiler_args.overwrite_cache, 255 | ) 256 | total_tokens_prompted += num_toks 257 | 258 | with open(f"{args.generative_index_path}/{run_string}_identified_schema.json", "w") as f: 259 | json.dump(base_extraction_count, f) 260 | 261 | with open(f"{args.generative_index_path}/{run_string}_order_of_addition.json", "w") as f: 262 | json.dump(order_of_addition, f) 263 | 264 | return total_tokens_prompted 265 | -------------------------------------------------------------------------------- /evaporate/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from collections import Counter, defaultdict 4 | 5 | from manifest import Manifest 6 | from evaporate.configs import get_args 7 | from evaporate.prompts import Step 8 | from openai import OpenAI 9 | 10 | cur_idx = 0 11 | TOGETHER_API_KEY = os.environ.get("TOGETHER_API_KEY") 12 | #If using together AI, you will need to set the TOGETHER_API_KEY to your API key 13 | 14 | 15 | def together_call(prompt, model, streaming = False, max_tokens = 1024): 16 | client = OpenAI( 17 | api_key=TOGETHER_API_KEY, 18 | base_url='https://api.together.xyz', 19 | 20 | ) 21 | messages = [{ 22 | "role": "system", 23 | "content": "You are an AI assistant", 24 | }, { 25 | "role": "user", 26 | "content": prompt, 27 | }] 28 | chat_completion = client.chat.completions.create(messages=messages, 29 | model=model, 30 | max_tokens=max_tokens, 31 | #response_format={ "type": "json_object" }, 32 | stream=streaming) 33 | response = chat_completion.choices[0].message.content 34 | return response 35 | 36 | def apply_prompt(step : Step, max_toks = 50, do_print=False, manifest=None, overwrite_cache=False): 37 | global cur_idx 38 | manifest_lst = manifest.copy() 39 | if len(manifest) == 1: 40 | manifest = manifest_lst[0] 41 | else: 42 | manifest = manifest_lst[cur_idx] 43 | 44 | # sometimes we want to rotate keys 45 | cur_idx = cur_idx + 1 46 | if cur_idx >= len(manifest_lst)-1: 47 | cur_idx = 0 48 | 49 | prompt = step.prompt 50 | response, num_tokens = get_response( 51 | prompt, 52 | manifest, 53 | max_toks = max_toks, 54 | overwrite=overwrite_cache, 55 | stop_token="---" 56 | ) 57 | step.response = response 58 | if do_print: 59 | print(response) 60 | return response, num_tokens 61 | 62 | 63 | def get_file_attribute(attribute): 64 | attribute = attribute.lower() 65 | attribute = attribute.replace("/", "_").replace(")", "").replace("-", "_") 66 | attribute = attribute.replace("(", "").replace(" ", "_") 67 | if len(attribute) > 30: 68 | attribute = attribute[:30] 69 | return attribute 70 | 71 | 72 | def get_all_files(data_dir): 73 | files = [] 74 | for file in os.listdir(data_dir): 75 | if os.path.isfile(os.path.join(data_dir, file)): 76 | files.append(os.path.join(data_dir, file)) 77 | else: 78 | files.extend(get_all_files(os.path.join(data_dir, file))) 79 | return files 80 | 81 | 82 | def get_directory_hierarchy(data_dir): 83 | if not data_dir.endswith("/") and os.path.isdir(data_dir): 84 | data_dir = data_dir + "/" 85 | directories2subdirs = defaultdict(list) 86 | for file in os.listdir(data_dir): 87 | new_dir = os.path.join(data_dir, file) 88 | if not new_dir.endswith("/") and os.path.isdir(new_dir): 89 | new_dir = new_dir + "/" 90 | if os.path.isdir(new_dir): 91 | directories2subdirs[data_dir].append(new_dir) 92 | if os.listdir(new_dir): 93 | more_subdirs = get_directory_hierarchy(new_dir) 94 | for k, v in more_subdirs.items(): 95 | directories2subdirs[k].extend(v) 96 | else: 97 | directories2subdirs[new_dir] = [] 98 | else: 99 | directories2subdirs[data_dir].append(new_dir) 100 | return directories2subdirs 101 | 102 | 103 | def get_unique_file_types(files): 104 | suffix2file = {} 105 | suffix2count = Counter() 106 | for file in files: 107 | suffix = file.split(".")[-1] 108 | if not suffix: 109 | suffix = "txt" 110 | suffix2count[suffix] += 1 111 | if suffix not in suffix2file: 112 | suffix2file[suffix] = file 113 | return suffix2file, suffix2count 114 | 115 | 116 | def get_structure(dataset_name, profiler_args): 117 | args = get_args(profiler_args) 118 | if not os.path.exists(args.cache_dir): 119 | os.makedirs(args.cache_dir) 120 | 121 | if not os.path.exists(args.generative_index_path): 122 | os.makedirs(args.generative_index_path) 123 | 124 | if not os.path.exists(args.generative_index_path): 125 | os.makedirs(args.generative_index_path) 126 | 127 | # all files 128 | cache_path = f"{args.cache_dir}/all_files.json" 129 | if not os.path.exists(cache_path) or args.overwrite_cache: 130 | files = get_all_files(args.data_dir) 131 | with open(cache_path, "w") as f: 132 | json.dump(files, f) 133 | else: 134 | with open(cache_path) as f: 135 | files = json.load(f) 136 | 137 | # all directories 138 | cache_path = f"{args.cache_dir}/all_dirs.json" 139 | if not os.path.exists(cache_path) or args.overwrite_cache: 140 | directory_hierarchy = get_directory_hierarchy(args.data_dir) 141 | with open(cache_path, "w") as f: 142 | json.dump(directory_hierarchy, f) 143 | else: 144 | with open(cache_path) as f: 145 | directory_hierarchy = json.load(f) 146 | 147 | suffix2file, suffix2count = get_unique_file_types(files) 148 | file_examples = "\n".join(list(suffix2file.values())) 149 | file_types = ", ".join((suffix2file.keys())) 150 | return directory_hierarchy, files, file_examples, file_types, args 151 | 152 | 153 | def get_files_in_group(dir_path): 154 | file_group = [] 155 | for i, (root,dirs,files) in enumerate(os.walk(dir_path, topdown=True)): 156 | files = [f"{root}/{f}" for f in files] 157 | file_group.extend(files) 158 | print(f"Working with a sample size of : {len(file_group)} files.") 159 | return file_group 160 | 161 | 162 | # MANIFEST 163 | def get_manifest_sessions(MODELS, MODEL2URL=None, KEYS=[]): 164 | manifest_sessions = defaultdict(list) 165 | for model in MODELS: 166 | if any(kwd in model for kwd in ["davinci", "curie", "babbage", "ada", "cushman"]): 167 | if not KEYS: 168 | raise ValueError("You must provide a list of keys to use these models.") 169 | for key in KEYS: 170 | manifest, model_name = get_manifest_session( 171 | client_name="openai", 172 | client_engine=model, 173 | client_connection=key, 174 | ) 175 | manifest_sessions[model].append(manifest) 176 | elif any(kwd in model for kwd in ["gpt-4", "gpt-3.5"]): 177 | if not KEYS: 178 | raise ValueError("You must provide a list of keys to use these models.") 179 | for key in KEYS: 180 | manifest, model_name = get_manifest_session( 181 | client_name="openaichat", 182 | client_engine=model, 183 | client_connection=key, 184 | ) 185 | manifest_sessions[model].append(manifest) 186 | else: 187 | if(model not in MODEL2URL): 188 | manifest = {} 189 | manifest["__name"] = model 190 | print("using together AI") 191 | else: 192 | print("using huggingface") 193 | manifest, model_name = get_manifest_session( 194 | client_name="huggingface", 195 | client_engine=model, 196 | client_connection=MODEL2URL[model], 197 | ) 198 | manifest_sessions[model].append(manifest) 199 | return manifest_sessions 200 | 201 | 202 | def get_manifest_session( 203 | client_name="huggingface", 204 | client_engine=None, 205 | client_connection="http://127.0.0.1:5000", 206 | cache_connection=None, 207 | temperature=0, 208 | top_p=1.0, 209 | ): 210 | if client_name == "huggingface" and temperature == 0: 211 | params = { 212 | "temperature": 0.001, 213 | "do_sample": False, 214 | "top_p": top_p, 215 | } 216 | elif client_name in {"openai", "ai21", "openaichat"}: 217 | params = { 218 | "temperature": temperature, 219 | "top_p": top_p, 220 | "engine": client_engine, 221 | } 222 | else: 223 | raise ValueError(f"{client_name} is not a valid client name") 224 | 225 | cache_params = { 226 | "cache_name": "sqlite", 227 | "cache_connection": cache_connection, 228 | } 229 | 230 | manifest = Manifest( 231 | client_name=client_name, 232 | client_connection=client_connection, 233 | **params, 234 | **cache_params, 235 | ) 236 | 237 | params = manifest.client_pool.get_current_client().get_model_params() 238 | model_name = params["model_name"] 239 | if "engine" in params: 240 | model_name += f"_{params['engine']}" 241 | return manifest, model_name 242 | 243 | 244 | def get_response( 245 | prompt, 246 | manifest, 247 | overwrite=False, 248 | max_toks=10, 249 | stop_token=None, 250 | gold_choices=[], 251 | verbose=False, 252 | ): 253 | prompt = prompt.strip() 254 | if gold_choices: 255 | gold_choices = [" " + g.strip() for g in gold_choices] 256 | if type(manifest) == dict and manifest["__name"] != "openai": 257 | response = together_call(prompt, manifest["__name"]) 258 | num_tokens = 0 259 | else: 260 | response_obj = manifest.run( 261 | prompt, 262 | gold_choices=gold_choices, 263 | overwrite_cache=overwrite, 264 | return_response=True, 265 | ) 266 | response_obj = response_obj.get_json_response()["choices"][0] 267 | log_prob = response_obj["text_logprob"] 268 | response = response_obj["text"] 269 | num_tokens = response_obj['usage']['total_tokens'] 270 | else: 271 | if type(manifest) == dict and manifest["__name"] != "openai": 272 | response = together_call(prompt, manifest["__name"]) 273 | num_tokens = 0 274 | else: 275 | response_obj = manifest.run( 276 | prompt, 277 | max_tokens=max_toks, 278 | stop_token=stop_token, 279 | overwrite_cache=overwrite, 280 | return_response=True 281 | ) 282 | num_tokens = -1 283 | try: 284 | num_tokens = response_obj.get_usage_obj().usages[0].total_tokens 285 | except: 286 | num_tokens = 0 287 | print("Fail to get total tokens used") 288 | response_obj = response_obj.get_json_response() 289 | response = response_obj["choices"][0]["text"] 290 | stop_token = "---" 291 | response = response.strip().split(stop_token)[0].strip() if stop_token else response.strip() 292 | log_prob = None 293 | if verbose: 294 | print("\n***Prompt***\n", prompt) 295 | print("\n***Response***\n", response) 296 | if log_prob: 297 | return response, log_prob 298 | return response, num_tokens 299 | 300 | -------------------------------------------------------------------------------- /evaporate/evaluate_profiler.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, Counter 2 | import numpy as np 3 | from evaporate.prompts import (PICK_VALUE_CONTEXT, Step,) 4 | from evaporate.utils import apply_prompt 5 | 6 | 7 | def clean_comparison(responses, field): 8 | clean_responses = [] 9 | if type(responses) == str: 10 | responses = [responses] 11 | for response in responses: 12 | response = response.lower() 13 | field = field.lower() 14 | field_reformat = field.replace("_", "-") 15 | 16 | for char in ["'", field, field_reformat, ":", "<", ">", '"', "none"]: 17 | response = response.replace(char, " ") 18 | for char in [",", ".", "?", "!", ";", "(", ")", "[", "]", "{", "}", "-", "none", "\n", "\t", "\r"]: 19 | response = response.replace(char, " ") 20 | response = response.replace(" ", " ") 21 | response = response.split() 22 | response = [r.strip() for r in response] 23 | response = [r for r in response if r] 24 | response = ' '.join(response) 25 | clean_responses.append(response) 26 | clean_responses = ", ".join(clean_responses) 27 | return clean_responses 28 | 29 | 30 | def normalize_value_type(metadata, attribute): 31 | # make everything a list of strings since functions can return diverse types 32 | cleaned_items = [] 33 | if type(metadata) == str: 34 | metadata = [metadata] 35 | for item in metadata: 36 | if type(item) == list: 37 | item = [str(i) for i in item] 38 | item = ", ".join(item) 39 | elif type(item) == tuple: 40 | item = list(item) 41 | item = [str(i) for i in item] 42 | item = ", ".join(item) 43 | elif item is None: 44 | item = '' 45 | elif type(item) != str: 46 | item = [str(item)] 47 | item = ", ".join(item) 48 | if item: 49 | cleaned_items.append(item) 50 | return cleaned_items 51 | 52 | 53 | def pick_a_gold_label(golds, attribute="", manifest_session=None, overwrite_cache=False): 54 | """ 55 | To counteract the large model hallucinating on various chunks affecting the evaluation of good functions. 56 | """ 57 | 58 | pred_str = "- " + "\n- ".join(golds) 59 | 60 | prompt_template = PICK_VALUE_CONTEXT[0] 61 | prompt = prompt_template.format(pred_str=pred_str, attribute=attribute) 62 | try: 63 | check, num_toks = apply_prompt( 64 | Step(prompt), 65 | max_toks=100, 66 | manifest=manifest_session, 67 | overwrite_cache=overwrite_cache 68 | ) 69 | except: 70 | return golds, 0 71 | check = check.split("\n") 72 | check = [c for c in check if c] 73 | if check: 74 | if "none" in check[0].lower(): 75 | check = golds 76 | else: 77 | check = check[0] 78 | return check, num_toks 79 | 80 | 81 | def text_f1( 82 | preds=[], 83 | golds=[], 84 | extraction_fraction=1.0, 85 | attribute=None, 86 | extraction_fraction_thresh=0.8, 87 | use_abstension=True, 88 | ): 89 | """Compute average F1 of text spans. 90 | Taken from Squad without prob threshold for no answer. 91 | """ 92 | total_f1 = 0 93 | total_recall = 0 94 | total_prec = 0 95 | f1s = [] 96 | total = 0 97 | 98 | if extraction_fraction >= extraction_fraction_thresh and use_abstension: 99 | new_preds = [] 100 | new_golds = [] 101 | for pred, gold in zip(preds, golds): 102 | if pred: 103 | new_preds.append(pred) 104 | new_golds.append(gold) 105 | preds = new_preds 106 | golds = new_golds 107 | if not preds: 108 | return 0.0, 0.0 109 | for pred, gold in zip(preds, golds): 110 | if type(pred) == str: 111 | pred_toks = pred.split() 112 | else: 113 | pred_toks = pred 114 | if type(gold) == str: 115 | gold_toks_list = [gold.split()] 116 | else: 117 | assert 0, print(gold) 118 | gold_toks_list = gold 119 | 120 | if type(gold_toks_list) == list and gold_toks_list: 121 | for gold_toks in gold_toks_list: 122 | 123 | # If both lists are lenght 1, split to account for example like: 124 | # ["a b"], ["a"] -> ["a","b"], ["a"] 125 | if len(gold_toks) == 1 and len(pred_toks) == 1: 126 | gold_toks = gold_toks[0].split() 127 | pred_toks = pred_toks[0].split() 128 | 129 | common = Counter(pred_toks) & Counter(gold_toks) 130 | num_same = sum(common.values()) 131 | if len(gold_toks) == 0 or len(pred_toks) == 0: 132 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 133 | total_f1 += int(gold_toks == pred_toks) 134 | f1s.append(int(gold_toks == pred_toks)) 135 | total_recall += int(gold_toks == pred_toks) 136 | elif num_same == 0: 137 | total_f1 += 0 138 | f1s.append(0) 139 | else: 140 | precision = 1.0 * num_same / len(pred_toks) 141 | recall = 1.0 * num_same / len(gold_toks) 142 | f1 = (2 * precision * recall) / (precision + recall) 143 | total_f1 += f1 144 | total_recall += recall 145 | total_prec += precision 146 | f1s.append(f1) 147 | 148 | total += 1 149 | if not total: 150 | return 0.0, 0.0 151 | f1_avg = total_f1 / total 152 | f1_median = np.percentile(f1s, 50) 153 | return f1_avg, f1_median 154 | 155 | 156 | def evaluate( 157 | all_extractions:list, 158 | gold_key:str, 159 | field:str, 160 | manifest_session=None, 161 | overwrite_cache=False, 162 | combiner_mode='mv', 163 | extraction_fraction_thresh=0.8, 164 | use_abstension=True, 165 | ): 166 | normalized_field_name = field 167 | for char in ["'", ":", "<", ">", '"', "_", "-", " ", "none"]: 168 | normalized_field_name = normalized_field_name.replace(char, "") 169 | 170 | key2golds = defaultdict(list) 171 | key2preds = defaultdict(list) 172 | total_tokens_prompted = 0 173 | 174 | # handle FM golds on D_eval 175 | gold_file2metadata = all_extractions[gold_key] 176 | cleaned_gold_metadata = {} 177 | for filepath, gold_metadata in gold_file2metadata.items(): 178 | gold_metadata = normalize_value_type(gold_metadata, field) 179 | if len(gold_metadata) > 1: 180 | gold_metadata, num_toks = pick_a_gold_label( 181 | gold_metadata, 182 | attribute=field, 183 | manifest_session=manifest_session, 184 | overwrite_cache=overwrite_cache 185 | ) 186 | total_tokens_prompted += num_toks 187 | gold_metadata = clean_comparison(gold_metadata, field) 188 | cleaned_gold_metadata[filepath] = gold_metadata 189 | # handle function preds on D_eval 190 | for i, (key, file2metadata) in enumerate(all_extractions.items()): 191 | if key == gold_key: 192 | continue 193 | for filepath, metadata in file2metadata.items(): 194 | gold_metadata = cleaned_gold_metadata[filepath] 195 | pred_metadata = normalize_value_type(metadata, field) 196 | pred_metadata = clean_comparison(pred_metadata, field) 197 | key2golds[key].append(gold_metadata) 198 | key2preds[key].append(pred_metadata) 199 | 200 | # Handling abstensions 201 | 202 | metrics = {} 203 | for key, golds in key2golds.items(): 204 | num_extractions = 0 205 | for golds in key2golds[key]: 206 | if golds and not any(golds.lower() == wd for wd in ['none']): 207 | num_extractions += 1 208 | extraction_fraction = float(num_extractions) / float(len(key2golds[key])) 209 | if combiner_mode == "top_k": 210 | # Don't use the extraction fraction in the naive setting for scoring 211 | extraction_fraction = 0.0 212 | #print(f"Extraction fraction: {extraction_fraction}") 213 | preds = key2preds[key] 214 | f1, f1_med = text_f1( 215 | preds, golds, 216 | extraction_fraction=extraction_fraction, 217 | attribute=field, 218 | extraction_fraction_thresh=extraction_fraction_thresh, 219 | use_abstension=use_abstension, 220 | ) 221 | priorf1, priorf1_med = text_f1(preds, golds, extraction_fraction=0.0, attribute=field) 222 | metrics[key] = { 223 | "average_f1": f1, 224 | "median_f1": f1_med, 225 | "extraction_fraction": extraction_fraction, 226 | "prior_average_f1": priorf1, 227 | "prior_median_f1": priorf1_med, 228 | } 229 | 230 | return metrics, key2golds, total_tokens_prompted 231 | 232 | 233 | def get_topk_scripts_per_field( 234 | script2metrics, 235 | function_dictionary, 236 | all_extractions, 237 | gold_key='', 238 | k=3, 239 | do_end_to_end=False, 240 | keep_thresh = 0.5, 241 | cost_thresh = 1, 242 | combiner_mode='mv', 243 | ): 244 | script2avg = dict( 245 | sorted(script2metrics.items(), 246 | reverse=True, 247 | key=lambda x: (x[1]['average_f1'], x[1]['median_f1'])) 248 | ) 249 | 250 | top_k_scripts = [k for k, v in script2avg.items() if k != gold_key] 251 | top_k_values = [ 252 | max(v['average_f1'], v['median_f1']) for k, v in script2avg.items() if k != gold_key 253 | ] 254 | if not top_k_values: 255 | return [] 256 | 257 | best_value = top_k_values[0] 258 | best_script = top_k_scripts[0] 259 | if best_value < keep_thresh and do_end_to_end: 260 | return [] 261 | 262 | filtered_fn_scripts = { 263 | k:v for k, v in script2metrics.items() if ( 264 | v['average_f1'] >= keep_thresh or v['median_f1'] >= keep_thresh 265 | ) and "function" in k 266 | } 267 | top_k_fns = [] 268 | num_fns = 0 269 | if filtered_fn_scripts: 270 | script2avg = dict( 271 | sorted(filtered_fn_scripts.items(), 272 | reverse=True, 273 | key=lambda x: (x[1]['average_f1'], x[1]['median_f1'])) 274 | ) 275 | 276 | top_k_fns = [ 277 | k for k, v in script2avg.items() if k != gold_key and abs( 278 | max(v['average_f1'], v['median_f1'])-best_value 279 | ) < cost_thresh 280 | ] 281 | num_fns = len(top_k_fns) 282 | 283 | if num_fns: 284 | top_k_scripts = top_k_scripts[0:min(k, num_fns)] 285 | else: 286 | return [] 287 | 288 | # construct final set of functions 289 | final_set = [] 290 | for key in top_k_scripts: 291 | if key in top_k_fns: 292 | final_set.append(key) 293 | 294 | if len(final_set) > k: 295 | final_set = final_set[:k] 296 | if not final_set and not do_end_to_end: 297 | return [top_k_scripts[0]] 298 | 299 | # print results 300 | # print(f"Top {k} scripts:") 301 | # for script in final_set: 302 | # print(f"- {script}; Score: {script2metrics[script]}") 303 | print(f"Best script overall: {best_script}; Score: {script2metrics[best_script]}") 304 | return final_set 305 | 306 | -------------------------------------------------------------------------------- /evaporate/main.py: -------------------------------------------------------------------------------- 1 | import time 2 | from tqdm import tqdm 3 | from evaporate.run_profiler import prerun_profiler, identify_attributes, get_attribute_function 4 | from evaporate.profiler import get_model_extractions 5 | from evaporate.configs import set_profiler_args 6 | from evaporate.evaluate_synthetic_utils import text_f1, get_file_attribute 7 | from evaporate.evaluate_profiler import pick_a_gold_label, evaluate, get_topk_scripts_per_field 8 | from evaporate.retrieval import get_most_similarity 9 | from evaporate.profiler import get_functions, apply_final_profiling_functions,apply_final_ensemble,combine_extractions 10 | import os 11 | import json 12 | from collections import defaultdict, Counter 13 | import numpy as np 14 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 15 | 16 | 17 | class EvaporateData: 18 | def __init__(self, profiler_args): 19 | self.GOLD_MODEL = profiler_args["direct_extract_model"] 20 | profiler_args["GOLD_KEY"] = "gold_extraction_file" 21 | self.GOLD_KEY = "gold_extraction_file" 22 | 23 | self.profiler_args= set_profiler_args(profiler_args) 24 | self.data_dict = prerun_profiler(self.profiler_args) 25 | self.runtime = {} 26 | self.token_used = {} 27 | self.accuracy = {} 28 | self.attributes = [] 29 | self.direct_result = {} 30 | self.function_dictionary = {} 31 | self.all_extractions = {} 32 | self.manifest_sessions = self.data_dict["manifest_sessions"] 33 | self.selected_func_key = {} 34 | self.extract_result = {} 35 | self.all_metrics = None 36 | self.gold_extractions = self.data_dict["gold_extractions"] 37 | 38 | def save_results(): 39 | pass 40 | 41 | def get_attribute(self, do_end_to_end = False): 42 | if do_end_to_end or self.profiler_args.do_end_to_end: 43 | self.attributes, total_time, num_toks, evaluation_result = identify_attributes(self.profiler_args, self.data_dict, evaluation=True) 44 | self.runtime["get_attribute"] = total_time 45 | self.token_used["get_attribute"] = num_toks 46 | self.accuracy["get_attribute"] = evaluation_result 47 | else: 48 | self.attributes = self.data_dict["gold_attributes"] 49 | return self.attributes 50 | 51 | 52 | def direct_extract(self, use_retrieval_model = True, is_getting_sample = False, gold = ""): 53 | if(self.attributes == []): 54 | print("Please run get_attribute first") 55 | return 56 | files = list(self.data_dict["file2chunks"].keys()) 57 | if(is_getting_sample): 58 | files = self.data_dict["sample_files"] 59 | time_begin = time.time() 60 | token_used = 0 61 | for attribute in self.attributes: 62 | if(attribute in self.direct_result): 63 | print("already extract ", attribute) 64 | #continue 65 | new_file_chunk_dict = {} 66 | if(use_retrieval_model): 67 | baseline_sentence = attribute + ":"+ gold[attribute] 68 | for file in files: 69 | sentences = self.data_dict["file2chunks"][file] 70 | new_file_chunk_dict[file] = [get_most_similarity(baseline_sentence, sentences)] 71 | else: 72 | new_file_chunk_dict = self.data_dict["file2chunks"] 73 | extractions, num_toks, errored_out = get_model_extractions( 74 | new_file_chunk_dict, 75 | files, 76 | attribute, 77 | self.manifest_sessions[self.GOLD_MODEL], 78 | self.GOLD_MODEL, 79 | overwrite_cache=self.profiler_args.overwrite_cache, 80 | collecting_preds=True, 81 | ) 82 | token_used += num_toks 83 | self.direct_result[attribute] = {} 84 | for file in extractions: 85 | golds = [] 86 | for tmp in extractions[file]: 87 | golds.append( "- " + "\n- ".join(tmp)) 88 | golds = "- " + "\n- ".join(golds) 89 | if(use_retrieval_model): 90 | try: 91 | self.direct_result[attribute][file] = extractions[file][0] 92 | except: 93 | print("error in ", attribute, file, extractions[file]) 94 | else: 95 | self.direct_result[attribute][file] = pick_a_gold_label(golds, attribute, self.manifest_session) 96 | print("finish extract ", attribute) 97 | self.runtime["direct_extract"] = time.time() - time_begin 98 | self.token_used["direct_extract"] = token_used 99 | return self.direct_result, self.evaluate(self.direct_result) 100 | 101 | def get_extract_functions(self): 102 | self.runtime["get_extract_functions"] = 0 103 | self.token_used["get_extract_functions"] = 0 104 | begin_time = time.time() 105 | total_tokens_prompted = 0 106 | for attribute in self.attributes: 107 | if attribute in self.function_dictionary: 108 | print("already generate ", attribute, " function") 109 | continue 110 | self.all_extractions[attribute] = {} 111 | self.function_dictionary[attribute] = {} 112 | for model in self.profiler_args.EXTRACTION_MODELS: 113 | manifest_session = self.manifest_sessions[model] 114 | functions, function_promptsource, num_toks = get_functions( 115 | self.data_dict["file2chunks"], 116 | self.data_dict["sample_files"], 117 | {}, 118 | attribute, 119 | manifest_session, 120 | overwrite_cache=self.profiler_args.overwrite_cache, 121 | ) 122 | total_tokens_prompted += num_toks 123 | for fn_key, fn in functions.items(): 124 | self.all_extractions[attribute][fn_key], num_function_errors = apply_final_profiling_functions( 125 | self.data_dict["file2contents"], 126 | self.data_dict["sample_files"], 127 | fn, 128 | attribute, 129 | ) 130 | self.function_dictionary[attribute][fn_key] = {} 131 | self.function_dictionary[attribute][fn_key]['function'] = fn 132 | self.function_dictionary[attribute][fn_key]['promptsource'] = function_promptsource[fn_key] 133 | self.function_dictionary[attribute][fn_key]['extract_model'] = model 134 | self.runtime["get_extract_functions"] = time.time() - begin_time 135 | self.token_used["get_extract_functions"] = total_tokens_prompted 136 | return self.function_dictionary 137 | 138 | def weak_supervision(self, use_gold_key = False): 139 | result = self.direct_extract 140 | self.runtime["weak_supervision"] = 0 141 | self.token_used["weak_supervision"] = 0 142 | begin_time = time.time() 143 | total_tokens_prompted = 0 144 | for attribute in self.attributes: 145 | if attribute in self.selected_func_key: 146 | print("already weak supervision ", attribute) 147 | continue 148 | self.all_extractions[attribute]["gold_extraction_file"] = self.gold_extractions 149 | if(use_gold_key): 150 | self.GOLD_KEY = "gold_extraction_file" 151 | else: 152 | if(result == {}): 153 | print("Please run direct_extract first") 154 | return 155 | self.all_extractions[attribute]["gold-key"] = result[attribute] 156 | self.GOLD_KEY = "gold-key" 157 | self.all_metrics, key2golds, num_toks = evaluate( 158 | self.all_extractions[attribute], 159 | self.GOLD_KEY, 160 | field=attribute, 161 | manifest_session=self.manifest_sessions[self.profiler_args.GOLD_KEY], 162 | overwrite_cache=self.profiler_args.overwrite_cache, 163 | combiner_mode=self.profiler_args.combiner_mode, 164 | extraction_fraction_thresh=self.profiler_args.extraction_fraction_thresh, 165 | use_abstension=self.profiler_args.use_abstension, 166 | ) 167 | total_tokens_prompted += num_toks 168 | selected_keys = get_topk_scripts_per_field( 169 | self.all_metrics, 170 | self.function_dictionary[attribute], 171 | self.all_extractions, 172 | self.GOLD_KEY, 173 | k=self.profiler_args.num_top_k_scripts, 174 | do_end_to_end=self.profiler_args.do_end_to_end, 175 | combiner_mode=self.profiler_args.combiner_mode, 176 | keep_thresh = 0.00 177 | ) 178 | self.selected_func_key[attribute] = selected_keys 179 | self.runtime["weak_supervision"] = time.time() - begin_time 180 | self.token_used["weak_supervision"] = total_tokens_prompted 181 | return self.selected_func_key 182 | 183 | def apply_functions(self): 184 | total_tokens_prompted = 0 185 | self.runtime["apply_functions"] = 0 186 | self.token_used["apply_functions"] = 0 187 | self.extract_result = {} 188 | begin_time = time.time() 189 | for attribute in self.attributes: 190 | print(f"Apply the scripts to the data lake and save the metadata. Taking the top {self.profiler_args.num_top_k_scripts} scripts per field.") 191 | top_k_extractions, num_toks = apply_final_ensemble( 192 | self.profiler_args.full_file_groups, 193 | self.data_dict["file2chunks"], 194 | self.data_dict['file2contents'], 195 | self.selected_func_key[attribute], 196 | self.all_metrics, 197 | attribute, 198 | self.function_dictionary[attribute], 199 | data_lake=self.profiler_args.data_lake, 200 | manifest_sessions=self.manifest_sessions, 201 | function_cache=True, 202 | MODELS=self.profiler_args.EXTRACTION_MODELS, 203 | overwrite_cache=self.profiler_args.overwrite_cache, 204 | do_end_to_end=self.profiler_args.do_end_to_end, 205 | ) 206 | total_tokens_prompted += num_toks 207 | 208 | file2metadata, num_toks = combine_extractions( 209 | self.profiler_args, 210 | top_k_extractions, 211 | self.all_metrics, 212 | combiner_mode=self.profiler_args.combiner_mode, 213 | train_extractions=self.all_extractions[attribute], 214 | attribute=attribute, 215 | gold_key = self.GOLD_KEY, 216 | extraction_fraction_thresh=self.profiler_args.extraction_fraction_thresh, 217 | ) 218 | total_tokens_prompted += num_toks 219 | self.extract_result[attribute] = file2metadata 220 | self.runtime["apply_functions"] = time.time() - begin_time 221 | self.token_used["apply_functions"] = total_tokens_prompted 222 | return self.extract_result, self.evaluate(self.extract_result) 223 | 224 | def evaluate(self, result): 225 | f1_pred = {} 226 | for attribute in self.attributes: 227 | attribute2 = get_file_attribute(attribute) 228 | preds = [] 229 | golds = [] 230 | for file in result[attribute]: 231 | if(file in self.gold_extractions.keys()): 232 | preds.append(result[attribute][file]) 233 | golds.append(self.gold_extractions[file][attribute]) 234 | try: 235 | f1_pred[attribute] = text_f1(preds, golds) 236 | except: 237 | print(attribute) 238 | #turn to f1_pred[extraction_name] to a list of f1 scores 239 | f1_result = np.array(list(f1_pred.values())) 240 | return {"mean":f1_result.mean(), "std":f1_result.std(),"result":f1_pred} -------------------------------------------------------------------------------- /evaporate/weak_supervision/binary_deps.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | from itertools import chain, product, combinations 4 | from scipy.sparse import issparse 5 | import more_itertools 6 | import torch 7 | 8 | 9 | class DependentPGM: 10 | """ 11 | This class describes a PGM learned from labeled data with specified edge structure. 12 | 13 | Args: 14 | edges: list of edges that are dependent 15 | train_votes: n x m array of votes in {0, 1} 16 | train_gold: n array of true labels in {0, 1} 17 | """ 18 | def __init__( 19 | self, edges, train_votes, train_gold, abstains = False, classes = [0, 1], abstain_value = -1) -> None: 20 | """ 21 | Initialize the PGM by computing its junction tree factorization (c_tree and c_data) 22 | and by computing individual LF accuracy and class balance. 23 | """ 24 | 25 | self.edges = edges 26 | self.train_votes = train_votes 27 | self.train_gold = train_gold 28 | 29 | self.classes = classes 30 | self.k = len(classes) 31 | assert len(np.unique(self.train_gold)) == self.k 32 | 33 | self.abstains = abstains 34 | assert len(np.unique(self.train_votes)) == int(abstains) + self.k 35 | self.abstain_value = abstain_value 36 | 37 | self.n, self.m = self.train_votes.shape 38 | 39 | self.nodes = np.arange(self.m) 40 | self.higher_order = len(edges) != 0 41 | 42 | # construct data structures containing dependency graph information (maximal cliques and separator sets) 43 | self._set_clique_tree() 44 | self._set_clique_data() 45 | 46 | # compute LF accuracies and class balance 47 | self._get_accs_and_cb() 48 | 49 | def _get_scaled(self): 50 | if self.classes == [0, 1]: 51 | self.train_votes_scaled = 2*self.train_votes - 1 52 | self.train_gold_scaled = 2*self.train_gold - 1 53 | if self.abstains: 54 | self.train_votes_scaled[self.train_votes == self.abstain_value] = 0 55 | else: 56 | self.train_votes_scaled = self.train_votes 57 | self.train_gold_scaled = self.train_gold 58 | 59 | 60 | 61 | 62 | 63 | def _set_clique_tree(self): 64 | G1 = nx.Graph() 65 | G1.add_nodes_from(self.nodes) 66 | G1.add_edges_from(self.edges) 67 | 68 | # Check if graph is chordal 69 | # TODO: Add step to triangulate graph if not 70 | if not nx.is_chordal(G1): 71 | raise nx.NetworkXError("Graph triangulation not implemented.") 72 | 73 | # Create maximal clique graph G2 74 | # Each node is a maximal clique C_i 75 | # Let w = |C_i \cap C_j|; C_i, C_j have an edge with weight w if w > 0 76 | G2 = nx.Graph() 77 | for i, c in enumerate(nx.chordal_graph_cliques(G1)): 78 | G2.add_node(i, members=c) 79 | for i in G2.nodes(): 80 | for j in G2.nodes(): 81 | S = G2.nodes[i]["members"].intersection(G2.nodes[j]["members"]) 82 | w = len(S) 83 | if w > 0: 84 | G2.add_edge(i, j, weight=w, members=S) 85 | 86 | self.c_tree = nx.maximum_spanning_tree(G2) # should be maximum??? Because we want maximal separator sets 87 | # Return a minimum spanning tree of G2 88 | 89 | def _set_clique_data(self): 90 | # Create a helper data structure which maps cliques (as tuples of member 91 | # sources) --> {start_index, end_index, maximal_cliques}, where 92 | # the last value is a set of indices in this data structure 93 | self.c_data = dict() 94 | for i in range(self.m): 95 | self.c_data[i] = { 96 | "vertices": [i], 97 | "max_cliques": set( # which max clique i belongs to 98 | [ 99 | j 100 | for j in self.c_tree.nodes() 101 | if i in self.c_tree.nodes[j]["members"] 102 | ] 103 | ), 104 | } 105 | 106 | # Get the higher-order clique statistics based on the clique tree 107 | # First, iterate over the maximal cliques (nodes of c_tree) and 108 | # separator sets (edges of c_tree) 109 | if self.higher_order: 110 | counter = 0 111 | for item in chain(self.c_tree.nodes(), self.c_tree.edges()): 112 | if isinstance(item, int): 113 | C = self.c_tree.nodes[item] 114 | C_type = "node" 115 | elif isinstance(item, tuple): 116 | C = self.c_tree[item[0]][item[1]] 117 | C_type = "edge" 118 | else: 119 | raise ValueError(item) 120 | members = list(C["members"]) 121 | nc = len(members) 122 | 123 | # Else add one column for each possible value 124 | if nc != 1: 125 | # Add to self.c_data as well 126 | #idx = counter + m 127 | self.c_data[tuple(members)] = { 128 | "vertices": members, 129 | "max_cliques": set([item]) if C_type == "node" else set(item), 130 | } 131 | counter += 1 132 | 133 | 134 | def _get_accs_and_cb(self): 135 | classes = [0, 1] 136 | self.gold_idxs = [np.where(self.train_gold == c)[0] for c in classes] 137 | 138 | self.accs = np.zeros((self.m, 2)) # [i, j, k] = Pr(prompt_i = j| y = k) 139 | for p in range(self.m): 140 | for i in classes: 141 | self.accs[p, i] = len(np.where(self.train_votes[self.gold_idxs[i], p] == 1)[0]) / len(self.gold_idxs[i]) 142 | 143 | self.accs = np.clip(self.accs, 0.0001, 0.9999) 144 | self.balance = len(self.gold_idxs[1]) / self.n 145 | 146 | def get_clique_probs(self, idxs, vals, y): 147 | """ 148 | Computes marginal probability over voters indexed by idx, Pr(votes_idxs = vals | y). 149 | """ 150 | truth_matrix = np.ones(len(self.gold_idxs[y])).astype(bool) 151 | for i, lf in enumerate(idxs): 152 | truth_matrix = np.logical_and(truth_matrix, self.train_votes[self.gold_idxs[y], lf] == vals[i]) 153 | 154 | if len(np.where(truth_matrix == True)[0]) == 0: 155 | return 0.00001 156 | return len(np.where(truth_matrix == True)[0]) / len(self.gold_idxs[y]) 157 | 158 | 159 | def get_cond_probs(self, votes, y): 160 | """ 161 | Computes the probability Pr(votes | y). 162 | """ 163 | pr_y = self.balance if y == 1 else 1 - self.balance 164 | prod = pr_y 165 | 166 | for i in self.c_tree.nodes(): 167 | node = self.c_tree.nodes[i] 168 | members = list(node['members']) 169 | if len(members) == 1: 170 | v = members[0] 171 | print(f"multiplying by {votes[v] * self.accs[v, y]}") 172 | prod *= votes[v] * self.accs[v, y] + (1 - votes[v]) * (1 - self.accs[v, y]) 173 | else: 174 | print(members) 175 | print(f"multiplying by {self.get_clique_probs(members, votes[members], y)}") 176 | 177 | prod *= self.get_clique_probs(members, votes[members], y) 178 | 179 | for i in self.c_tree.edges(): 180 | edge = self.c_tree.edges[i] 181 | members = list(edge['members']) 182 | if len(members) == 1: 183 | v = members[0] 184 | deg = len(self.c_data[v]['max_cliques']) 185 | prod /= (votes[v] * self.accs[v, y] + (1 - votes[v]) * (1 - self.accs[v, y]))**(deg-1) 186 | 187 | print(members) 188 | print(f"Dividing by {votes[v] * self.accs[v, y] + (1 - votes[v]) * (1 - self.accs[v, y])} to the {deg - 1} power") 189 | 190 | else: 191 | deg = len(self.c_data[tuple(members)]['max_cliques']) 192 | prod /= (self.get_clique_probs(members, votes[members], y))**(deg-1) 193 | 194 | print(members) 195 | print(f"Dividing by {self.get_clique_probs(members, votes[members], y)} to the {deg - 1} power") 196 | 197 | print(prod) 198 | return prod 199 | 200 | def get_probs(self, votes): 201 | """ 202 | Computes the probability Pr(y = 1 | votes). 203 | """ 204 | pos = self.get_cond_probs(votes, 1) 205 | neg = self.get_cond_probs(votes, 0) 206 | if pos == 0: 207 | return 0 208 | else: 209 | return pos / (pos + neg) 210 | 211 | def evaluate(self, test_votes, test_gold): 212 | """ 213 | Using our learned PGM, output rounded estimates of Pr(y = 1 | votes) and computes its accuracy. 214 | 215 | Args: 216 | test_votes: vote array to perform inference on in {0, 1} 217 | test_gold: true labels to compare to in {0, 1} 218 | """ 219 | n_test = len(test_votes) 220 | 221 | output_rounded = np.zeros(n_test) 222 | output_probs = np.zeros(n_test) 223 | err = 0 224 | for i in range(n_test): 225 | output_probs[i] = self.get_probs(test_votes[i]) 226 | output_rounded[i] = np.round(output_probs[i]) 227 | err += output_rounded[i] != test_gold[i] 228 | 229 | accuracy = 1 - err / n_test 230 | 231 | return output_probs, output_rounded, accuracy 232 | 233 | 234 | def is_triangulated(nodes, edges): 235 | """ 236 | If a graph is triangulated (e.g. if a junction tree factorization exists). 237 | """ 238 | G1 = nx.Graph() 239 | G1.add_nodes_from(nodes) 240 | G1.add_edges_from(edges) 241 | return nx.is_chordal(G1) 242 | 243 | 244 | def structure_learning(m, votes, gold, acc_theta, classes = [0, 1], l1_lambda=0.2): 245 | """ 246 | Structure learning algorithm (Ising model selection) from Ravikumar (2010). 247 | 248 | Args: 249 | - votes: n_train x m array of training votes 250 | - gold: n_train array of gold labels on the training data 251 | - acc_theta: E[vote_i y] (where vote and y are scaled to [-1, 1]). This is a scaled version of accuracy that we will initialize some of the 252 | parameters in our PGM with in order to specify that we don't want to optimize over the edges between votes and y. 253 | We only are learning edges among votes! 254 | - classes: the list of classes the data can take on. 255 | - l1_lambda: l1 regularization strength 256 | """ 257 | 258 | # scale the data 259 | classes = np.sort(np.unique(gold)) 260 | vote_classes = np.sort(np.unique(votes)) 261 | if 0 in classes and 1 in classes: 262 | votes_scaled = 2*votes - 1 263 | gold_scaled = 2*gold - 1 264 | if len(vote_classes) == len(classes) + 1: 265 | votes_scaled[votes == -1] = 0 266 | else: 267 | votes_scaled = votes 268 | gold_scaled = gold 269 | 270 | acc_theta = torch.from_numpy(acc_theta).type(torch.FloatTensor) 271 | all_thetas = np.zeros((m, m)) # learned thetas from alg 272 | 273 | # for each prompt, we fit a logistic regression model on it with prompt_i's output as the response variable and all otehr prompt outputs as the covariates. 274 | # big_theta is a vector of weights that denote dependence on each prompt (0 is independence). 275 | for v in range(m): 276 | print(f"Learning neighborhood of vertex {v}.") 277 | if len(classes) == 2: 278 | big_theta = learn_neighborhood(m, v, votes_scaled, gold_scaled, acc_theta, l1_lambda) 279 | else: 280 | big_theta = learn_neighborhood_multi(m, v, votes_scaled, gold_scaled, acc_theta, l1_lambda, classes) 281 | all_thetas[v] = big_theta 282 | 283 | return all_thetas 284 | 285 | 286 | # v is the vertex whose neighborhood graph we are estimating 287 | def learn_neighborhood(m, vertex, votes, gold, accs, l1_lambda, epochs = 50000): 288 | """ 289 | Learn the neighborhood graph for a vertex. 290 | 291 | Args: 292 | - m: number of prompts 293 | - vertex: the index of the prompt we are selecting as the response variable 294 | - votes: votes on training data 295 | - gold: gold label of training data 296 | - accs: training accuracies of each prompt we use to initialize the PGM parameters with 297 | - l1_lambda: regularization strength 298 | - epochs: number of iterations 299 | """ 300 | n = len(gold) 301 | vote_y = np.concatenate((votes, gold.reshape(n, 1)), axis=1) 302 | 303 | xr = vote_y[:, vertex] 304 | x_notr = np.delete(vote_y, vertex, axis=1) 305 | 306 | xr = torch.from_numpy(xr).type(torch.FloatTensor) 307 | x_notr = torch.from_numpy(x_notr).type(torch.FloatTensor) 308 | 309 | 310 | theta = torch.zeros(m) # last index is for accuracy between vertex and y 311 | theta[m - 1] = accs[vertex] # initialize this to be the train accuracy. We do want this to be an optimizable variable still though. 312 | theta.requires_grad_() 313 | 314 | optimizer = torch.optim.SGD([theta], lr=0.0001) 315 | for t in range(epochs): 316 | optimizer.zero_grad() 317 | 318 | # logistic regression from Ravikumar et al 319 | fx = (torch.log(torch.exp(torch.matmul(x_notr, theta)) 320 | + torch.exp(-torch.matmul(x_notr, theta))).mean()) 321 | loss = fx - torch.multiply(xr, x_notr.T).mean(dim=1).dot(theta) + l1_lambda * torch.linalg.vector_norm(theta[:m], ord=1) 322 | 323 | loss.backward() 324 | optimizer.step() 325 | 326 | #if t % 1000 == 0: 327 | # print(f"Loss: {loss}") 328 | 329 | big_theta = np.concatenate([theta.detach().numpy()[:vertex], [0], theta.detach().numpy()[vertex:m - 1]]) 330 | return big_theta 331 | 332 | # v is the vertex whose neighborhood graph we are estimating 333 | def learn_neighborhood_multi(m, vertex, votes, gold, accs, l1_lambda, classes, epochs = 50000): 334 | # votes: in range {0, ... k} 335 | n = len(gold) 336 | vote_y = np.concatenate((votes, gold.reshape(n, 1)), axis=1) 337 | 338 | xr = vote_y[:, vertex] 339 | x_notr = np.delete(vote_y, vertex, axis=1) 340 | 341 | xr = torch.from_numpy(xr).type(torch.FloatTensor) 342 | x_notr = torch.from_numpy(x_notr).type(torch.FloatTensor) 343 | 344 | 345 | theta = torch.zeros(m) # last index is for accuracy between vertex and y 346 | theta[m - 1] = accs[vertex] # initialize this 347 | theta.requires_grad_() 348 | 349 | optimizer = torch.optim.SGD([theta], lr=0.0001) 350 | for t in range(epochs): 351 | optimizer.zero_grad() 352 | 353 | # logistic regression from Ravikumar et al 354 | mu = 0 355 | for i in range(x_notr.shape[1]): 356 | # mu = \sum_i theta_i * \sum_data sign{x_r = x_i} 357 | mu += (2*(xr == x_notr[:, i])-1).type(torch.FloatTensor).mean() * theta[i] 358 | 359 | fx = 0 360 | for k in classes: 361 | # \sum_y exp( \sum_i theta_i sign(x_i = y)) "normalization" 362 | fx += torch.exp(torch.matmul((2*(x_notr == k)-1).type(torch.FloatTensor), theta)).mean() 363 | 364 | loss = fx - mu + l1_lambda * torch.linalg.vector_norm(theta[:m], ord=1) 365 | 366 | loss.backward() 367 | optimizer.step() 368 | 369 | #if t % 1000 == 0: 370 | # print(f"Loss: {loss}") 371 | 372 | big_theta = np.concatenate([theta.detach().numpy()[:vertex], [0], theta.detach().numpy()[vertex:m - 1]]) 373 | return big_theta 374 | 375 | def main(): 376 | # load data 377 | vote_arr_train = np.load('./data/youtube-spam/train_votes.npy').T 378 | vote_arr_test = np.load('./data/youtube-spam/test_votes.npy').T 379 | gold_arr_train = np.load('./data/youtube-spam/train_gold.npy').T 380 | gold_arr_test = np.load('./data/youtube-spam/test_gold.npy').T 381 | 382 | # vote_arr_train = np.concatenate((vote_arr_train[:, 0: 2], vote_arr_train[:, 4:]), axis=1) 383 | # vote_arr_test = np.concatenate((vote_arr_test[:, 0: 2], vote_arr_test[:, 4:]), axis=1) 384 | 385 | n_train, num_prompts = vote_arr_train.shape 386 | 387 | 388 | # make validation set 389 | np.random.seed(4) 390 | val_idxs = np.random.choice(np.arange(n_train), size= 28, replace=False) 391 | vote_arr_val = vote_arr_train[val_idxs, :] 392 | vote_arr_train = np.delete(vote_arr_train, val_idxs, axis=0) 393 | 394 | gold_arr_val = gold_arr_train[val_idxs] 395 | gold_arr_train = np.delete(gold_arr_train, val_idxs) 396 | 397 | nodes = np. arange(num_prompts) 398 | 399 | 400 | # specify edgeset 401 | # edges =[(0, 1)] 402 | #model = DependentPGM(edges, vote_arr_train, gold_arr_train) 403 | #probs, output, acc = model.evaluate(vote_arr_test, gold_arr_test) 404 | #print(acc) 405 | 406 | 407 | # Brute-force iteration through a bunch of edges 408 | all_edges = list(combinations(nodes, 2)) 409 | small_edgesets = list(more_itertools.powerset(all_edges)) 410 | #small_edgesets = list(combinations(all_edges, 0)) + list(combinations(all_edges, 1)) + list(combinations(all_edges, 2)) + list(combinations(all_edges, 3)) 411 | scores = np.zeros(len(small_edgesets)) 412 | 413 | for i, edgeset in enumerate(small_edgesets): 414 | if len(edgeset) > 4: 415 | break 416 | if not is_triangulated(nodes, edgeset): 417 | continue 418 | model = DependentPGM(edgeset, vote_arr_train, gold_arr_train) 419 | 420 | probs, output, scores[i] = model.evaluate(vote_arr_val, gold_arr_val) 421 | if i % 100 == 0: 422 | print(f"Edgeset: {edgeset} \n score: {scores[i]}") 423 | 424 | print(f"Best edgeset score: {scores.max()}") 425 | print(f"Best edgeset: {small_edgesets[scores.argmax()]}") 426 | 427 | edges = small_edgesets[scores.argmax()] 428 | 429 | vote_arr_train = np.concatenate((vote_arr_train, vote_arr_val)) 430 | gold_arr_train = np.concatenate((gold_arr_train, gold_arr_val)) 431 | 432 | model = DependentPGM(edges, vote_arr_train, gold_arr_train) 433 | probs, output, acc = model.evaluate(vote_arr_test, gold_arr_test) 434 | print(f"Final model accuracy: {acc}") 435 | 436 | 437 | if __name__ == "__main__": 438 | main() 439 | -------------------------------------------------------------------------------- /evaporate/configs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import datetime 4 | 5 | def get_run_string( 6 | data_lake, today, file_groups, profiler_args, do_end_to_end, 7 | train_size, dynamicbackoff, models 8 | ): 9 | body = profiler_args.body_only # Baseline systems only operate on the HTML body 10 | model_ct = len(models) 11 | if profiler_args.use_qa_model: 12 | model_ct += 1 13 | run_string = f"dataLake{data_lake}_date{today}_fileSize{len(file_groups)}_trainSize{train_size}_numAggregate{profiler_args.num_top_k_scripts}_chunkSize{profiler_args.chunk_size}_removeTables{profiler_args.remove_tables}_body{body}_cascading{do_end_to_end}_useBackoff{dynamicbackoff}_MODELS{model_ct}" 14 | return run_string 15 | 16 | def get_data_lake_info(args): 17 | extractions_file = None 18 | 19 | if 1: 20 | DATA_DIR = args.data_dir 21 | file_groups = os.listdir(args.data_dir) 22 | if not DATA_DIR.endswith("/"): 23 | DATA_DIR += "/" 24 | file_groups = [f"{DATA_DIR}{file_group}" for file_group in file_groups if not file_group.startswith(".")] 25 | full_file_groups = file_groups.copy() 26 | extractions_file = args.gold_extractions_file 27 | parser = "txt" 28 | 29 | return file_groups, extractions_file, parser, full_file_groups 30 | 31 | #default args for the experiment 32 | def get_experiment_args(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument( 35 | "--data_lake", 36 | type=str, 37 | help="Name of the data lake to operate over. Must be in configs.py" 38 | ) 39 | parser.add_argument( 40 | "--data_dir", 41 | type=str, 42 | help="Path to raw data-lake documents", 43 | ) 44 | parser.add_argument( 45 | "--base_data_dir", 46 | type=str, 47 | default="/", 48 | help="Path to save intermediate result", 49 | ) 50 | parser.add_argument( 51 | "--cache_dir", 52 | type=str, 53 | default= "", 54 | help = "Path to cache intermediate files during system execution", 55 | ) 56 | parser.add_argument( 57 | "--generative_index_path", 58 | type=str, 59 | default= "", 60 | help = "Path to store the generated structured view of the data lake", 61 | ) 62 | parser.add_argument( 63 | "--gold_extractions_file", 64 | type=str, 65 | default= "", 66 | help = "Path to store the generated structured view of the data lake", 67 | ) 68 | parser.add_argument( 69 | "--do_end_to_end", 70 | action='store_true', 71 | default=False, 72 | help="True for OpenIE, False for ClosedIE" 73 | ) 74 | 75 | parser.add_argument( 76 | "--num_attr_to_cascade", 77 | type=int, 78 | default=35, 79 | help="Number of attributes to generate functions for" 80 | ) 81 | 82 | parser.add_argument( 83 | "--num_top_k_scripts", 84 | type=int, 85 | default=10, 86 | help="Number of generated functions to combine over for each attribute" 87 | ) 88 | 89 | parser.add_argument( 90 | "--train_size", 91 | type=int, 92 | default=10, 93 | help="Number of files to prompt on" 94 | ) 95 | 96 | parser.add_argument( 97 | "--combiner_mode", 98 | type=str, 99 | default='ws', 100 | help="Combiner mode for combining the outputs of the generated functions", 101 | choices=['ws', 'mv', 'top_k'] 102 | ) 103 | 104 | parser.add_argument( 105 | "--use_dynamic_backoff", 106 | action="store_true", 107 | default=True, 108 | help="Whether to generate functions or do Evaporate-Direct", 109 | ) 110 | 111 | parser.add_argument( 112 | "--KEYS", 113 | type=str, 114 | default=[], 115 | help="List of keys to use the model api", 116 | nargs='*' 117 | ) 118 | parser.add_argument( 119 | "--MODELS", 120 | type=str, 121 | default=["gpt-4"], 122 | help="List of models to use for the extraction step" 123 | ) 124 | parser.add_argument( 125 | "--EXTRACTION_MODELS", 126 | type=str, 127 | default=["gpt-4"], 128 | help="List of models to use for the extraction step" 129 | ) 130 | parser.add_argument( 131 | "--MODEL2URL", 132 | type=str, 133 | default={}, 134 | help="Dict mapping models to their urls" 135 | ) 136 | parser.add_argument( 137 | "--use_qa_model", 138 | action="store_true", 139 | default=False, 140 | help="Whether to use a QA model for the extraction step" 141 | ) 142 | parser.add_argument( 143 | "--GOLD_KEY", 144 | type=str, 145 | default="gpt-4", 146 | help="Key to use for the gold standard" 147 | ) 148 | 149 | parser.add_argument( 150 | "--eval_size", 151 | type=int, 152 | default=15, 153 | ) 154 | 155 | parser.add_argument( 156 | "--max_chunks_per_file", 157 | type=int, 158 | default=-1, 159 | ) 160 | 161 | parser.add_argument( 162 | "--chunk_size", 163 | type=int, 164 | default=3000, 165 | ) 166 | 167 | parser.add_argument( 168 | "--extraction_fraction_thresh", 169 | type=int, 170 | default=0.9, 171 | help="for abstensions approach", 172 | ) 173 | 174 | parser.add_argument( 175 | "--remove_tables", 176 | action="store_true", 177 | default=False, 178 | help="Remove tables from the html files?", 179 | ) 180 | 181 | parser.add_argument( 182 | "--body_only", 183 | action="store_true", 184 | default=False, 185 | help="Only use HTML body", 186 | ) 187 | 188 | parser.add_argument( 189 | "--max_metadata_fields", 190 | type=int, 191 | default=15, 192 | ) 193 | 194 | parser.add_argument( 195 | "--overwrite_cache", 196 | action='store_true', 197 | default=False, 198 | help="overwrite the manifest cache" 199 | ) 200 | parser.add_argument( 201 | "--swde_plus", 202 | action="store_true", 203 | default=False, 204 | help="Whether to use the extended SWDE dataset to measure OpenIE performance", 205 | ) 206 | 207 | parser.add_argument( 208 | "--schema_id_sizes", 209 | type=int, 210 | default=0, 211 | help="Number of documents to use for schema identification stage, if it differs from extraction", 212 | ) 213 | 214 | parser.add_argument( 215 | "--slice_results", 216 | action="store_true", 217 | default=False, 218 | help="Whether to measure the results by attribute-slice", 219 | ) 220 | 221 | parser.add_argument( 222 | "--fn_generation_prompt_num", 223 | type=int, 224 | default=-1, 225 | help="For ablations on function generation with diversity, control which prompt we use. Default is all.", 226 | ) 227 | 228 | parser.add_argument( 229 | "--upper_bound_fns", 230 | action="store_true", 231 | default=False, 232 | help="For ablations that select functions using ground truth instead of the FM.", 233 | ) 234 | 235 | parser.add_argument( 236 | "--use_alg_filtering", 237 | type=str, 238 | default=True, 239 | help="Whether to filter functions based on quality.", 240 | ) 241 | 242 | parser.add_argument( 243 | "--use_abstension", 244 | type=str, 245 | default=True, 246 | help="Whether to use the abstensions approach.", 247 | ) 248 | 249 | parser.add_argument( 250 | "--set_dicts", 251 | type=str, 252 | default='', 253 | help="Alternate valid names for the SWDE attributes as provided in the benchmark.", 254 | ) 255 | 256 | parser.add_argument( 257 | "--topic", 258 | type=list, 259 | default=[], 260 | help="Topic of the data lake", 261 | ) 262 | experiment = parser.parse_args() 263 | return experiment 264 | 265 | #get args related to data storage paths 266 | def get_args(profiler_args): 267 | if(profiler_args.generative_index_path == ""): 268 | profiler_args.generative_index_path = os.path.join(profiler_args.base_data_dir, "generative_indexes/", profiler_args.data_lake) 269 | if(profiler_args.cache_dir == ""): 270 | profiler_args.cache_dir = os.path.join(profiler_args.base_data_dir, "cache/", profiler_args.data_lake) 271 | if(profiler_args.gold_extractions_file == ""): 272 | profiler_args.gold_extractions_file = os.path.join(profiler_args.base_data_dir, "gold_extractions.json" ) 273 | parser = argparse.ArgumentParser( 274 | "LLM explorer.", 275 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 276 | ) 277 | 278 | parser.add_argument( 279 | "--overwrite_cache", 280 | type=bool, 281 | default=profiler_args.overwrite_cache, 282 | help="Whether to overwrite the caching for prompts." 283 | ) 284 | 285 | parser.add_argument( 286 | "--data_lake", 287 | type=str, 288 | default=profiler_args.data_lake, 289 | help="Name of the data lake" 290 | ) 291 | 292 | parser.add_argument( 293 | "--data_dir", 294 | type=str, 295 | default = profiler_args.data_dir, 296 | help="Path to raw data-lake documents", 297 | ) 298 | 299 | parser.add_argument( 300 | "--generative_index_path", 301 | type=str, 302 | default = profiler_args.generative_index_path, 303 | help="Path to store the generated structured view of the data lake", 304 | ) 305 | 306 | parser.add_argument( 307 | "--cache_dir", 308 | type=str, 309 | default=profiler_args.cache_dir, 310 | help="Path to cache intermediate files during system execution", 311 | ) 312 | 313 | parser.add_argument( 314 | "--set_dicts", 315 | type=str, 316 | default=profiler_args.set_dicts, 317 | help="Alternate valid names for the SWDE attributes as provided in the benchmark.", 318 | ) 319 | 320 | parser.add_argument( 321 | "--gold_extractions_file", 322 | type=str, 323 | default=profiler_args.gold_extractions_file, 324 | help="Path to store the generated structured view of the data lake", 325 | ) 326 | parser.add_argument( 327 | "--topic", 328 | type=list, 329 | default=profiler_args.topic, 330 | help="Topic of the data lake", 331 | ) 332 | 333 | args = parser.parse_args(args=[]) 334 | return args 335 | 336 | #for notebook settings 337 | def set_profiler_args(information): 338 | parser = argparse.ArgumentParser( 339 | "LLM explorer.", 340 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 341 | ) 342 | parser.add_argument( 343 | "--data_lake", 344 | type=str, 345 | help="Name of the data lake to operate over. Must be in configs.py" 346 | ) 347 | parser.add_argument( 348 | "--data_dir", 349 | type=str, 350 | help="Path to raw data-lake documents", 351 | ) 352 | parser.add_argument( 353 | "--base_data_dir", 354 | type=str, 355 | default="/", 356 | help="Path to save intermediate result", 357 | ) 358 | parser.add_argument( 359 | "--cache_dir", 360 | type=str, 361 | default= "", 362 | help = "Path to cache intermediate files during system execution", 363 | ) 364 | parser.add_argument( 365 | "--generative_index_path", 366 | type=str, 367 | default= "", 368 | help = "Path to store the generated structured view of the data lake", 369 | ) 370 | parser.add_argument( 371 | "--gold_extractions_file", 372 | type=str, 373 | default= "", 374 | help = "Path to store the generated structured view of the data lake", 375 | ) 376 | 377 | parser.add_argument( 378 | "--num_attr_to_cascade", 379 | type=int, 380 | default=35, 381 | help="Number of attributes to generate functions for" 382 | ) 383 | 384 | parser.add_argument( 385 | "--num_top_k_scripts", 386 | type=int, 387 | default=10, 388 | help="Number of generated functions to combine over for each attribute" 389 | ) 390 | 391 | 392 | parser.add_argument( 393 | "--combiner_mode", 394 | type=str, 395 | default='ws', 396 | help="Combiner mode for combining the outputs of the generated functions", 397 | choices=['ws', 'mv', 'top_k'] 398 | ) 399 | 400 | parser.add_argument( 401 | "--use_dynamic_backoff", 402 | action="store_true", 403 | default=True, 404 | help="Whether to generate functions or do Evaporate-Direct", 405 | ) 406 | 407 | parser.add_argument( 408 | "--KEYS", 409 | type=str, 410 | default=[], 411 | help="List of keys to use the model api", 412 | nargs='*' 413 | ) 414 | parser.add_argument( 415 | "--MODELS", 416 | type=str, 417 | default=["gpt-4"], 418 | help="List of models to use for the extraction step" 419 | ) 420 | parser.add_argument( 421 | "--EXTRACTION_MODELS", 422 | type=str, 423 | default=["gpt-4"], 424 | help="List of models to use for the extraction step" 425 | ) 426 | parser.add_argument( 427 | "--MODEL2URL", 428 | type=str, 429 | default={}, 430 | help="Dict mapping models to their urls" 431 | ) 432 | parser.add_argument( 433 | "--use_qa_model", 434 | action="store_true", 435 | default=False, 436 | help="Whether to use a QA model for the extraction step" 437 | ) 438 | parser.add_argument( 439 | "--GOLD_KEY", 440 | type=str, 441 | default="gpt-4", 442 | help="Key to use for the gold standard" 443 | ) 444 | 445 | parser.add_argument( 446 | "--eval_size", 447 | type=int, 448 | default=15, 449 | ) 450 | 451 | parser.add_argument( 452 | "--chunk_size", 453 | type=int, 454 | default=3000, 455 | ) 456 | 457 | parser.add_argument( 458 | "--max_chunks_per_file", 459 | type=int, 460 | default=-1, 461 | ) 462 | 463 | 464 | parser.add_argument( 465 | "--extraction_fraction_thresh", 466 | type=int, 467 | default=0.9, 468 | help="for abstensions approach", 469 | ) 470 | 471 | parser.add_argument( 472 | "--remove_tables", 473 | action="store_true", 474 | default=False, 475 | help="Remove tables from the html files?", 476 | ) 477 | 478 | parser.add_argument( 479 | "--body_only", 480 | action="store_true", 481 | default=False, 482 | help="Only use HTML body", 483 | ) 484 | 485 | parser.add_argument( 486 | "--max_metadata_fields", 487 | type=int, 488 | default=15, 489 | ) 490 | 491 | parser.add_argument( 492 | "--overwrite_cache", 493 | action='store_true', 494 | default=False, 495 | help="overwrite the manifest cache" 496 | ) 497 | parser.add_argument( 498 | "--swde_plus", 499 | action="store_true", 500 | default=False, 501 | help="Whether to use the extended SWDE dataset to measure OpenIE performance", 502 | ) 503 | 504 | parser.add_argument( 505 | "--schema_id_sizes", 506 | type=int, 507 | default=0, 508 | help="Number of documents to use for schema identification stage, if it differs from extraction", 509 | ) 510 | 511 | parser.add_argument( 512 | "--slice_results", 513 | action="store_true", 514 | default=False, 515 | help="Whether to measure the results by attribute-slice", 516 | ) 517 | 518 | parser.add_argument( 519 | "--fn_generation_prompt_num", 520 | type=int, 521 | default=-1, 522 | help="For ablations on function generation with diversity, control which prompt we use. Default is all.", 523 | ) 524 | 525 | parser.add_argument( 526 | "--upper_bound_fns", 527 | action="store_true", 528 | default=False, 529 | help="For ablations that select functions using ground truth instead of the FM.", 530 | ) 531 | 532 | parser.add_argument( 533 | "--use_alg_filtering", 534 | action='store_true', 535 | default=False, 536 | help="Whether to filter functions based on quality.", 537 | ) 538 | 539 | parser.add_argument( 540 | "--use_abstension", 541 | action='store_true', 542 | default=False, 543 | help="Whether to use the abstensions approach.", 544 | ) 545 | 546 | parser.add_argument( 547 | "--do_end_to_end", 548 | action='store_true', 549 | default=False, 550 | help="True for OpenIE, False for ClosedIE" 551 | ) 552 | 553 | parser.add_argument( 554 | "--set_dicts", 555 | type=str, 556 | default='', 557 | help="Alternate valid names for the SWDE attributes as provided in the benchmark.", 558 | ) 559 | 560 | parser.add_argument( 561 | "--topic", 562 | type=list, 563 | default=[], 564 | help="Topic of the data lake", 565 | ) 566 | 567 | args = parser.parse_args(args=[]) 568 | for key, value in information.items(): 569 | setattr(args, key, value) 570 | today = datetime.datetime.today().strftime("%m%d%Y") 571 | file_groups, extractions_file, parser, full_file_groups = get_data_lake_info(args) 572 | setattr(args, "file_groups", file_groups) 573 | setattr(args, "extractions_file", extractions_file) 574 | setattr(args, "parser", parser) 575 | setattr(args, "full_file_groups", full_file_groups) 576 | if("train_size" not in args): 577 | args.train_size = 10 578 | if "use_dynamic_backoff" not in args: 579 | args.use_dynamic_backoff = True 580 | setattr(args, "run_string", get_run_string(args.data_lake, today, args.full_file_groups, args, args.do_end_to_end, args.train_size, args.use_dynamic_backoff, args.EXTRACTION_MODELS)) 581 | if(args.generative_index_path == ""): 582 | args.generative_index_path = os.path.join(args.base_data_dir, "generative_indexes/", args.data_lake) 583 | if(args.cache_dir == ""): 584 | args.cache_dir = os.path.join(args.base_data_dir, "cache/", args.data_lake) 585 | if(args.gold_extractions_file == ""): 586 | args.gold_extractions_file = os.path.join(args.base_data_dir, "gold_extractions.json" ) 587 | return args 588 | 589 | 590 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from evaporate.main import EvaporateData\n", 20 | "import json" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 3, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "MANIFEST_URL = \"http://127.0.0.1:5000\" # please make sure that a local manifest session is running with your model at this address\n", 30 | "DATA_DIR = \"/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k\"\n", 31 | "GOLD_PATH = \"/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/table.json\"\n", 32 | "BASE_DATA_DIR = \"/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/\"\n", 33 | "model = \"mistralai/Mistral-7B-Instruct-v0.2\"" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "# Set Up" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 4, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "name": "stdout", 50 | "output_type": "stream", 51 | "text": [ 52 | "using together\n" 53 | ] 54 | }, 55 | { 56 | "name": "stderr", 57 | "output_type": "stream", 58 | "text": [ 59 | "Chunking files: 100%|██████████| 100/100 [00:00<00:00, 1089.38it/s]\n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "profiler_args= {\n", 65 | " \"data_lake\": \"fda\", \n", 66 | " \"combiner_mode\": \"mv\", \n", 67 | " \"do_end_to_end\": False,\n", 68 | " \"KEYS\": [\"\"], \n", 69 | " \"MODELS\":[model],\n", 70 | " \"EXTRACTION_MODELS\": [model],\n", 71 | " \"MODEL2URL\" : {},\n", 72 | " \"data_dir\": f\"{DATA_DIR}\", \n", 73 | " \"chunk_size\": 3000,\n", 74 | " \"base_data_dir\": f\"{BASE_DATA_DIR}\", \n", 75 | " \"gold_extractions_file\": GOLD_PATH,\n", 76 | " \"direct_extract_model\": model\n", 77 | "}\n", 78 | "\n", 79 | "evaporate = EvaporateData(profiler_args)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "# Get Schema Attribute" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 5, 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "name": "stdout", 96 | "output_type": "stream", 97 | "text": [ 98 | "['purpose for submission', 'type of test', 'classification', 'product code', 'panel', 'indications for use', 'predicate device name', 'proposed labeling', 'conclusion', '510k number', 'applicant', 'predicate 510k number', 'proprietary and established names', 'regulation section', 'measurand', 'intended use']\n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "attributes = evaporate.get_attribute(do_end_to_end=False)\n", 104 | "print(attributes)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "# Get Direct Attribute" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 6, 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "name": "stderr", 121 | "output_type": "stream", 122 | "text": [ 123 | "Extracting attribute purpose for submission using LM: 100%|██████████| 10/10 [00:28<00:00, 2.80s/it]\n" 124 | ] 125 | }, 126 | { 127 | "name": "stdout", 128 | "output_type": "stream", 129 | "text": [ 130 | "finish extract purpose for submission\n" 131 | ] 132 | }, 133 | { 134 | "name": "stderr", 135 | "output_type": "stream", 136 | "text": [ 137 | "Extracting attribute type of test using LM: 100%|██████████| 10/10 [00:14<00:00, 1.41s/it]" 138 | ] 139 | }, 140 | { 141 | "name": "stdout", 142 | "output_type": "stream", 143 | "text": [ 144 | "finish extract type of test\n", 145 | "{'purpose for submission': {'/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K150526.txt': ['The given context is a section from a Decision Memorandum for a 510(k) submission. In this section', 'the purpose for submission is stated as New Device. This indicates that the submission is for a new medical device that has not been previously cleared or approved by the FDA.'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K151046.txt': ['To obtain substantial equivalence determination for the illumigene® HSV 1&2 DNA Amplification Assay'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K180886.txt': ['To obtain a substantial equivalence for the addition of Delafloxacin at concentrations of 0.002-32 µg/mL for susceptibility testing of non-fastidious Gram negative organisms'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K181525.txt': ['Purpose for Submission: New Device'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K151265.txt': ['New Submission'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K171641.txt': ['This is a new 510(k) application for the determination of Substantial Equivalence for the Mesa Biotech Accula Flu A/Flu B Test and associated instrument. Mesa Biotech', 'Inc. has submitted a combined 510(k) and CLIA waiver package for dual review.'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K182472.txt': ['To obtain a Substantial Equivalence determination for the Cepheid Xpert GBS LB Control Panel for use with the Cepheid Xpert GBS LB Assay on the GeneXpert Instrument System.'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K161714.txt': ['Answer: The Immunalysis Barbiturates Urine Enzyme Immunoassay is a homogeneous enzyme immunoassay with a cutoff of 200 ng/mL. The assay is intended for use in laboratories for the qualitative and semi-quantitative analysis of Barbiturates in human urine with automated clinical chemistry analyzers. This assay is calibrated against Secobarbital. This in vitro diagnostic device is for prescription use only. The semi-quantitative mode is for purposes of enabling laboratories to determine an appropriate dilution of the specimen for confirmation by a confirmatory method such as Gas Chromatography/ Mass Spectrometry (GC-MS) or Liquid Chromatography/ Tandem Mass Spectrometry (LC-MS/MS) or permitting laboratories to establish quality control procedures. The Immunalysis Barbiturates Urine Enzyme Immunoassay provides only a preliminary analytical test result. A more specific alternate chemical method must be used in order to obtain a confirmed analytical result. GC-MS or LC-MS/MS is the preferred confirmatory method. Clinical consideration and professional judgment should be applied to any drug of abuse test result', 'particularly when preliminary positive results are used. The Immunalysis Multi-Drug Calibrators: The Immunalysis Multi-Drug Calibrators are intended for in vitro diagnostic use for the calibration of assays for the following analytes: Benzoylecgonine', 'Methamphetamine', 'Morphine', 'PCP', 'Secobarbital and Oxazepam. The calibrators are designed for prescription use with immunoassays.'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K162042.txt': ['2. Indication(s) for use:'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K170974.txt': ['Purpose for Submission: New device']}, 'type of test': {'/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K150526.txt': ['Type of Test: Quantitative', 'turbidimetric'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K151046.txt': ['Type of Test: Qualitative in vitro diagnostic device for the direct detection and differentiation of HSV-1 and HSV-2 DNA in cutaneous and mucocutaneous lesion specimens from symptomatic patients suspected of Herpetic infections.'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K180886.txt': ['Type of Test: Quantitative Antimicrobial Susceptibility Test growth based detection'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K181525.txt': ['In the provided context', 'the type of test is mentioned as Quantitative immunoturbidimetric assay. This information is important for understanding the nature of the test being performed and the technology used in the analysis. Immunoturbidimetry is a common analytical technique used in clinical chemistry and immunology to measure the concentration of various substances in a sample', 'such as proteins or enzymes', 'by detecting the turbidity or cloudiness caused by the interaction between the analyte and specific antibodies or reagents. In this case', 'the test is being used to quantitatively determine the free protein S antigen in human plasma.'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K151265.txt': ['Type of Test: Quantitative Amperometric Assay; glucose dehydrogenase - flavin adenine dinucleotide (GDH-FAD)'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K171641.txt': ['Type of Test: RT-PCR amplification followed by hybridization and colorimetric visualization of amplified products on a test strip'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K182472.txt': ['Type of Test: The Cepheid Xpert GBS LB Control Panel is intended for use as external quality control materials to monitor the performance of in vitro laboratory nucleic acid testing procedures for the qualitative detection of Group B Streptococcus (GBS) performed with the Cepheid Xpert GBS LB Assay on the GeneXpert Instrument System.'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K161714.txt': ['Homogenous Enzyme Immunoassay', 'Qualitative and Semi-quantitative'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K162042.txt': ['Type of Test: Quantitative', 'mid-infrared (MIR) spectrophotometric assay'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K170974.txt': ['Quantitative and Semi-quantitative Flow Cytometric Immunoassays']}}\n", 146 | "{'mean': 0.723529274358417, 'std': 0.08378645930419634, 'result': {'purpose for submission': (0.606440811221299, 0.7472527472527473), 'type of test': (0.7014743670578509, 0.8389491719017703)}}\n" 147 | ] 148 | }, 149 | { 150 | "name": "stderr", 151 | "output_type": "stream", 152 | "text": [ 153 | "\n" 154 | ] 155 | } 156 | ], 157 | "source": [ 158 | "with open(GOLD_PATH, \"r\") as f:\n", 159 | " gold = json.load(f)\n", 160 | " gold = gold[\"/data/evaporate/fda-ai-pmas/510k/K151917.txt\"]\n", 161 | "#state reference dict for using retrieval model\n", 162 | "direct_attribute, direct_eval = evaporate.direct_extract(is_getting_sample = True, gold = gold, use_retrieval_model= True)\n", 163 | "print(direct_attribute)\n", 164 | "print(direct_eval)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "metadata": {}, 170 | "source": [ 171 | "# Get Attribute functions" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 7, 177 | "metadata": {}, 178 | "outputs": [ 179 | { 180 | "name": "stderr", 181 | "output_type": "stream", 182 | "text": [ 183 | "Generating functions for attribute purpose for submission: 100%|██████████| 10/10 [07:08<00:00, 42.83s/it]\n" 184 | ] 185 | }, 186 | { 187 | "name": "stdout", 188 | "output_type": "stream", 189 | "text": [ 190 | "Timeout 0\n", 191 | "Timeout\n", 192 | "Timeout 0\n", 193 | "Timeout\n" 194 | ] 195 | }, 196 | { 197 | "name": "stderr", 198 | "output_type": "stream", 199 | "text": [ 200 | "Generating functions for attribute type of test: 100%|██████████| 10/10 [07:07<00:00, 42.80s/it]\n" 201 | ] 202 | }, 203 | { 204 | "name": "stdout", 205 | "output_type": "stream", 206 | "text": [ 207 | "Timeout 0\n", 208 | "Timeout\n", 209 | "Timeout 0\n", 210 | "Timeout\n" 211 | ] 212 | } 213 | ], 214 | "source": [ 215 | "extract_functions = evaporate.get_extract_functions()" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "# Weak Supervision (select the best method)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 8, 228 | "metadata": {}, 229 | "outputs": [ 230 | { 231 | "name": "stdout", 232 | "output_type": "stream", 233 | "text": [ 234 | "Best script overall: function_94; Score: {'average_f1': 0.10487804878048781, 'median_f1': 0.0, 'extraction_fraction': 1.0, 'prior_average_f1': 0.10487804878048781, 'prior_median_f1': 0.0}\n", 235 | "Best script overall: function_0; Score: {'average_f1': 0.1, 'median_f1': 0.0, 'extraction_fraction': 1.0, 'prior_average_f1': 0.1, 'prior_median_f1': 0.0}\n" 236 | ] 237 | } 238 | ], 239 | "source": [ 240 | "selected_keys = evaporate.weak_supervision(use_gold_key=True)" 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "metadata": {}, 246 | "source": [ 247 | "# apply on the data lake" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 9, 253 | "metadata": {}, 254 | "outputs": [ 255 | { 256 | "name": "stdout", 257 | "output_type": "stream", 258 | "text": [ 259 | "Apply the scripts to the data lake and save the metadata. Taking the top 10 scripts per field.\n", 260 | "Applying function function_94...\n", 261 | "Applying function function_1...\n", 262 | "Applying function function_2...\n", 263 | "Applying function function_3...\n", 264 | "Applying function function_5...\n", 265 | "Applying function function_6...\n", 266 | "Applying function function_9...\n", 267 | "Applying function function_10...\n", 268 | "Applying function function_11...\n", 269 | "Applying function function_12...\n" 270 | ] 271 | }, 272 | { 273 | "name": "stderr", 274 | "output_type": "stream", 275 | "text": [ 276 | "Applying key function_94: 100%|██████████| 100/100 [00:00<00:00, 234187.83it/s]\n", 277 | "Applying key function_1: 100%|██████████| 100/100 [00:00<00:00, 216871.98it/s]\n", 278 | "Applying key function_2: 100%|██████████| 100/100 [00:00<00:00, 163393.22it/s]\n", 279 | "Applying key function_3: 100%|██████████| 100/100 [00:00<00:00, 135869.91it/s]\n", 280 | "Applying key function_5: 100%|██████████| 100/100 [00:00<00:00, 184122.21it/s]\n", 281 | "Applying key function_6: 100%|██████████| 100/100 [00:00<00:00, 512125.03it/s]\n", 282 | "Applying key function_9: 100%|██████████| 100/100 [00:00<00:00, 497544.96it/s]\n", 283 | "Applying key function_10: 100%|██████████| 100/100 [00:00<00:00, 68635.31it/s]\n", 284 | "Applying key function_11: 100%|██████████| 100/100 [00:00<00:00, 158514.89it/s]\n", 285 | "Applying key function_12: 100%|██████████| 100/100 [00:00<00:00, 322638.77it/s]" 286 | ] 287 | }, 288 | { 289 | "name": "stdout", 290 | "output_type": "stream", 291 | "text": [ 292 | "Apply the scripts to the data lake and save the metadata. Taking the top 10 scripts per field.\n", 293 | "Applying function function_0...\n", 294 | "Applying function function_1...\n", 295 | "Applying function function_2...\n", 296 | "Applying function function_3...\n", 297 | "Applying function function_4...\n", 298 | "Applying function function_5...\n", 299 | "Applying function function_6...\n", 300 | "Applying function function_7...\n" 301 | ] 302 | }, 303 | { 304 | "name": "stderr", 305 | "output_type": "stream", 306 | "text": [ 307 | "\n" 308 | ] 309 | }, 310 | { 311 | "name": "stdout", 312 | "output_type": "stream", 313 | "text": [ 314 | "Applying function function_8...\n", 315 | "Applying function function_9...\n" 316 | ] 317 | }, 318 | { 319 | "name": "stderr", 320 | "output_type": "stream", 321 | "text": [ 322 | "Applying key function_0: 100%|██████████| 100/100 [00:00<00:00, 121785.83it/s]\n", 323 | "Applying key function_1: 100%|██████████| 100/100 [00:00<00:00, 40932.02it/s]\n", 324 | "Applying key function_2: 100%|██████████| 100/100 [00:00<00:00, 29112.96it/s]\n", 325 | "Applying key function_3: 100%|██████████| 100/100 [00:00<00:00, 96243.78it/s]\n", 326 | "Applying key function_4: 100%|██████████| 100/100 [00:00<00:00, 135474.94it/s]\n", 327 | "Applying key function_5: 100%|██████████| 100/100 [00:00<00:00, 35907.06it/s]\n", 328 | "Applying key function_6: 100%|██████████| 100/100 [00:00<00:00, 181179.44it/s]\n", 329 | "Applying key function_7: 100%|██████████| 100/100 [00:00<00:00, 279062.14it/s]\n", 330 | "Applying key function_8: 100%|██████████| 100/100 [00:00<00:00, 195995.51it/s]\n", 331 | "Applying key function_9: 100%|██████████| 100/100 [00:00<00:00, 172747.28it/s]" 332 | ] 333 | }, 334 | { 335 | "name": "stdout", 336 | "output_type": "stream", 337 | "text": [ 338 | "{'mean': 0.0027519179133103184, 'std': 0.00294827166632422, 'result': {'purpose for submission': (0.0040076716532412736, 0.0), 'type of test': (0.006999999999999999, 0.0)}}\n" 339 | ] 340 | }, 341 | { 342 | "name": "stderr", 343 | "output_type": "stream", 344 | "text": [ 345 | "\n" 346 | ] 347 | } 348 | ], 349 | "source": [ 350 | "function_extract, function_eval = evaporate.apply_functions()\n", 351 | "print(function_eval)" 352 | ] 353 | } 354 | ], 355 | "metadata": { 356 | "kernelspec": { 357 | "display_name": "Python 3 (ipykernel)", 358 | "language": "python", 359 | "name": "python3" 360 | }, 361 | "language_info": { 362 | "codemirror_mode": { 363 | "name": "ipython", 364 | "version": 3 365 | }, 366 | "file_extension": ".py", 367 | "mimetype": "text/x-python", 368 | "name": "python", 369 | "nbconvert_exporter": "python", 370 | "pygments_lexer": "ipython3", 371 | "version": "3.8.18" 372 | } 373 | }, 374 | "nbformat": 4, 375 | "nbformat_minor": 4 376 | } 377 | -------------------------------------------------------------------------------- /evaporate/profiler_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | import argparse 4 | import random 5 | from bs4 import BeautifulSoup 6 | from collections import Counter, defaultdict 7 | 8 | 9 | def set_profiler_args(profiler_args): 10 | 11 | parser = argparse.ArgumentParser( 12 | "LLM profiler.", 13 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 14 | ) 15 | 16 | parser.add_argument( 17 | "--chunk_size", 18 | type=int, 19 | default=5000 20 | ) 21 | 22 | parser.add_argument( 23 | "--train_size", 24 | type=int, 25 | default=15, 26 | ) 27 | 28 | parser.add_argument( 29 | "--eval_size", 30 | type=int, 31 | default=15, 32 | ) 33 | 34 | parser.add_argument( 35 | "--max_chunks_per_file", 36 | type=int, 37 | default=-1, 38 | ) 39 | 40 | parser.add_argument( 41 | "--num_top_k_scripts", 42 | type=int, 43 | default=1, 44 | help="of all the scripts we generate for the metadata fields, number to retain after scoring their qualities", 45 | ) 46 | 47 | parser.add_argument( 48 | "--extraction_fraction_thresh", 49 | type=int, 50 | default=0.9, 51 | help="for abstensions approach", 52 | ) 53 | 54 | parser.add_argument( 55 | "--remove_tables", 56 | type=bool, 57 | default=False, 58 | help="Remove tables from the html files?", 59 | ) 60 | 61 | parser.add_argument( 62 | "--body_only", 63 | type=bool, 64 | default=False, 65 | help="Only use HTML body", 66 | ) 67 | 68 | parser.add_argument( 69 | "--max_metadata_fields", 70 | type=int, 71 | default=15, 72 | ) 73 | 74 | parser.add_argument( 75 | "--use_dynamic_backoff", 76 | type=bool, 77 | default=True, 78 | help="Whether to do the function generation workflow or directly extract from chunks", 79 | ) 80 | 81 | parser.add_argument( 82 | "--use_qa_model", 83 | type=bool, 84 | default=False, 85 | help="Whether to apply the span-extractor QA model.", 86 | ) 87 | 88 | parser.add_argument( 89 | "--overwrite_cache", 90 | type=int, 91 | default=0, 92 | help="overwrite the manifest cache" 93 | ) 94 | 95 | # models to use in the pipeline 96 | parser.add_argument( 97 | "--MODELS", 98 | type=list, 99 | help="models to use in the pipeline" 100 | ) 101 | 102 | parser.add_argument( 103 | "--KEYS", 104 | type=list, 105 | help="keys for openai models" 106 | ) 107 | 108 | parser.add_argument( 109 | "--GOLDKEY", 110 | type=str, 111 | help="models to use in the pipeline" 112 | ) 113 | 114 | parser.add_argument( 115 | "--MODEL2URL", 116 | type=dict, 117 | default={}, 118 | help="models to use in the pipeline" 119 | ) 120 | 121 | parser.add_argument( 122 | "--swde_plus", 123 | type=bool, 124 | default=False, 125 | help="Whether to use the extended SWDE dataset to measure OpenIE performance", 126 | ) 127 | 128 | parser.add_argument( 129 | "--schema_id_sizes", 130 | type=int, 131 | default=0, 132 | help="Number of documents to use for schema identification stage, if it differs from extraction", 133 | ) 134 | 135 | parser.add_argument( 136 | "--slice_results", 137 | type=bool, 138 | default=False, 139 | help="Whether to measure the results by attribute-slice", 140 | ) 141 | 142 | parser.add_argument( 143 | "--fn_generation_prompt_num", 144 | type=int, 145 | default=-1, 146 | help="For ablations on function generation with diversity, control which prompt we use. Default is all.", 147 | ) 148 | 149 | parser.add_argument( 150 | "--upper_bound_fns", 151 | type=bool, 152 | default=False, 153 | help="For ablations that select functions using ground truth instead of the FM.", 154 | ) 155 | 156 | parser.add_argument( 157 | "--combiner_mode", 158 | type=str, 159 | default='mv', 160 | help="For ablations that select functions using ground truth instead of the FM.", 161 | ) 162 | 163 | parser.add_argument( 164 | "--use_alg_filtering", 165 | type=str, 166 | default=True, 167 | help="Whether to filter functions based on quality.", 168 | ) 169 | 170 | parser.add_argument( 171 | "--use_abstension", 172 | type=str, 173 | default=True, 174 | help="Whether to use the abstensions approach.", 175 | ) 176 | 177 | args = parser.parse_args(args=[]) 178 | for arg, val in profiler_args.items(): 179 | setattr(args, arg, val) 180 | return args 181 | 182 | 183 | #################### GET SOME SAMPLE FILES TO SEED THE METADATA SEARCH ######################### 184 | 185 | def sample_scripts(files, train_size=5): 186 | # "Train" split 187 | random.seed(0) 188 | if train_size <= len(files): 189 | sample_files = random.sample(files, train_size) 190 | else: 191 | sample_files = files 192 | sample_contents = [] 193 | for sample_file in sample_files: 194 | with open(sample_file, 'r') as f: 195 | sample_contents.append(f.read()) 196 | return sample_files 197 | 198 | 199 | #################### BOILERPLATE CHUNKING CODE, CRITICAL FOR LONG SEUQENCES #################### 200 | def chunk_file( 201 | parser, file, chunk_size=5000, mode="train", remove_tables=False, body_only=False 202 | ): 203 | content = get_file_contents(file) 204 | if "html" in parser: 205 | content, chunks = get_html_parse( 206 | content, 207 | chunk_size=chunk_size, 208 | mode=mode, 209 | remove_tables=remove_tables, 210 | body_only=body_only 211 | ) 212 | else: 213 | content, chunks = get_txt_parse(content, chunk_size=chunk_size, mode=mode) 214 | return content, chunks 215 | 216 | 217 | # HTML --> CHUNKS 218 | def clean_html(content): 219 | for tag in ['script', 'style', 'svg']: 220 | content = content.split("\n") 221 | clean_content = [] 222 | in_script = 0 223 | for c in content: 224 | if c.strip().strip("\t").startswith(f"<{tag}"): 225 | in_script = 1 226 | endstr = "" 227 | if endstr in c or "/>" in c: 228 | in_script = 0 229 | if not in_script: 230 | clean_content.append(c) 231 | content = "\n".join(clean_content) 232 | return content 233 | 234 | 235 | def get_flattened_items(content, chunk_size=500): 236 | flattened_divs = str(content).split("\n") 237 | flattened_divs = [ch for ch in flattened_divs if ch.strip() and ch.strip("\n").strip()] 238 | 239 | clean_flattened_divs = [] 240 | for div in flattened_divs: 241 | if len(str(div)) > chunk_size: 242 | sub_divs = div.split("><") 243 | if len(sub_divs) == 1: 244 | clean_flattened_divs.append(div) 245 | else: 246 | clean_flattened_divs.append(sub_divs[0] + ">") 247 | for sd in sub_divs[1:-1]: 248 | clean_flattened_divs.append("<" + sd + ">") 249 | clean_flattened_divs.append("<" + sub_divs[-1]) 250 | else: 251 | clean_flattened_divs.append(div) 252 | return clean_flattened_divs 253 | 254 | 255 | def get_html_parse(content, chunk_size=5000, mode="train", remove_tables=False, body_only=False): 256 | if remove_tables: 257 | soup = BeautifulSoup(content) 258 | tables = soup.find_all("table") 259 | for table in tables: 260 | if "infobox" not in str(table): 261 | content = str(soup) 262 | content = content.replace(str(table), "") 263 | soup = BeautifulSoup(content) 264 | 265 | if body_only: 266 | soup = BeautifulSoup(content) 267 | content = str(soup.find("body")) 268 | soup = BeautifulSoup(content) 269 | 270 | else: 271 | content = clean_html(content) 272 | clean_flattened_divs = [] 273 | flattened_divs = get_flattened_items(content, chunk_size=chunk_size) 274 | for i, div in enumerate(flattened_divs): 275 | new_div = re.sub(r'style="[^"]*"', '', str(div)) 276 | new_div = re.sub(r'', '', str(new_div)) 277 | new_div = re.sub(r'', '', str(new_div)) 278 | new_div = re.sub(r'', '', str(new_div)) 279 | new_div = "\n".join([l for l in new_div.split("\n") if l.strip() and l.strip("\n").strip()]) 280 | # new_div = BeautifulSoup(new_div) #.fsind_all("div")[0] 281 | if new_div: 282 | clean_flattened_divs.append(new_div) 283 | 284 | if mode == "eval": 285 | return [] 286 | 287 | grouped_divs = [] 288 | current_div = [] 289 | current_length = 0 290 | max_length = chunk_size 291 | join_str = " " if use_raw_text else "\n" 292 | for div in clean_flattened_divs: 293 | str_div = str(div) 294 | len_div = len(str_div) 295 | if (current_length + len_div > max_length): 296 | grouped_divs.append(join_str.join(current_div)) 297 | current_div = [] 298 | current_length = 0 299 | elif not current_div and (current_length + len_div > max_length): 300 | grouped_divs.append(str_div) 301 | continue 302 | current_div.append(str_div) 303 | current_length += len_div 304 | 305 | return content, grouped_divs 306 | 307 | 308 | # GENERIC TXT --> CHUNKS 309 | def get_txt_parse(content, chunk_size=5000, mode="train"): 310 | # convert to chunks 311 | if mode == "train": 312 | chunks = content.split("\n") 313 | clean_chunks = [] 314 | for chunk in chunks: 315 | if len(chunk) > chunk_size: 316 | sub_chunks = chunk.split(". ") 317 | clean_chunks.extend(sub_chunks) 318 | else: 319 | clean_chunks.append(chunk) 320 | 321 | chunks = clean_chunks.copy() 322 | clean_chunks = [] 323 | for chunk in chunks: 324 | if len(chunk) > chunk_size: 325 | sub_chunks = chunk.split(", ") 326 | clean_chunks.extend(sub_chunks) 327 | else: 328 | clean_chunks.append(chunk) 329 | 330 | final_chunks = [] 331 | cur_chunk = [] 332 | cur_chunk_size = 0 333 | for chunk in clean_chunks: 334 | if cur_chunk_size + len(chunk) > chunk_size: 335 | final_chunks.append("\n".join(cur_chunk)) 336 | cur_chunk = [] 337 | cur_chunk_size = 0 338 | cur_chunk.append(chunk) 339 | cur_chunk_size += len(chunk) 340 | if cur_chunk: 341 | final_chunks.append("\n".join(cur_chunk)) 342 | else: 343 | final_chunks = [] 344 | return content, final_chunks 345 | 346 | 347 | def get_file_contents(file): 348 | text = '' 349 | if file.endswith(".swp"): 350 | return text 351 | try: 352 | with open(file) as f: 353 | text = f.read() 354 | except: 355 | with open(file, "rb") as f: 356 | text = f.read().decode("utf-8", "ignore") 357 | return text 358 | 359 | 360 | def clean_metadata(field): 361 | return field.replace("\t", " ").replace("\n", " ").strip().lower() 362 | 363 | 364 | def filter_file2chunks(file2chunks, sample_files, attribute): 365 | def get_attribute_parts(attribute): 366 | for char in ["/", "-", "(", ")", "[", "]", "{", "}", ":"]: 367 | attribute = attribute.replace(char, " ") 368 | attribute_parts = attribute.lower().split() 369 | return attribute_parts 370 | 371 | # filter chunks with simple keyword search 372 | attribute_chunks = defaultdict(list) 373 | starting_num_chunks = 0 374 | ending_num_chunks = 0 375 | ending_in_sample_chunks = 0 376 | starting_in_sample_chunks = 0 377 | for file, chunks in file2chunks.items(): 378 | starting_num_chunks += len(chunks) 379 | if file in sample_files: 380 | starting_in_sample_chunks += len(chunks) 381 | cleaned_chunks = [] 382 | for chunk in chunks: 383 | if attribute.lower() in chunk.lower(): 384 | cleaned_chunks.append(chunk) 385 | if len(cleaned_chunks) == 0: 386 | for chunk in chunks: 387 | if attribute.lower().replace(" ", "") in chunk.lower().replace(" ", ""): 388 | cleaned_chunks.append(chunk) 389 | if len(cleaned_chunks) == 0: 390 | chunk2num_word_match = Counter() 391 | for chunk_num, chunk in enumerate(chunks): 392 | attribute_parts = get_attribute_parts(attribute.lower()) 393 | for wd in attribute_parts: 394 | if wd.lower() in chunk.lower(): 395 | chunk2num_word_match[chunk_num] += 1 396 | # sort chunks by number of words that match 397 | sorted_chunks = sorted(chunk2num_word_match.items(), key=lambda x: x[1], reverse=True) 398 | if len(sorted_chunks) > 0: 399 | cleaned_chunks.append(chunks[sorted_chunks[0][0]]) 400 | if len(sorted_chunks) > 1: 401 | cleaned_chunks.append(chunks[sorted_chunks[1][0]]) 402 | ending_num_chunks += len(cleaned_chunks) 403 | num_chunks = len(cleaned_chunks) 404 | num_chunks = min(num_chunks, 2) 405 | 406 | cleaned_chunks = cleaned_chunks[:num_chunks] 407 | attribute_chunks[file] = cleaned_chunks 408 | if file in sample_files: 409 | ending_in_sample_chunks += len(attribute_chunks[file]) 410 | file2chunks = attribute_chunks 411 | if ending_num_chunks == 0 or ending_in_sample_chunks == 0: 412 | print(f"Removing because no chunks for attribute {attribute} in any file") 413 | return None 414 | print(f"For attribute {attribute}\n-- Starting with {starting_num_chunks} chunks\n-- Ending with {ending_num_chunks} chunks") 415 | print(f"-- {starting_in_sample_chunks} starting chunks in sample files\n-- {ending_in_sample_chunks} chunks in sample files") 416 | 417 | return file2chunks 418 | 419 | 420 | def clean_function_predictions(extraction, attribute=None): 421 | if extraction is None: 422 | return '' 423 | if type(extraction) == list: 424 | if extraction and type(extraction[0]) == list: 425 | full_answer = [] 426 | for answer in extraction: 427 | if type(answer) == list: 428 | dedup_list = [] 429 | for a in answer: 430 | if a not in dedup_list: 431 | dedup_list.append(a) 432 | answer = dedup_list 433 | answer = [str(a).strip().strip("\n") for a in answer] 434 | full_answer.append(", ".join(answer)) 435 | else: 436 | full_answer.append(answer.strip().strip("\n")) 437 | full_answer = [a.strip() for a in full_answer] 438 | extraction = ", ".join(full_answer) 439 | elif extraction and len(extraction) == 1 and extraction[0] is None: 440 | extraction = '' 441 | else: 442 | dedup_list = [] 443 | for a in extraction: 444 | if a not in dedup_list: 445 | dedup_list.append(a) 446 | extraction = dedup_list 447 | extraction = [(str(e)).strip().strip("\n") for e in extraction] 448 | extraction = ", ".join(extraction) 449 | if type(extraction) == "str" and extraction.lower() == "none": 450 | extraction = "" 451 | extraction = extraction.strip().replace(" ", " ") 452 | if extraction.lower().startswith(attribute.lower()): 453 | idx = extraction.lower().find(attribute.lower()) 454 | extraction = extraction[idx+len(attribute):].strip() 455 | for char in [':', ","]: 456 | extraction = extraction.strip(char).strip() 457 | extraction = extraction.replace(",", ", ").replace(" ", " ") 458 | return extraction 459 | 460 | 461 | def check_vs_train_extractions(train_extractions, final_extractions, gold_key, attribute = None): 462 | clean_final_extractions = {} 463 | 464 | gold_values = train_extractions[gold_key] 465 | modes = [] 466 | start_toks = [] 467 | end_toks = [] 468 | for file, gold in gold_values.items(): 469 | if type(gold) == dict: 470 | gold = gold[attribute] 471 | if type(gold) == list: 472 | if gold and type(gold[0]) == list: 473 | gold = [g[0] for g in gold] 474 | gold = ", ".join(gold) 475 | else: 476 | gold = ", ".join(gold) 477 | gold = gold.lower() 478 | pred = final_extractions[file].lower() 479 | if not pred or not gold: 480 | continue 481 | if ("<" in pred and "<" not in gold) or (">" in pred and ">" not in gold): 482 | check_pred = BeautifulSoup(pred).text 483 | if check_pred in gold or gold in check_pred: 484 | modes.append("soup") 485 | elif gold in pred and len(pred) > len(gold): 486 | modes.append("longer") 487 | idx = pred.index(gold) 488 | if idx > 0: 489 | start_toks.append(pred[:idx-1]) 490 | end_idx = idx + len(gold) 491 | if end_idx < len(pred): 492 | end_toks.append(pred[end_idx:]) 493 | 494 | def long_substr(data): 495 | substr = '' 496 | if len(data) > 1 and len(data[0]) > 0: 497 | for i in range(len(data[0])): 498 | for j in range(len(data[0])-i+1): 499 | if j > len(substr) and is_substr(data[0][i:i+j], data): 500 | substr = data[0][i:i+j] 501 | return substr 502 | 503 | def is_substr(find, data): 504 | if len(data) < 1 and len(find) < 1: 505 | return False 506 | for i in range(len(data)): 507 | if find not in data[i]: 508 | return False 509 | return True 510 | 511 | longest_end_tok = long_substr(end_toks) 512 | longest_start_tok = long_substr(start_toks) 513 | if len(set(modes)) == 1: 514 | num_golds = len(gold_values) 515 | for file, extraction in final_extractions.items(): 516 | if "longer" in modes: 517 | # gold longer than pred 518 | if len(end_toks) == num_golds and longest_end_tok in extraction and extraction.count(longest_end_tok) == 1: 519 | idx = extraction.index(longest_end_tok) 520 | extraction = extraction[:idx] 521 | if len(start_toks) == num_golds and longest_start_tok in extraction and extraction.count(longest_start_tok) == 1: 522 | idx = extraction.index(longest_start_tok) 523 | extraction = extraction[idx:] 524 | elif "soup" in modes: 525 | extraction = BeautifulSoup(extraction).text 526 | clean_final_extractions[file] = extraction 527 | else: 528 | clean_final_extractions = final_extractions 529 | return clean_final_extractions 530 | -------------------------------------------------------------------------------- /evaporate/run_profiler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import json 5 | import datetime 6 | from tqdm import tqdm 7 | import pickle 8 | import argparse 9 | from collections import defaultdict, Counter 10 | 11 | from evaporate.utils import get_structure, get_manifest_sessions, get_file_attribute 12 | from evaporate.profiler_utils import chunk_file, sample_scripts 13 | from evaporate.schema_identification import identify_schema 14 | from evaporate.profiler import run_profiler, get_file_attribute 15 | from evaporate.evaluate_synthetic import main as evaluate_synthetic_main 16 | from evaporate.configs import get_experiment_args 17 | 18 | 19 | random.seed(0) 20 | def get_data_lake_info(args, data_lake): 21 | extractions_file = None 22 | 23 | if 1: 24 | DATA_DIR = args.data_dir 25 | file_groups = os.listdir(args.data_dir) 26 | if not DATA_DIR.endswith("/"): 27 | DATA_DIR += "/" 28 | file_groups = [f"{DATA_DIR}{file_group}" for file_group in file_groups if not file_group.startswith(".")] 29 | full_file_groups = file_groups.copy() 30 | extractions_file = args.gold_extractions_file 31 | parser = "txt" 32 | 33 | return file_groups, extractions_file, parser, full_file_groups 34 | 35 | 36 | def chunk_files(file_group, parser, chunk_size, remove_tables, max_chunks_per_file, body_only): 37 | file2chunks = {} 38 | file2contents = {} 39 | for file in tqdm(file_group, total=len(file_group), desc="Chunking files"): 40 | content, chunks = chunk_file( 41 | parser, 42 | file, 43 | chunk_size=chunk_size, 44 | mode="train", 45 | remove_tables=remove_tables, 46 | body_only=body_only 47 | ) 48 | if max_chunks_per_file > 0: 49 | chunks = chunks[:max_chunks_per_file] 50 | file2chunks[file] = chunks 51 | file2contents[file] = content 52 | return file2chunks, file2contents 53 | 54 | 55 | # chunking & preparing data 56 | def prepare_data(profiler_args, file_group, data_args, parser = "html"): 57 | data_lake = profiler_args.data_lake 58 | if profiler_args.body_only: 59 | body_only = profiler_args.body_only 60 | suffix = f"_bodyOnly{body_only}" 61 | else: 62 | suffix = "" 63 | # prepare the datalake: chunk all files 64 | manifest_sessions = get_manifest_sessions(profiler_args.MODELS, MODEL2URL=profiler_args.MODEL2URL, KEYS=profiler_args.KEYS) 65 | if os.path.exists(f"{data_args.cache_dir}/{data_lake}_size{len(file_group)}_chunkSize{profiler_args.chunk_size}_{suffix}_file2chunks.pkl"): 66 | with open(f"{data_args.cache_dir}/{data_lake}_size{len(file_group)}_chunkSize{profiler_args.chunk_size}_{suffix}_file2chunks.pkl", "rb") as f: 67 | file2chunks = pickle.load(f) 68 | with open(f"{data_args.cache_dir}/{data_lake}_size{len(file_group)}_chunkSize{profiler_args.chunk_size}_{suffix}_file2contents.pkl", "rb") as f: 69 | file2contents = pickle.load(f) 70 | else: 71 | file2chunks, file2contents = chunk_files( 72 | file_group, 73 | parser, 74 | profiler_args.chunk_size, 75 | profiler_args.remove_tables, 76 | profiler_args.max_chunks_per_file, 77 | profiler_args.body_only 78 | ) 79 | if not os.path.exists(data_args.cache_dir): 80 | os.mkdir(data_args.cache_dir) 81 | with open(f"{data_args.cache_dir}/{data_lake}_size{len(file_group)}_chunkSize{profiler_args.chunk_size}_removeTables{profiler_args.remove_tables}{suffix}_file2chunks.pkl", "wb") as f: 82 | pickle.dump(file2chunks, f) 83 | with open(f"{data_args.cache_dir}/{data_lake}_size{len(file_group)}_chunkSize{profiler_args.chunk_size}_removeTables{profiler_args.remove_tables}{suffix}_file2contents.pkl", "wb") as f: 84 | pickle.dump(file2contents, f) 85 | return file2chunks, file2contents, manifest_sessions 86 | 87 | 88 | def get_run_string( 89 | data_lake, today, file_groups, profiler_args, do_end_to_end, 90 | train_size, dynamicbackoff, models 91 | ): 92 | body = profiler_args.body_only # Baseline systems only operate on the HTML body 93 | model_ct = len(models) 94 | if profiler_args.use_qa_model: 95 | model_ct += 1 96 | run_string = f"dataLake{data_lake}_date{today}_fileSize{len(file_groups)}_trainSize{train_size}_numAggregate{profiler_args.num_top_k_scripts}_chunkSize{profiler_args.chunk_size}_removeTables{profiler_args.remove_tables}_body{body}_cascading{do_end_to_end}_useBackoff{dynamicbackoff}_MODELS{model_ct}" 97 | return run_string 98 | 99 | 100 | def get_gold_metadata(args): 101 | # get the list of gold metadata for closed-IE runs 102 | try: 103 | with open(args.gold_extractions_file) as f: 104 | gold_file2extractions = json.load(f) 105 | except: 106 | with open(args.gold_extractions_file, "rb") as f: 107 | gold_file2extractions = pickle.load(f) 108 | frequency = Counter() 109 | for file, dic in gold_file2extractions.items(): 110 | for k, v in dic.items(): 111 | if k != "topic_entity_name": 112 | if type(v) == str and v: 113 | frequency[k] += 1 114 | elif type(v) == list and v and v[0]: 115 | frequency[k] += 1 116 | sorted_frequency = sorted(frequency.items(), key=lambda x: x[1], reverse=True) 117 | gold_metadata = [x[0] for x in sorted_frequency] 118 | gold_attributes = [m.lower() for m in gold_metadata if m not in ['topic_entity_name']] 119 | return gold_attributes 120 | 121 | 122 | def determine_attributes_to_remove(attributes, args, run_string, num_attr_to_cascade): 123 | attributes_reordered = {} 124 | attributes_to_remove = [] 125 | attributes_to_metrics = {} 126 | attribute_to_first_extractions = {} 127 | mappings_names = {} 128 | for num, attribute in enumerate(attributes): 129 | attribute = attribute.lower() 130 | file_attribute = get_file_attribute(attribute) 131 | if not os.path.exists(f"{args.generative_index_path}/{run_string}_{file_attribute}_all_metrics.json"): 132 | continue 133 | if not os.path.exists(f"{args.generative_index_path}/{run_string}_{file_attribute}_file2metadata.json"): 134 | continue 135 | if num >= num_attr_to_cascade: 136 | os.remove(f"{args.generative_index_path}/{run_string}_{file_attribute}_all_metrics.json") 137 | os.remove(f"{args.generative_index_path}/{run_string}_{file_attribute}_file2metadata.json") 138 | continue 139 | with open(f"{args.generative_index_path}/{run_string}_{file_attribute}_all_metrics.json") as f: 140 | metrics = json.load(f) 141 | with open(f"{args.generative_index_path}/{run_string}_{file_attribute}_top_k_keys.json") as f: 142 | selected_keys = json.load(f) 143 | with open(f"{args.generative_index_path}/{run_string}_{file_attribute}_file2metadata.json") as f: 144 | file2metadata = json.load(f) 145 | attributes_reordered[attribute] = metrics[selected_keys[0]] 146 | 147 | if selected_keys and metrics: 148 | for a, m in attributes_to_metrics.items(): 149 | if attribute.lower() in a.lower() or a.lower() in attribute.lower(): 150 | if m == metrics[selected_keys[0]]['average_f1']: 151 | attributes_to_remove.append(attribute) 152 | mappings_names[a] = attribute 153 | mappings_names[attribute] = a 154 | break 155 | 156 | first_extractions = [m for i, (f, m) in enumerate(file2metadata.items()) if i < 5] 157 | if any(f != "" for f in first_extractions): 158 | first_extractions = " ".join(first_extractions) 159 | for a, m in attribute_to_first_extractions.items(): 160 | if m == first_extractions: 161 | attributes_to_remove.append(attribute) 162 | mappings_names[a] = attribute 163 | mappings_names[attribute] = a 164 | break 165 | 166 | if attribute in attributes_to_remove: 167 | continue 168 | if selected_keys: 169 | attributes_to_metrics[attribute] = metrics[selected_keys[0]]['average_f1'] 170 | attribute_to_first_extractions[attribute] = first_extractions 171 | return attributes_to_remove, mappings_names, attributes 172 | 173 | 174 | def measure_openie_results( 175 | attributes, 176 | args, 177 | profiler_args, 178 | run_string, 179 | gold_attributes, 180 | attributes_to_remove, 181 | file_groups, 182 | mappings_names 183 | ): 184 | file2extractions = defaultdict(dict) 185 | unique_attributes = set() 186 | num_extractions2results = {} 187 | data_lake = profiler_args.data_lake 188 | for attr_num, attribute in enumerate(attributes): 189 | attribute = attribute.lower() 190 | file_attribute = get_file_attribute(attribute) 191 | if os.path.exists(f"{args.generative_index_path}/{run_string}_{file_attribute}_file2metadata.json"): 192 | if attribute in attributes_to_remove: 193 | print(f"Removing: {attribute}") 194 | os.remove(f"{args.generative_index_path}/{run_string}_{file_attribute}_file2metadata.json") 195 | continue 196 | 197 | with open(f"{args.generative_index_path}/{run_string}_{file_attribute}_file2metadata.json") as f: 198 | file2metadata = json.load(f) 199 | for file, extraction in file2metadata.items(): 200 | file2extractions[file][attribute] = extraction 201 | unique_attributes.add(attribute) 202 | 203 | if file2extractions: 204 | num_extractions = len(unique_attributes) 205 | nums = [1, 2, 3, 4, len(attributes) - 1, len(gold_attributes)] 206 | if file2extractions and ((num_extractions) % 5 == 0 or num_extractions in nums) or attr_num == len(attributes) - 1: 207 | if num_extractions in num_extractions2results: 208 | continue 209 | with open(f"{args.generative_index_path}/{run_string}_file2extractions.json", "w") as f: 210 | json.dump(file2extractions, f) 211 | 212 | results = evaluate_synthetic_main( 213 | run_string, 214 | args, 215 | profiler_args, 216 | data_lake, 217 | sample_files=file_groups, 218 | stage='openie', 219 | mappings_names=mappings_names 220 | ) 221 | num_extractions2results[num_extractions] = results 222 | return num_extractions2results 223 | 224 | def prerun_profiler(profiler_args): 225 | file2chunks, file2contents, manifest_sessions = prepare_data( 226 | profiler_args, profiler_args.full_file_groups, profiler_args, profiler_args.parser 227 | ) 228 | manifest_sessions = { 229 | k: v for k, v in manifest_sessions.items() if k in profiler_args.MODELS 230 | } 231 | gold_attributes = get_gold_metadata(profiler_args) 232 | try: 233 | with open(profiler_args.gold_extractions_file) as f: 234 | gold_extractions_tmp = json.load(f) 235 | except: 236 | with open(profiler_args.gold_extractions_file, "rb") as f: 237 | gold_extractions_tmp = pickle.load(f) 238 | gold_extractions = {} 239 | for file, extractions in gold_extractions_tmp.items(): 240 | gold_extractions[os.path.join(profiler_args.data_dir, file.split('/')[-1])] = extractions 241 | manifest_sessions[profiler_args.GOLD_KEY] = {} 242 | manifest_sessions[profiler_args.GOLD_KEY]['__name'] = 'gold_extraction_file' 243 | for attribute in gold_attributes: 244 | manifest_sessions[profiler_args.GOLD_KEY][attribute] = {} 245 | for file in profiler_args.file_groups: 246 | manifest_sessions[profiler_args.GOLD_KEY][attribute][file] = gold_extractions[file][attribute] 247 | 248 | sample_files = sample_scripts( 249 | profiler_args.file_groups, 250 | train_size=profiler_args.train_size, 251 | ) 252 | data_dict = { 253 | "file2chunks": file2chunks, 254 | "file2contents": file2contents, 255 | "manifest_sessions": manifest_sessions, 256 | "gold_attributes": gold_attributes, 257 | "sample_files" : sample_files, 258 | "gold_extractions": gold_extractions 259 | } 260 | return data_dict 261 | 262 | def identify_attributes(profiler_args, data_dict, evaluation = False): 263 | file2chunks = data_dict["file2chunks"] 264 | file2contents = data_dict["file2contents"] 265 | manifest_sessions = data_dict["manifest_sessions"] 266 | sample_files = sample_scripts( 267 | profiler_args.file_groups, 268 | train_size=profiler_args.train_size, 269 | ) 270 | t0 = time.time() 271 | num_toks = identify_schema( 272 | profiler_args.run_string, 273 | profiler_args, 274 | file2chunks, 275 | file2contents, 276 | sample_files, 277 | manifest_sessions, 278 | profiler_args.data_lake, 279 | profiler_args 280 | ) 281 | t1 = time.time() 282 | with open(f"{profiler_args.generative_index_path}/{profiler_args.run_string}_identified_schema.json") as f: 283 | most_common_fields = json.load(f) 284 | with open(f"{profiler_args.generative_index_path}/{profiler_args.run_string}_order_of_addition.json") as f: 285 | order_of_addition = json.load(f) 286 | order = {item: (len(order_of_addition) - i) for i, item in enumerate(order_of_addition)} 287 | ctr = Counter(most_common_fields) 288 | pred_metadata = sorted( 289 | ctr.most_common(profiler_args.num_attr_to_cascade), 290 | key=lambda x: (x[1], order[x[0]]), 291 | reverse=True 292 | ) 293 | attributes = [item[0].lower() for item in pred_metadata] 294 | if evaluation : 295 | evaluation_result = evaluate_synthetic_main( 296 | profiler_args.run_string, 297 | profiler_args, 298 | profiler_args, 299 | profiler_args.data_lake, 300 | stage='schema_id' 301 | ) 302 | else: 303 | evaluation_result = None 304 | return attributes, t1-t0, num_toks, evaluation_result 305 | 306 | def get_attribute_function(profiler_args, data_dict, attribute): 307 | sample_files = sample_scripts( 308 | profiler_args.file_groups, 309 | train_size=profiler_args.train_size, 310 | ) 311 | t0 = time.time() 312 | num_toks, success = run_profiler( 313 | profiler_args.run_string, 314 | profiler_args, 315 | data_dict["file2chunks"], 316 | data_dict["file2contents"], 317 | sample_files, 318 | profiler_args.full_file_groups, 319 | data_dict["manifest_sessions"], 320 | attribute, 321 | profiler_args 322 | ) 323 | t1 = time.time() 324 | try: 325 | file_attribute = get_file_attribute(attribute) 326 | with open(f"{profiler_args.generative_index_path}/{profiler_args.run_string}_{file_attribute}_functions.json") as f: 327 | function_dictionary = json.load(f) 328 | with open(f"{profiler_args.generative_index_path}/{profiler_args.run_string}_{file_attribute}_top_k_keys.json") as f: 329 | selected_keys = json.load(f) 330 | except: 331 | selected_keys = None 332 | function_dictionary = None 333 | return function_dictionary, selected_keys, t1-t0, num_toks 334 | def run_experiment(profiler_args): 335 | do_end_to_end = profiler_args.do_end_to_end 336 | num_attr_to_cascade = profiler_args.num_attr_to_cascade 337 | train_size = profiler_args.train_size 338 | data_lake = profiler_args.data_lake 339 | 340 | print(f"Data lake") 341 | today = datetime.datetime.today().strftime("%m%d%Y") 342 | 343 | _, _, _, _, args = get_structure(data_lake, profiler_args) 344 | file_groups, extractions_file, parser, full_file_groups = get_data_lake_info(args, data_lake) 345 | file2chunks, file2contents, manifest_sessions = prepare_data( 346 | profiler_args, full_file_groups, args, parser 347 | ) 348 | manifest_sessions = { 349 | k: v for k, v in manifest_sessions.items() if k in profiler_args.MODELS 350 | } 351 | gold_attributes = get_gold_metadata(args) 352 | 353 | results_by_train_size = defaultdict(dict) 354 | total_time_dict = defaultdict(dict) 355 | 356 | if 1: 357 | total_tokens_prompted = 0 358 | 359 | print(f"\n\nData-lake: {data_lake}, Train size: {train_size}") 360 | setattr(profiler_args, 'train_size', train_size) 361 | 362 | run_string = get_run_string( 363 | data_lake, today, full_file_groups, profiler_args, 364 | do_end_to_end, train_size, 365 | profiler_args.use_dynamic_backoff, 366 | profiler_args.EXTRACTION_MODELS, 367 | ) 368 | 369 | sample_files = sample_scripts( 370 | file_groups, 371 | train_size=profiler_args.train_size, 372 | ) 373 | 374 | # top-level schema identification 375 | if do_end_to_end: 376 | t0 = time.time() 377 | num_toks = identify_schema( 378 | run_string, 379 | args, 380 | file2chunks, 381 | file2contents, 382 | sample_files, 383 | manifest_sessions, 384 | data_lake, 385 | profiler_args 386 | ) 387 | t1 = time.time() 388 | total_time = t1-t0 389 | total_tokens_prompted += num_toks 390 | total_time_dict[f'schemaId'][f'totalTime_trainSize{train_size}'] = int(total_time) 391 | 392 | results = evaluate_synthetic_main( 393 | run_string, 394 | args, 395 | profiler_args, 396 | data_lake, 397 | stage='schema_id' 398 | ) 399 | results_by_train_size[train_size]['schema_id'] = results 400 | 401 | if 1: 402 | if do_end_to_end: 403 | with open(f"{args.generative_index_path}/{run_string}_identified_schema.json") as f: 404 | most_common_fields = json.load(f) 405 | with open(f"{args.generative_index_path}/{run_string}_order_of_addition.json") as f: 406 | order_of_addition = json.load(f) 407 | order = {item: (len(order_of_addition) - i) for i, item in enumerate(order_of_addition)} 408 | ctr = Counter(most_common_fields) 409 | pred_metadata = sorted( 410 | ctr.most_common(num_attr_to_cascade), 411 | key=lambda x: (x[1], order[x[0]]), 412 | reverse=True 413 | ) 414 | attributes = [item[0].lower() for item in pred_metadata] 415 | else: 416 | attributes = gold_attributes 417 | 418 | # top-level information extraction 419 | num_collected = 0 420 | for i, attribute in enumerate(attributes): 421 | print(f"\n\nExtracting {attribute} ({i+1} / {len(attributes)})") 422 | t0 = time.time() 423 | num_toks, success = run_profiler( 424 | run_string, 425 | args, 426 | file2chunks, 427 | file2contents, 428 | sample_files, 429 | full_file_groups, 430 | manifest_sessions, 431 | attribute, 432 | profiler_args 433 | ) 434 | t1 = time.time() 435 | total_time = t1-t0 436 | total_tokens_prompted += num_toks 437 | total_time_dict[f'extract'][f'totalTime_trainSize{train_size}'] = int(total_time) 438 | if success: 439 | num_collected += 1 440 | if num_collected >= num_attr_to_cascade: 441 | break 442 | 443 | # run closed ie eval 444 | results = evaluate_synthetic_main( 445 | run_string, 446 | args, 447 | profiler_args, 448 | data_lake, 449 | gold_attributes=gold_attributes, 450 | stage='extract' 451 | ) 452 | results_by_train_size[train_size]['extract'] = results 453 | 454 | # Determine whether to remove any attributes based on the extractions 455 | # Potentially can rerank the attributes based on the metric comparison to big model 456 | if do_end_to_end: 457 | attributes_to_remove, mappings_names, attributes = determine_attributes_to_remove( 458 | attributes, 459 | args, 460 | run_string, 461 | num_attr_to_cascade, 462 | ) 463 | numextractions2results = measure_openie_results( 464 | attributes, 465 | args, 466 | profiler_args, 467 | run_string, 468 | gold_attributes, 469 | attributes_to_remove, 470 | full_file_groups, 471 | mappings_names 472 | ) 473 | if 'openie' not in results_by_train_size[train_size]: 474 | results_by_train_size[train_size]['openie'] = {} 475 | results_by_train_size[train_size]['openie'] = numextractions2results 476 | 477 | results_by_train_size[train_size]['total_tokens_prompted'] = total_tokens_prompted 478 | results_by_train_size[train_size]['num_total_files'] = len(full_file_groups) 479 | results_by_train_size[train_size]['num_sample_files'] = len(sample_files) 480 | result_path_dir = os.path.join(profiler_args.base_data_dir, "results_dumps") 481 | if not os.path.exists(result_path_dir): 482 | os.mkdir(result_path_dir) 483 | print(run_string) 484 | with open(f"{result_path_dir}/{run_string}_results_by_train_size.pkl", "wb") as f: 485 | pickle.dump(results_by_train_size, f) 486 | print(f"Saved!") 487 | 488 | print(f"Total tokens prompted: {total_tokens_prompted}") 489 | 490 | 491 | 492 | 493 | 494 | def main(): 495 | #todo: make into two types of args: running_args and data_args, set by the user 496 | profiler_args = get_experiment_args() 497 | 498 | # model_dict = { 499 | # 'MODELS': ["text-davinci-003"], 500 | # 'EXTRACTION_MODELS': ["text-davinci-003"], 501 | # 'GOLD_KEY': "text-davinci-003", 502 | # } 503 | # Example of how to use a locally-hosted FM 504 | # model_dict = { 505 | # 'MODELS': [" EleutherAI/gpt-j-6B"], 506 | # 'EXTRACTION_MODELS': [" EleutherAI/gpt-j-6B"], 507 | # 'GOLD_KEY': " EleutherAI/gpt-j-6B", 508 | # 'MODEL2URL': { 509 | # " EleutherAI/gpt-j-6B": "http://127.0.0.1:5000" 510 | # }, 511 | # } 512 | 513 | run_experiment(profiler_args) 514 | 515 | 516 | if __name__ == "__main__": 517 | main() 518 | -------------------------------------------------------------------------------- /evaporate/evaluate_synthetic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import pickle 5 | 6 | import html 7 | from bs4 import BeautifulSoup 8 | from collections import Counter, defaultdict 9 | from evaporate.utils import get_file_attribute 10 | from evaporate.evaluate_synthetic_utils import text_f1 11 | 12 | 13 | # Compute recall from two sets 14 | def set_recall(pred, gt): 15 | return len(set(pred) & set(gt)) / len(set(gt)) 16 | 17 | 18 | # Compute precision from two sets 19 | def set_precision(pred, gt): 20 | return len(set(pred) & set(gt)) / len(set(pred)) 21 | 22 | 23 | # Compute F1 from precision and recall 24 | def compute_f1(precision, recall): 25 | if recall > 0. or precision > 0.: 26 | return 2. * (precision * recall) / (precision + recall) 27 | else: 28 | return 0. 29 | 30 | 31 | def evaluate_schema_identification(run_string, args, group_name, train_size=-1): 32 | with open(f"{args.generative_index_path}/{run_string}_identified_schema.json") as f: 33 | most_common_fields = json.load(f) 34 | 35 | try: 36 | with open(args.gold_extractions_file) as f: 37 | gold_file2extractions = json.load(f) 38 | except: 39 | with open(args.gold_extractions_file, "rb") as f: 40 | gold_file2extractions = pickle.load(f) 41 | 42 | for file, dic in gold_file2extractions.items(): 43 | gold_metadata = list(dic.keys()) 44 | gold_metadata = [m for m in gold_metadata if m not in ['topic_entity_name']] 45 | break 46 | 47 | ctr = Counter(most_common_fields) 48 | results = {} 49 | for k in [len(gold_metadata), 1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 100, len(most_common_fields)]: 50 | if not most_common_fields: 51 | results[k] = { 52 | "recall": 0, 53 | "precision": 0, 54 | "f1": 0, 55 | "num_gold_attributes": k, 56 | } 57 | continue 58 | gold_metadata = [item.lower() for item in gold_metadata] 59 | pred_metadata = ctr 60 | 61 | limit = k 62 | pred_metadata = sorted(pred_metadata.most_common(limit), key=lambda x: (x[1], x[0]), reverse=True) 63 | pred_metadata = [item[0].lower() for item in pred_metadata] 64 | cleaned_pred_metadata = set() 65 | for pred in pred_metadata: 66 | if not pred: 67 | continue 68 | cleaned_pred_metadata.add(pred) 69 | cleaned_gold_metadata = set() 70 | for gold in gold_metadata: 71 | cleaned_gold_metadata.add(gold) 72 | 73 | recall = [x for x in cleaned_gold_metadata if x in cleaned_pred_metadata] 74 | precision = [x for x in cleaned_pred_metadata if x in cleaned_gold_metadata] 75 | recall = len(recall) / len(cleaned_gold_metadata) 76 | precision = len(precision) / len(cleaned_pred_metadata) 77 | f1 = compute_f1(precision, recall) 78 | 79 | results[k] = { 80 | "recall": recall, 81 | "precision": precision, 82 | "f1": f1, 83 | "num_gold_attributes": limit, 84 | } 85 | 86 | print(f"@k = %d --- Recall: %.3f, Precision: %.3f, F1: %.3f" % (k, recall, precision, f1)) 87 | print() 88 | return results 89 | 90 | 91 | def clean_comparison(extraction, attribute='', exact_match=False): 92 | # formatting transformations 93 | 94 | if type(extraction) == list: 95 | if extraction and type(extraction[0]) == list: 96 | full_answer = [] 97 | for answer in extraction: 98 | if type(answer) == list: 99 | dedup_list = [] 100 | for a in answer: 101 | if a not in dedup_list: 102 | dedup_list.append(a) 103 | answer = dedup_list 104 | answer = [str(a).strip().strip("\n") for a in answer] 105 | full_answer.append(", ".join(answer)) 106 | else: 107 | full_answer.append(answer.strip().strip("\n")) 108 | full_answer = [a.strip() for a in full_answer] 109 | extraction = ", ".join(full_answer) 110 | else: 111 | dedup_list = [] 112 | for a in extraction: 113 | if a not in dedup_list: 114 | dedup_list.append(a) 115 | extraction = dedup_list 116 | extraction = [(str(e)).strip().strip("\n") for e in extraction if e] 117 | extraction = ", ".join(extraction) 118 | elif type(extraction) == "str" and ( 119 | extraction.lower() == "none" or any( 120 | phrase in extraction.lower() for phrase in ["not reported", "none available", "n/a"] 121 | ) 122 | ): 123 | extraction = "" 124 | if type(extraction) == float and math.isnan(extraction): 125 | extraction = '' 126 | if type(extraction) != str: 127 | extraction = str(extraction) 128 | 129 | if type(extraction) == str: 130 | if ("<" in extraction) and (">" in extraction): 131 | extraction = BeautifulSoup(extraction).text 132 | 133 | extraction = extraction.strip().replace(" ", " ").lower() 134 | attribute_variations = [f"{attribute}(s)".lower(), f"{attribute.strip()}(s)".lower(), attribute.lower(), attribute] 135 | for a in attribute_variations: 136 | extraction = extraction.replace(a, "").strip() 137 | for char in ["'", '"', "(", ")", ",", "/", "]", "[", ":"]: 138 | extraction = extraction.replace(char, "").strip() 139 | extraction = html.unescape(extraction) 140 | for char in ['&', '&', "-", "_", "\n", "\t", "http:", "<", ">"]: 141 | extraction = extraction.replace(char, " ").strip() 142 | if exact_match: 143 | extraction = extraction.replace(" ", "") 144 | if extraction == " ": 145 | extraction = "" 146 | extraction = extraction.strip() 147 | return extraction 148 | 149 | 150 | def evaluate_extraction_quality(run_string, args, gold_extractions_file, gold_extractions_file_dir, gold_attributes=None): 151 | all_attribute_f1 = 0 152 | all_attribute_total = 0 153 | total_runtime_overall_attributes = 0 154 | attribute2f1 = {} 155 | attribute2scripts = {} 156 | 157 | # load gold extractions file 158 | try: 159 | with open(gold_extractions_file) as f: 160 | gold_extractions_tmp = json.load(f) 161 | except: 162 | with open(gold_extractions_file, "rb") as f: 163 | gold_extractions_tmp = pickle.load(f) 164 | gold_extractions = {} 165 | for file, extractions in gold_extractions_tmp.items(): 166 | gold_extractions[os.path.join(gold_extractions_file_dir, file.split('/')[-1])] = extractions 167 | 168 | for attribute in gold_attributes: 169 | attribute = attribute.lower() 170 | # load predicted extractions 171 | fileattribute = get_file_attribute(attribute) 172 | if not os.path.exists(f"{args.generative_index_path}/{run_string}_{fileattribute}_file2metadata.json"): 173 | print(f"Missing file for {attribute}") 174 | continue 175 | with open(f"{args.generative_index_path}/{run_string}_{fileattribute}_file2metadata.json") as f: 176 | file2metadata = json.load(f) 177 | 178 | try: 179 | with open(f"{args.generative_index_path}/{run_string}_{fileattribute}_functions.json") as f: 180 | function_dictionary = json.load(f) 181 | with open(f"{args.generative_index_path}/{run_string}_{fileattribute}_top_k_keys.json") as f: 182 | selected_keys = json.load(f) 183 | attribute2scripts[attribute] = selected_keys 184 | except: 185 | function_dictionary = {} 186 | selected_keys = [] 187 | pass 188 | total_runtime = 0 189 | for key in selected_keys: 190 | if key in function_dictionary: 191 | runtime = function_dictionary[key]['runtime'] 192 | total_runtime += runtime 193 | 194 | preds = [] 195 | golds = [] 196 | for file, gold_entry in gold_extractions.items(): 197 | for attr, gold_value in gold_entry.items(): 198 | attr = clean_comparison(attr) 199 | attribute = clean_comparison(attribute) 200 | if attr.lower() != attribute.lower(): 201 | continue 202 | if file not in file2metadata: 203 | continue 204 | pred_value = file2metadata[file] 205 | 206 | value_check = '' 207 | pred_value_check = '' 208 | if type(pred_value) == list and type(pred_value[0]) == str: 209 | pred_value_check = sorted([p.strip() for p in pred_value]) 210 | elif type(pred_value) == str and "," in pred_value: 211 | pred_value_check = sorted([p.strip() for p in pred_value.split(",")]) 212 | if type(gold_value) == list: 213 | value_check = gold_value[0] 214 | if "," in gold_value: 215 | value_check = sorted([p.strip() for p in gold_value.split(",")]) 216 | if value_check and pred_value_check and value_check == pred_value_check: 217 | gold_value = pred_value 218 | 219 | # SWDE doesn't include the full passage in many cases (e.g. "IMDB synopsis") 220 | pred_value = clean_comparison(pred_value, attribute=attribute) 221 | gold_value = clean_comparison(gold_value, attribute=attribute) 222 | if pred_value.lower().strip(".").startswith(gold_value.lower().strip(".")): 223 | pred_value = " ".join(pred_value.split()[:len(gold_value.split())]) 224 | preds.append(pred_value) 225 | golds.append(gold_value) 226 | 227 | if not preds: 228 | total_f1, total_f1_median = 0, 0 229 | if golds and preds: 230 | total_f1, total_f1_median = text_f1(preds, golds, attribute=attribute) 231 | else: 232 | print(f"Skipping eval of attribute: {attribute}") 233 | continue 234 | 235 | if preds: 236 | all_attribute_f1 += (total_f1) 237 | all_attribute_total += 1 238 | attribute2f1[attribute] = total_f1 239 | 240 | total_runtime_overall_attributes += total_runtime 241 | 242 | num_function_scripts = 0 243 | for k, v in attribute2f1.items(): 244 | scripts = [] 245 | if k in attribute2scripts: 246 | scripts = attribute2scripts[k] 247 | if any(s for s in scripts if "function" in s): 248 | num_function_scripts += 1 249 | print(f"{k}, text-f1 = {v} --- {scripts}") 250 | 251 | try: 252 | overall_f1 = all_attribute_f1 / all_attribute_total 253 | print(f"\nOverall f1 across %d attributes: %.3f" % (all_attribute_total, overall_f1)) 254 | print(f"Used functions for {num_function_scripts} out of {len(attribute2f1)} attributes") 255 | print(f"Average time: {total_runtime_overall_attributes/all_attribute_total} seconds, {all_attribute_total} fns.\n\n") 256 | 257 | results = { 258 | "f1": all_attribute_f1 / all_attribute_total, 259 | "total_attributes": all_attribute_total, 260 | "attribute2f1": attribute2f1, 261 | } 262 | except: 263 | results = { 264 | "f1": 0, 265 | "total_attributes": all_attribute_total, 266 | "attribute2f1": attribute2f1, 267 | } 268 | 269 | return results 270 | 271 | 272 | def determine_attribute_slices(gold_extractions, slice_results): 273 | num_occurences, num_characters = defaultdict(int), defaultdict(int) 274 | num_documents = len(gold_extractions) 275 | for file, extraction_dict in gold_extractions.items(): 276 | for key, value in extraction_dict.items(): 277 | if type(value) == str and value: 278 | num_occurences[key] += 1 279 | num_characters[key] += len(value) 280 | elif type(value) == list and value[0]: 281 | num_occurences[key] += 1 282 | num_characters[key] += len(value[0]) 283 | 284 | # calculate the average length of the attribute 285 | for attr, total_len in num_characters.items(): 286 | num_characters[attr] = total_len / num_occurences[attr] 287 | 288 | # split into the "head", "tail", and "unstructured" 289 | attribute_slices = defaultdict(set) 290 | for attr, num_occur in num_occurences.items(): 291 | attribute_slices["all"].add(attr) 292 | 293 | # skip the rest if not slicing results 294 | if not slice_results: 295 | continue 296 | 297 | num_char = num_characters[attr] 298 | if int(num_documents * 0.5) <= num_occur: 299 | attribute_slices["head"].add(attr) 300 | else: 301 | attribute_slices["tail"].add(attr) 302 | 303 | if num_char >= 20: 304 | attribute_slices["unstructured"].add(attr) 305 | else: 306 | attribute_slices["structured"].add(attr) 307 | 308 | return attribute_slices 309 | 310 | 311 | def evaluate_openie_quality( 312 | run_string, 313 | args, 314 | gold_extractions_file, 315 | sample_files=None, 316 | slice_results=False, 317 | mappings_names={} 318 | ): 319 | # load pred extractions file 320 | with open(f"{args.generative_index_path}/{run_string}_file2extractions.json") as f: 321 | pred_extractions = json.load(f) 322 | 323 | # alternate gold attribute naming 324 | if args.set_dicts: 325 | with open(args.set_dicts) as f: 326 | set_dicts = json.load(f) 327 | else: 328 | set_dicts = {} 329 | 330 | # load gold extractions file 331 | try: 332 | with open(gold_extractions_file) as f: 333 | gold_extractions = json.load(f) 334 | except: 335 | with open(gold_extractions_file, "rb") as f: 336 | gold_extractions = pickle.load(f) 337 | 338 | pred_attributes = set() 339 | for file, extraction_dict in pred_extractions.items(): 340 | for key, value in extraction_dict.items(): 341 | pred_attributes.add(key) 342 | 343 | # split the attribute into slices -> "head", "tail", and "unstructured" 344 | attribute_slices = determine_attribute_slices(gold_extractions, slice_results) 345 | 346 | results = {} 347 | for attribute_slice, gold_attributes in attribute_slices.items(): 348 | 349 | # lenient attribute scoring method: https://arxiv.org/pdf/2201.10608.pdf 350 | gold_attribute_mapping = {} 351 | for gold_attribute in gold_attributes: 352 | if gold_attribute in pred_attributes or not set_dicts: 353 | gold_attribute_mapping[gold_attribute] = gold_attribute 354 | continue 355 | if gold_attribute in set_dicts: 356 | alternate_golds = set_dicts[gold_attribute] 357 | else: 358 | alternate_golds = [gold_attribute] 359 | found = 0 360 | for alternate_gold in alternate_golds: 361 | if alternate_gold in pred_attributes: 362 | gold_attribute_mapping[gold_attribute] = alternate_gold 363 | found = 1 364 | if not found: 365 | if gold_attribute.strip('s') in pred_attributes: 366 | gold_attribute_mapping[gold_attribute] = gold_attribute.strip('s') 367 | elif gold_attribute+"s" in pred_attributes: 368 | gold_attribute_mapping[gold_attribute] = gold_attribute+"s" 369 | elif gold_attribute.strip('(s)') in pred_attributes: 370 | gold_attribute_mapping[gold_attribute] = gold_attribute.strip('(s)') 371 | elif gold_attribute+"(s)" in pred_attributes: 372 | gold_attribute_mapping[gold_attribute] = gold_attribute+"(s)" 373 | elif gold_attribute.replace(" ", "") in pred_attributes: 374 | gold_attribute_mapping[gold_attribute] = gold_attribute.replace(" ", "") 375 | elif any(pred_attribute.replace(" ", "") in gold_attributes for pred_attribute in pred_attributes): 376 | for pred_attribute in pred_attributes: 377 | if pred_attribute.replace(" ", "") in gold_attributes: 378 | gold_attribute_mapping[gold_attribute] = pred_attribute 379 | elif gold_attribute in mappings_names and mappings_names[gold_attribute] in pred_attributes: 380 | gold_attribute_mapping[gold_attribute] = mappings_names[gold_attribute] 381 | else: 382 | gold_attribute_mapping[gold_attribute] = gold_attribute 383 | 384 | pred_set = set() 385 | skipped = set() 386 | all_measurements = defaultdict(dict) 387 | for file, extraction_dict in pred_extractions.items(): 388 | if sample_files and file not in sample_files: 389 | continue 390 | for key, value in extraction_dict.items(): 391 | if key not in attribute_slices["all"]: 392 | if key.replace(" ", "") in attribute_slices["all"]: 393 | key = key.replace(" ", "") 394 | 395 | # skip predicted attributes that are in a different slice 396 | if key in attribute_slices["all"] and key not in gold_attributes: 397 | skipped.add(key) 398 | continue 399 | 400 | clean_key = clean_comparison(key, exact_match=True) 401 | clean_value = clean_comparison(value, attribute=key, exact_match=True) 402 | if clean_value: 403 | pred_set.add((file, clean_key, clean_value)) 404 | if file not in all_measurements[clean_key]: 405 | all_measurements[clean_key][file] = { 406 | "pred": "", 407 | "gold": "", 408 | } 409 | all_measurements[clean_key][file]['pred'] = clean_value 410 | 411 | clean_pred_attributes = set([x[1] for x in pred_set]) 412 | # resolve mapping between gold and pred attributes 413 | gold_attribute_mapping = {} 414 | for gold_attribute in gold_attributes: 415 | if gold_attribute in clean_pred_attributes: 416 | gold_attribute_mapping[gold_attribute] = gold_attribute 417 | continue 418 | found = False 419 | if set_dicts and gold_attribute in set_dicts: 420 | alternate_golds = set_dicts[gold_attribute] 421 | for alternate_gold in alternate_golds: 422 | if alternate_gold in clean_pred_attributes: 423 | gold_attribute_mapping[gold_attribute] = alternate_gold 424 | found = True 425 | if not found: 426 | if gold_attribute.strip('s') in clean_pred_attributes: 427 | gold_attribute_mapping[gold_attribute] = gold_attribute.strip('s') 428 | elif gold_attribute+"s" in clean_pred_attributes: 429 | gold_attribute_mapping[gold_attribute] = gold_attribute+"s" 430 | else: 431 | gold_attribute_mapping[gold_attribute] = gold_attribute 432 | 433 | num_attributes = len(clean_pred_attributes) 434 | 435 | gold_set = set() 436 | for file, extraction_dict in gold_extractions.items(): 437 | if sample_files and file not in sample_files: 438 | continue 439 | for key, value in extraction_dict.items(): 440 | # ignore attributes in a different slice 441 | if key not in gold_attributes: 442 | continue 443 | 444 | if key == "topic_entity_name": 445 | if "name" in pred_attributes: 446 | gold_attribute_mapping[key] = "name" 447 | 448 | key = gold_attribute_mapping[key] 449 | if key not in pred_attributes: 450 | if key.replace(" ", "") in pred_attributes: 451 | key = key.replace(" ", "") 452 | 453 | # sort list-based attribute values for consistency. 454 | if file in pred_extractions and key in pred_extractions[file]: 455 | pred_value = pred_extractions[file][key] 456 | value_check = '' 457 | pred_value_check = '' 458 | if type(pred_value) == list and type(pred_value[0]) == str: 459 | pred_value_check = sorted([p.strip() for p in pred_value]) 460 | elif type(pred_value) == str and "," in pred_value: 461 | pred_value_check = sorted([p.strip() for p in pred_value.split(",")]) 462 | if type(value) == list: 463 | value_check = value[0] 464 | if "," in value: 465 | value_check = sorted([p.strip() for p in value.split(",")]) 466 | if value_check and pred_value_check and value_check == pred_value_check: 467 | value = pred_value 468 | 469 | clean_key = clean_comparison(key, exact_match=True) 470 | clean_value = clean_comparison(value, attribute=key, exact_match=True) 471 | 472 | if clean_value: 473 | gold_set.add((file, clean_key, clean_value)) 474 | if file not in all_measurements[clean_key]: 475 | all_measurements[clean_key][file] = { 476 | "pred": "", 477 | "gold": "", 478 | } 479 | all_measurements[clean_key][file]['gold'] = clean_value 480 | 481 | if not pred_set or not gold_set: 482 | results[attribute_slice] = { 483 | "precision": 0, 484 | "recall": 0, 485 | "f1": 0, 486 | "num_files_evaluated": len(pred_extractions), 487 | } 488 | else: 489 | # exact match over all fields 490 | precision = set_precision(pred_set, gold_set) 491 | recall = set_recall(pred_set, gold_set) 492 | f1 = compute_f1(precision, recall) 493 | 494 | results[attribute_slice] = { 495 | "precision": precision, 496 | "recall": recall, 497 | "f1": f1, 498 | "num_files_evaluated": len(pred_extractions), 499 | } 500 | print(f"[%s] OpenIE Precision (%d attributes): Precision: %.3f Recall: %.3f F1: %.3f" % (attribute_slice, num_attributes, precision, recall, f1)) 501 | return results if slice_results else results["all"] 502 | 503 | 504 | def main( 505 | run_string, 506 | args, 507 | profiler_args, 508 | data_lake = "wiki_nba_players", 509 | sample_files=None, 510 | stage='', 511 | gold_attributes=[], 512 | mappings_names={} 513 | ): 514 | gold_extractions_file = args.gold_extractions_file 515 | train_size = profiler_args.train_size 516 | 517 | overall_results = {} 518 | 519 | if stage and stage != 'schema_id': 520 | pass 521 | else: 522 | schema_id_results = evaluate_schema_identification( 523 | run_string, 524 | args, 525 | data_lake, 526 | train_size=train_size, 527 | ) 528 | overall_results["schema_id"] = schema_id_results 529 | 530 | if stage and stage != 'extract': 531 | pass 532 | else: 533 | extraction_results = evaluate_extraction_quality( 534 | run_string, 535 | args, 536 | gold_extractions_file, 537 | profiler_args.data_dir, 538 | gold_attributes=gold_attributes 539 | ) 540 | overall_results["extraction"] = extraction_results 541 | 542 | if stage and stage != 'openie': 543 | pass 544 | else: 545 | openie_results = evaluate_openie_quality( 546 | run_string, 547 | args, 548 | gold_extractions_file, 549 | sample_files=sample_files, 550 | slice_results = profiler_args.slice_results, 551 | mappings_names = mappings_names 552 | ) 553 | overall_results["openie"] = openie_results 554 | 555 | return overall_results 556 | 557 | 558 | if __name__ == "__main__": 559 | main() --------------------------------------------------------------------------------