├── assets
└── banner.png
├── .gitmodules
├── setup.py
├── evaporate
├── run.sh
├── evaluate_synthetic_utils.py
├── retrieval.py
├── weak_supervision
│ ├── ws_utils.py
│ ├── pgm.py
│ ├── run_ws.py
│ └── binary_deps.py
├── prompts.py
├── schema_identification.py
├── utils.py
├── evaluate_profiler.py
├── main.py
├── configs.py
├── profiler_utils.py
├── run_profiler.py
└── evaluate_synthetic.py
├── .gitignore
├── README.md
├── example.ipynb
└── demo.ipynb
/assets/banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HazyResearch/evaporate/HEAD/assets/banner.png
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "metal-evap"]
2 | path = metal-evap
3 | url = git@github.com:simran-arora/metal-evap.git
4 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | _REQUIRED = [
4 | "tqdm",
5 | "openai",
6 | "manifest-ml",
7 | "pandas",
8 | "snorkel",
9 | "cvxpy",
10 | "bs4",
11 | "snorkel-metal",
12 | "tensorboardX",
13 | "numpy == 1.20.3",
14 | "networkx == 2.3"
15 | ]
16 |
17 | setup(
18 | name="evaporate",
19 | version="0.0.1",
20 | description="evaporating data lakes with foundation models",
21 | author="simran brandon sabri avanika andrew immanuel chris",
22 | packages=["evaporate"],
23 | install_requires=_REQUIRED,
24 | )
25 |
--------------------------------------------------------------------------------
/evaporate/run.sh:
--------------------------------------------------------------------------------
1 | keys=# INSERT YOUR API KEY(S) HERE
2 |
3 | #evaporate code clse ie
4 | python run_profiler.py \
5 | --data_lake fda_510ks \
6 | --num_attr_to_cascade 50 \
7 | --num_top_k_scripts 10 \
8 | --train_size 10 \
9 | --combiner_mode ws \
10 | --use_dynamic_backoff \
11 | --KEYS ${keys}\
12 | --data_dir /data/fda_510ks/data/evaporate/fda-ai-pmas/510k \
13 | --base_data_dir /data/evaporate/data/fda_510ks \
14 | --gold_extractions_file /data/evaporate/data/fda_510ks/table.json \
15 | #evaporate code open ie
16 | python run_profiler.py \
17 | --data_lake fda_510ks \
18 | --num_attr_to_cascade 50 \
19 | --num_top_k_scripts 10 \
20 | --train_size 10 \
21 | --combiner_mode ws \
22 | --use_dynamic_backoff \
23 | --KEYS ${keys}\
24 | --do_end_to_end \
25 | --data_dir /data/fda_510ks/data/evaporate/fda-ai-pmas/510k \
26 | --base_data_dir /data/evaporate/data/fda_510ks \
27 | --gold_extractions_file /data/evaporate/data/fda_510ks/table.json \
28 |
--------------------------------------------------------------------------------
/evaporate/evaluate_synthetic_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import Counter, defaultdict
3 |
4 | def text_f1(preds=[], golds=[], attribute= ''):
5 | """Compute average F1 of text spans.
6 | Taken from Squad without prob threshold for no answer.
7 | """
8 | total_f1 = 0
9 | total_recall = 0
10 | total_prec = 0
11 | f1s = []
12 | for pred, gold in zip(preds, golds):
13 | if isinstance(pred, list):
14 | pred = ' '.join(pred) # Example way to convert list to string
15 | if isinstance(gold, list):
16 | gold = ' '.join(gold) # Example way to convert list to string
17 | pred_toks = pred.split()
18 | gold_toks = gold.split()
19 | common = Counter(pred_toks) & Counter(gold_toks)
20 | num_same = sum(common.values())
21 | if len(gold_toks) == 0 or len(pred_toks) == 0:
22 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
23 | total_f1 += int(gold_toks == pred_toks)
24 | f1s.append(int(gold_toks == pred_toks))
25 | elif num_same == 0:
26 | total_f1 += 0
27 | f1s.append(0)
28 | else:
29 | precision = 1.0 * num_same / len(pred_toks)
30 | recall = 1.0 * num_same / len(gold_toks)
31 | f1 = (2 * precision * recall) / (precision + recall)
32 | total_f1 += f1
33 | total_recall += recall
34 | total_prec += precision
35 | f1s.append(f1)
36 | f1_avg = total_f1 / len(golds)
37 | f1_median = np.percentile(f1s, 50)
38 | return f1_avg, f1_median
39 |
40 | def get_file_attribute(attribute):
41 | attribute = attribute.lower()
42 | attribute = attribute.replace("/", "_").replace(")", "").replace("-", "_")
43 | attribute = attribute.replace("(", "").replace(" ", "_")
44 | if len(attribute) > 30:
45 | attribute = attribute[:30]
46 | return attribute
--------------------------------------------------------------------------------
/evaporate/retrieval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoTokenizer, AutoModel
3 |
4 |
5 | def mean_pooling(token_embeddings, mask):
6 | token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
7 | sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
8 | return sentence_embeddings
9 |
10 | def get_embeddings(sentences):
11 | tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')
12 | model = AutoModel.from_pretrained('facebook/contriever')
13 |
14 | sentences = [
15 | "Where was Marie Curie born?",
16 | "Maria Sklodowska, later known as Marie Curie, was born on November 7, 1867.",
17 | "Born in Paris on 15 May 1859, Pierre Curie was the son of Eugène Curie, a doctor of French Catholic origin from Alsace."
18 | ]
19 |
20 | # Apply tokenizer
21 | inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
22 |
23 | # Compute token embeddings
24 | outputs = model(**inputs)
25 |
26 | # Mean pooling
27 |
28 | embeddings = mean_pooling(outputs[0], inputs['attention_mask'])
29 | return embeddings
30 |
31 | def get_most_similarity(target_sentence, sentences):
32 | target_embedding = get_embeddings([target_sentence])[0]
33 | embeddings = get_embeddings(sentences)
34 | max_similarity = torch.nn.functional.cosine_similarity(target_embedding, embeddings, dim = -1)
35 | most_similar_sentence = max_similarity.argmax()
36 | return sentences[most_similar_sentence]
37 |
38 | #write a main function to test the code with if __name__ == "__main__":
39 | if __name__ == "__main__":
40 | target_sentence = "Where was Marie Curie born?"
41 | sentences = [
42 | "Where was Marie Curie born?",
43 | "Maria Sklodowska, later known as Marie Curie, was born on November 7, 1867.",
44 | "Born in Paris on 15 May 1859, Pierre Curie was the son of Eugène Curie, a doctor of French Catholic origin from Alsace."
45 | ]
46 | print(get_most_similarity(target_sentence, sentences))
47 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Custom
132 | .sqlite.cache
133 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Evaporate
2 |
3 |
4 |

5 |
6 |
7 | Code, datasets, and extended writeup for paper [Language Models Enable Simple Systems for Generating Structured Views of Heterogeneous Data Lakes](https://www.vldb.org/pvldb/vol17/p92-arora.pdf).
8 |
9 | ## Setup
10 |
11 | We encourage the use of conda environments:
12 | ```
13 | conda create --name evaporate python=3.8
14 | conda activate evaporate
15 | ```
16 |
17 | Clone as follows:
18 | ```bash
19 | # Evaporate code
20 | git clone git@github.com:HazyResearch/evaporate.git
21 | cd evaporate
22 | pip install -e .
23 |
24 | # Weak supervision code
25 | cd metal-evap
26 | git submodule init
27 | git submodule update
28 | pip install -e .
29 |
30 | # Manifest (to install from source, which helps you modify the set of supported models. Otherwise, ``setup.py`` installs ``manifest-ml``)
31 | git clone git@github.com:HazyResearch/manifest.git
32 | cd manifest
33 | pip install -e .
34 | ```
35 |
36 | ## Datasets
37 | The data used in the paper is hosted on Hugging Face's datasets platform: https://huggingface.co/datasets/hazyresearch/evaporate.
38 |
39 | To download the datasets, run the following commands in your terminal:
40 | ```bash
41 | git lfs install
42 | git clone https://huggingface.co/datasets/hazyresearch/evaporate
43 | ```
44 |
45 | Or download it via Python:
46 | ```python
47 | from datasets import load_dataset
48 | dataset = load_dataset("hazyresearch/evaporate")
49 | ```
50 |
51 | The code expects the data to be stored at ``/data/evaporate/`` as specified in ``constants.py`` CONSTANTS, though can be modified.
52 |
53 |
54 | ## Running the code
55 | Run closed IE and open IE using the commands:
56 |
57 | ```cd src/
58 | bash run.sh
59 | ```
60 |
61 | The ``keys`` in run.sh can be obtained by registering with the LLM provider. For instance, if you want to run inference with the OpenAI API models, create an account [here](https://openai.com/api/).
62 |
63 | The script includes commands for both closed and open IE runs. To walk through the code, look at ``run_profiler.py``. For open IE, the code first uses ``schema_identification.py`` to generate a list of attributes for the schema. Next, the code iterates through this list to perform extraction using ``profiler.py``. As functions are generated in ``profiler.py``, ``evaluate_profiler.py`` is used to score the function outputs against the outputs of directly prompting the LM on the sample documents.
64 |
65 |
66 | ## Citation
67 | If you use this codebase, or otherwise found our work valuable, please cite:
68 | ```
69 | @article{arora2023evaporate,
70 | title={Language Models Enable Simple Systems for Generating Structured Views of Heterogeneous Data Lakes},
71 | author={Arora, Simran and Yang, Brandon and Eyuboglu, Sabri and Narayan, Avanika and Hojel, Andrew and Trummer, Immanuel and R\'e, Christopher},
72 | journal={arXiv:2304.09433},
73 | year={2023}
74 | }
75 | ```
76 |
--------------------------------------------------------------------------------
/evaporate/weak_supervision/ws_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import itertools
3 |
4 | def get_probabilties(num_lfs, num_examples, predictions, label_name_to_int):
5 |
6 | lf_array = np.zeros((num_lfs, num_examples))
7 | golds = []
8 |
9 | # Collect golds and preds
10 | for i, (k, item) in enumerate(predictions.items()):
11 | preds = item['chosen_answers_lst']
12 | preds_mapped = []
13 | for p in preds:
14 | if p in label_name_to_int:
15 | preds_mapped.append(label_name_to_int[p])
16 | else:
17 | preds_mapped.append(0)
18 | preds = preds_mapped.copy()
19 | for lf_num, p in zip(range(num_lfs), preds):
20 | lf_array[lf_num][i] = p
21 | gold = label_name_to_int[item['gold']]
22 | golds.append(gold)
23 | golds = np.array(golds)
24 | neg_indices, pos_indices = [np.where(golds == -1)[0], np.where(golds == 1)[0]]
25 | indices = {
26 | -1: neg_indices,
27 | 1: pos_indices
28 | }
29 |
30 | # [i, j, k] = Pr(prompt_i = j| y = k)
31 | # Accuracies
32 | lf_accuracies = []
33 | for i in range(num_lfs):
34 | lf_accuracies.append(np.sum(golds == np.array(lf_array[i]))/num_examples)
35 | print(f"LF Accs: {lf_accuracies}")
36 |
37 | # [i, j, k] = Pr(prompt_i = j| y = k)
38 | classes = label_name_to_int.values()
39 | accs = np.zeros((num_lfs, len(classes), len(classes)))
40 | for p in range(num_lfs):
41 | for i in classes:
42 | for j in classes:
43 | j_idx = j
44 | if j == -1:
45 | j_idx = 0
46 | i_idx = i
47 | if i == -1:
48 | i_idx = 0
49 | accs[p, i_idx, j_idx] = len(np.where(lf_array[p, indices[i]] == j)[0]) / len(indices[i])
50 |
51 | # Compute probabilities
52 | pos_probs = []
53 | for i in range(num_lfs):
54 | sub_preds = lf_array[i][pos_indices]
55 | sub_golds = golds[pos_indices]
56 | pos_probs.append(np.sum(sub_golds == np.array(sub_preds))/len(pos_indices))
57 | print(f"Pos Probs: {pos_probs}")
58 |
59 | neg_probs = []
60 | for i in range(num_lfs):
61 | sub_preds = lf_array[i][neg_indices]
62 | sub_golds = golds[neg_indices]
63 | neg_probs.append(np.sum(sub_golds == np.array(sub_preds))/len(neg_indices))
64 | print(f"Neg Probs: {neg_probs}\n\n")
65 |
66 | return lf_accuracies, accs, pos_probs, neg_probs, golds, indices
67 |
68 |
69 | """ Independence Assumption: take the product of probabilities as p(L1, L2, ..., LK | y) """
70 |
71 | # Pr(y = 1 | lf votes)
72 | def get_cond_probs(votes, y, indices_train, golds_train, accs_train, num_lfs_test):
73 | prop_pos = len(indices_train[1])/len(golds_train)
74 | pr_y = prop_pos if y == 1 else 1 - prop_pos
75 | prod = pr_y
76 | for i in range(num_lfs_test):
77 | if y == -1:
78 | y = 0
79 | prod *= accs_train[i, y, votes[i]]
80 | return prod
81 |
82 | # Pr(y = 1 | lf votes)
83 | def get_probs(votes, indices_train, golds_train, acc_train, num_lfs_test):
84 | votes = [max(v, 0) for v in votes]
85 | numerator = get_cond_probs(votes, 1, indices_train, golds_train, acc_train, num_lfs_test)
86 | denominator = numerator + get_cond_probs(votes, -1, indices_train, golds_train, acc_train, num_lfs_test)
87 | return numerator / denominator
88 |
89 |
90 | def get_nb_accuracy(num_examples_test, num_lfs_test, predictions_test, label_name_to_int, golds_test, indices_train, golds_train, accs_train):
91 | output = np.zeros(num_examples_test)
92 | errors = 0
93 | for i, (k, item) in enumerate(predictions_test.items()):
94 | votes = item['chosen_answers_lst']
95 | votes_mapped = []
96 | for v in votes:
97 | if v in label_name_to_int:
98 | votes_mapped.append(label_name_to_int[v])
99 | else:
100 | votes_mapped.append(0)
101 | votes = votes_mapped.copy()
102 | probs = np.round(get_probs(votes, indices_train, golds_train, accs_train, num_lfs_test))
103 | output[i] = probs
104 |
105 | # Mean squared error
106 | g = golds_test[i]
107 | if golds_test[i] == -1:
108 | g = 0
109 | error = np.abs(output[i] - g)**2
110 | errors += error
111 | accuracy = 1 - (errors / num_examples_test)
112 | return accuracy, output
113 |
114 |
115 | def estimate_matrix(m, n, L):
116 | E_prod = np.zeros((m, m))
117 | l_avg = np.zeros(m)
118 | for i in range(n):
119 | l = L[i, :]
120 | l_avg += l
121 | E_prod += np.outer(l, l)
122 |
123 | l_avg = l_avg/n
124 | E_prod = E_prod/n
125 |
126 | cov = E_prod - np.outer(l_avg, l_avg)
127 |
128 | return (E_prod, cov, l_avg)
129 |
130 |
131 | def get_vote_vectors(num_samples, num_lfs, predictions, label_name_to_int):
132 | vectors = np.zeros((num_samples, num_lfs+1), float)
133 | vectors_no_y = np.zeros((num_samples, num_lfs), float)
134 | labels_vector = np.zeros((num_samples, 1), float)
135 | for i, p in enumerate(predictions.values()):
136 | votes = p['chosen_answers_lst']
137 | votes_mapped = []
138 | for v in votes:
139 | if v in label_name_to_int:
140 | votes_mapped.append(label_name_to_int[v])
141 | else:
142 | votes_mapped.append(0)
143 | votes = votes_mapped.copy()
144 | # votes = [max(v, 0) for v in votes]
145 | gold = p['gold']
146 | gold = label_name_to_int[gold]
147 | vectors_no_y[i] = np.array(votes)
148 | vectors[i] = np.array(votes + [gold]) #- lf_accuracies_train
149 | labels_vector[i] = np.array([gold])
150 | print(f"Shape: {vectors.shape}")
151 | print(f"Sample: {vectors[0]}")
152 |
153 | return vectors, vectors_no_y, labels_vector
154 |
155 | def get_feature_vector(vote_vectors, include_pairwise=False, include_singletons=True):
156 | feature_vectors = []
157 | for votes in vote_vectors:
158 | if include_singletons:
159 | feature_vector = list(votes[:])
160 | else:
161 | feature_vector = []
162 | if include_pairwise:
163 | for subset in itertools.combinations(votes[:], 2):
164 | feature_vector.append(subset[0] * subset[1])
165 | feature_vectors.append(feature_vector)
166 | X = np.matrix(feature_vectors)
167 | return X
--------------------------------------------------------------------------------
/example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 12,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stderr",
10 | "output_type": "stream",
11 | "text": [
12 | "UsageError: Line magic function `%autoreload` not found.\n"
13 | ]
14 | }
15 | ],
16 | "source": [
17 | "%autoreload 2\n",
18 | "%load_ext autoreload"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 13,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "import time\n",
28 | "from tqdm import tqdm\n",
29 | "import sys\n",
30 | "sys.path.append(f\"./evaporate/\")\n",
31 | "from run_profiler import prerun_profiler, identify_attributes, get_attribute_function\n",
32 | "from configs import set_profiler_args"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "MANIFEST_URL = \"http://127.0.0.1:5000\" # please make sure that a local manifest session is running with your model at this address\n",
42 | "DATA_DIR = \"/var/cr05_data/sim_data/code/evaporate/data/\""
43 | ]
44 | },
45 | {
46 | "cell_type": "markdown",
47 | "metadata": {},
48 | "source": [
49 | "# Set Up"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "profiler_args= set_profiler_args({\n",
59 | " \"data_lake\": \"wiki_nba\", \n",
60 | " \"num_attr_to_cascade\": 50, \n",
61 | " \"num_top_k_scripts\": 10, \n",
62 | " \"train_size\": 10, \n",
63 | " \"combiner_mode\": \"mv\", \n",
64 | " \"do_end_to_end\": False,\n",
65 | " \"use_dynamic_backoff\": True, \n",
66 | " \"KEYS\": [\"\"], \n",
67 | " \"MODELS\":[\"mistralai/Mistral-7B-instruct-v0.1\", \"gpt-4\"],\n",
68 | " \"EXTRACTION_MODELS\": [\"mistralai/Mistral-7B-instruct-v0.1\"],\n",
69 | " \"GOLD_KEY\": \"gpt-4\",\n",
70 | " \"MODEL2URL\" : {\"mistralai/Mistral-7B-instruct-v0.1\": MANIFEST_URL},\n",
71 | " \"data_dir\": f\"{DATA_DIR}/documents/\", \n",
72 | " \"base_data_dir\": f\"{DATA_DIR}\", \n",
73 | " \"gold_extractions_file\": f\"{DATA_DIR}/data_fda_510ks_table.json\"}\n",
74 | ")\n",
75 | "print(profiler_args)"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "data_dict = prerun_profiler(profiler_args)\n",
85 | "print(data_dict.keys())"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": null,
91 | "metadata": {},
92 | "outputs": [],
93 | "source": [
94 | "print(profiler_args.overwrite_cache)"
95 | ]
96 | },
97 | {
98 | "cell_type": "markdown",
99 | "metadata": {},
100 | "source": [
101 | "# Get Schema Attribute"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": null,
107 | "metadata": {},
108 | "outputs": [],
109 | "source": [
110 | "if profiler_args.do_end_to_end:\n",
111 | " attributes, total_time, num_toks, evaluation_result = identify_attributes(profiler_args, data_dict, evaluation=True)\n",
112 | " print(\"total_time: \", total_time, \"num_toks: \", num_toks, \"evaluation_result: \", evaluation_result)\n",
113 | "else:\n",
114 | " attributes = data_dict[\"gold_attributes\"]"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "attributes"
124 | ]
125 | },
126 | {
127 | "cell_type": "markdown",
128 | "metadata": {},
129 | "source": [
130 | "Model Information\n",
131 | "Data Dir\n",
132 | "CONSTANTS = {\n",
133 | " \"fda_510ks\": {\n",
134 | " \"data_dir\": os.path.join(BASE_DATA_DIR, \"fda-ai-pmas/510k/\"),\n",
135 | " \"database_name\": \"fda_510ks\",\n",
136 | " \"cache_dir\": \".cache/fda_510ks/\",\n",
137 | " \"generative_index_path\": os.path.join(BASE_DATA_DIR, \"generative_indexes/fda_510ks/\"),\n",
138 | " \"gold_extractions_file\": os.path.join(BASE_DATA_DIR, \"ground_truth/fda_510ks_gold_extractions.json\"),\n",
139 | " \"topic\": \"fda 510k device premarket notifications\",\n",
140 | " },\n",
141 | " }"
142 | ]
143 | },
144 | {
145 | "cell_type": "markdown",
146 | "metadata": {},
147 | "source": [
148 | "# Get Attribute function"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": null,
154 | "metadata": {},
155 | "outputs": [],
156 | "source": [
157 | "functions = {}\n",
158 | "selected_keys = {}\n",
159 | "for i, attribute in enumerate(attributes):\n",
160 | " print(f\"\\n\\nExtracting {attribute} ({i+1} / {len(attributes)})\")\n",
161 | " t0 = time.time()\n",
162 | " functions[attribute], selected_keys[attribute], total_time, num_toks= get_attribute_function(\n",
163 | " profiler_args, data_dict, attribute\n",
164 | " )\n",
165 | " print(functions[attribute])"
166 | ]
167 | },
168 | {
169 | "cell_type": "markdown",
170 | "metadata": {},
171 | "source": [
172 | "# Evaluation"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": null,
178 | "metadata": {},
179 | "outputs": [],
180 | "source": [
181 | "from evaporate.evaluate_synthetic import main as evaluate_synthetic_main"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": null,
187 | "metadata": {},
188 | "outputs": [],
189 | "source": [
190 | "results = evaluate_synthetic_main(\n",
191 | " profiler_args.run_string, \n",
192 | " profiler_args, \n",
193 | " profiler_args, \n",
194 | " profiler_args.data_lake,\n",
195 | " gold_attributes=data_dict[\"gold_attributes\"], \n",
196 | " stage='extract'\n",
197 | ")\n",
198 | "print(results)"
199 | ]
200 | }
201 | ],
202 | "metadata": {
203 | "kernelspec": {
204 | "display_name": "Python 3 (ipykernel)",
205 | "language": "python",
206 | "name": "python3"
207 | },
208 | "language_info": {
209 | "codemirror_mode": {
210 | "name": "ipython",
211 | "version": 3
212 | },
213 | "file_extension": ".py",
214 | "mimetype": "text/x-python",
215 | "name": "python",
216 | "nbconvert_exporter": "python",
217 | "pygments_lexer": "ipython3",
218 | "version": "3.8.18"
219 | }
220 | },
221 | "nbformat": 4,
222 | "nbformat_minor": 4
223 | }
224 |
--------------------------------------------------------------------------------
/evaporate/weak_supervision/pgm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import itertools
3 | import matplotlib.pyplot as plt
4 | import scipy.stats
5 |
6 |
7 |
8 | class Ising():
9 |
10 |
11 | def __init__(self, m, potentials, thetas = None, vals = [-1, 1], ) -> None:
12 | self.m = m
13 | self.v = m + 1 # total number of vertices
14 | self.potentials = potentials
15 |
16 | self.vals = vals
17 | #TODO support values in 0, 1
18 |
19 | if thetas is not None:
20 | assert len(thetas) >= len(potentials), f"Need to specify at least {len(potentials)} theta parameters."
21 | self.thetas = thetas
22 | else:
23 | self.thetas = np.random.rand(len(potentials))
24 |
25 | self.support = np.array(list(map(list, itertools.product(vals, repeat=self.v))))
26 |
27 | self._make_pdf()
28 | self._make_cdf()
29 |
30 | self._get_means()
31 | self._get_balance()
32 | self._get_accs()
33 |
34 | def _exponential_family(self, labels):
35 | x = 0.0
36 | for i in range(len(self.potentials)):
37 | x += self.thetas[i] * labels[self.potentials[i]].prod()
38 |
39 | return np.exp(x)
40 |
41 | def _make_pdf(self):
42 | p = np.zeros(len(self.support))
43 | for i, labels in enumerate(self.support):
44 | p[i] = self._exponential_family(labels)
45 |
46 | self.z = sum(p)
47 |
48 | self.pdf = p/self.z
49 |
50 | def _make_cdf(self):
51 | self.cdf = np.cumsum(self.pdf)
52 |
53 |
54 |
55 |
56 | def joint_p(self, C, values):
57 | p = 0.0
58 | for k, labels in enumerate(self.support):
59 | flag = True
60 | for i in range(len(C)):
61 | prod = labels[C[i]].prod()
62 | if prod != values[i]:
63 | flag = False
64 |
65 | if flag == True:
66 | p += self.pdf[k]
67 |
68 | return p
69 |
70 | def expectation(self, C):
71 | return self.vals[0] * self.joint_p(C, self.vals[0] * np.ones(len(C))) + self.vals[1] * self.joint_p(C, self.vals[1] * np.ones(len(C)))
72 |
73 | def _get_means(self):
74 | self.means = np.zeros(self.m)
75 | for k in range(self.m):
76 | self.means[k] = self.expectation([[k]])
77 |
78 |
79 | def _get_balance(self):
80 | self.balance = self.joint_p([[self.m]], [1])
81 |
82 | # def _get_covariance(self):
83 |
84 | def _get_accs(self):
85 | """
86 | self.accs[k, i, j] = Pr(lf_k = j | y = i) (i, j scaled to -1, 1 if needed)
87 | """
88 | self.accs = np.zeros((self.m, 2, 2))
89 | for k in range(self.m):
90 | self.accs[k, 1, 1] = self.joint_p([[k], [self.m]], [self.vals[1], self.vals[1]]) / self.balance
91 | self.accs[k, 0, 0] = self.joint_p([[k], [self.m]], [self.vals[0], self.vals[0]]) / (1 - self.balance)
92 | self.accs[k, 1, 0] = 1 - self.accs[k, 1, 1]
93 | self.accs[k, 0, 1] = 1 - self.accs[k, 0, 0]
94 |
95 |
96 | def sample(self):
97 | r = np.random.random_sample()
98 | smaller = np.where(self.cdf < r)[0]
99 | if len(smaller) == 0:
100 | i = 0
101 | else:
102 | i = smaller.max() + 1
103 |
104 | return self.support[i]
105 |
106 | def make_data(self, n, has_label = True):
107 | L = np.zeros((n, self.m))
108 | gold = np.zeros(n)
109 | for i in range(n):
110 | l = self.sample()
111 | L[i, :] = l[:self.m]
112 |
113 | if has_label:
114 | gold[i] = l[self.m]
115 |
116 | return L.astype(int), gold.astype(int)
117 |
118 |
119 |
120 | def est_accs(m, vote, gold):
121 | # compute pr(lf | y) accuracies. Each prompt has 4 values (2x2)
122 | # we need to do this on the train/dev set
123 | classes = [0, 1]
124 | gold_idxs = [np.where(gold == -1)[0], np.where(gold == 1)[0]]
125 |
126 | accs = np.zeros((m, 2, 2)) # [i, j, k] = Pr(prompt_i = j| y = k)
127 | for p in range(m):
128 | for i in classes:
129 | for j in classes:
130 | accs[p, i, j] = len(np.where(vote[gold_idxs[i], p] == 2*j-1)[0]) / len(gold_idxs[i])
131 |
132 | return accs
133 |
134 | def est_balance(gold, n):
135 | return len(np.where(gold == 1)[0]) / n
136 |
137 | # Pr(lf votes, y)
138 | def get_cond_probs(m, votes, y, accs, balance):
139 | pr_y = balance if y == 1 else 1 - balance
140 | prod = pr_y
141 | for i in range(m):
142 | prod *= accs[i, y, int(0.5*(votes[i] + 1))] # this assumes everything is independent
143 | return prod
144 |
145 | # Pr(y = 1 | lf votes)
146 | def get_probs(m, votes, accs, balance):
147 | pos = get_cond_probs(m, votes, 1, accs, balance)
148 | neg = get_cond_probs(m, votes, 0, accs, balance)
149 |
150 | if pos == 0:
151 | return 0
152 | else:
153 | return pos / (pos + neg)
154 |
155 |
156 | def pick_best_prompt(m, vote, gold, n):
157 | # overall accuracies Pr(lf_p = y) on test (we don't know these)
158 | overall_train_acc = np.zeros(m)
159 | for i in range(m):
160 | overall_train_acc[i] = len(np.where((vote[:, i] == gold) == True)[0])/n
161 |
162 | return overall_train_acc.argmax()
163 |
164 |
165 | def main():
166 |
167 | # number of weak labels
168 | m = 5
169 |
170 | # total number of vertices
171 | v = m + 1
172 |
173 | # randomly parametrize exponential family to determine accuracies and correlations
174 | #theta = np.random.rand()
175 | #theta_cliques = (np.random.randint(0, 2, 5)*2 - 1)*theta
176 | #theta = np.random.rand()
177 | #theta_cliques = [1, 1, 1, 1, 1, 1, 1]
178 | thetas = np.random.rand(30)
179 |
180 | # all conditionally independent
181 | potentials = [[5], [0], [1], [4], [0, 5], [1, 5], [2, 5], [3, 5], [4, 5]]
182 |
183 | pgm = Ising(m, potentials, thetas)
184 |
185 | n_train = 10000
186 | vote_train, gold_train = pgm.make_data(n_train)
187 |
188 | n_test = 1000
189 | vote_test, gold_test = pgm.make_data(n_test)
190 |
191 | accs = est_accs(m, vote_train, gold_train)
192 | balance = est_balance(gold_train, n_train)
193 |
194 | nb_output = np.zeros(n_test) # naive bayes
195 | mv_output = np.zeros(n_test)
196 |
197 | nb_err = 0
198 | mv_err = 0
199 |
200 | for i in range(n_test):
201 | nb_output[i] = 2*np.round(get_probs(m, vote_test[i], accs, balance))-1
202 | if nb_output[i] != gold_test[i]:
203 | nb_err += 1
204 |
205 |
206 | # note: play around with MV tie breaking strategy
207 | if len(np.where(vote_test[i] == 1)[0]) >= m / 2:
208 | mv_output[i] = 1
209 | elif len(np.where(vote_test[i] == 1)[0]) < m / 2:
210 | mv_output[i] = -1
211 | else:
212 | mv_output[i] = 2*np.random.randint(0, 2)-1
213 |
214 | if mv_output[i] != gold_test[i]:
215 | mv_err += 1
216 |
217 | nb_acc = 1 - (nb_err / n_test)
218 | mv_acc = 1 - (mv_err / n_test)
219 | #fs_acc = 1 - (fs_err / n_test)
220 |
221 | best_prompt = pick_best_prompt(m, vote_train, gold_train, n_train)
222 |
223 | best_prompt_acc = len(np.where((vote_test[:, best_prompt] == gold_test) == True)[0]) / n_test
224 |
225 | print(f"Naive bayes: {nb_acc}")
226 | print(f"Best prompt: {best_prompt_acc}")
227 | print(f"Majority vote: {mv_acc}")
228 |
229 |
230 | if __name__ == "__main__":
231 | main()
--------------------------------------------------------------------------------
/evaporate/weak_supervision/run_ws.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import json
4 | import sys
5 | import pickle
6 | import random
7 | import cvxpy as cp
8 | import scipy as sp
9 | from tqdm import tqdm
10 | from evaporate.weak_supervision.methods import Aggregator
11 | from metal.label_model import LabelModel
12 | from collections import defaultdict, Counter
13 |
14 | from evaporate.evaluate_synthetic import clean_comparison
15 |
16 |
17 | def get_data(
18 | all_votes,
19 | gold_extractions_file,
20 | attribute='',
21 | has_abstains=1.0,
22 | num_elts = 5,
23 | extraction_fraction_thresh=0.9,
24 | ):
25 | """
26 | Load in dataset from task_name depending on where files are saved.
27 |
28 | - num_elts = number of ``choices'' to use in the multiple-choice setup.
29 |
30 | """
31 | label_name_to_ints = []
32 | has_abstains = has_abstains >= extraction_fraction_thresh
33 |
34 | try:
35 | with open(gold_extractions_file) as f:
36 | gold_extractions = json.load(f)
37 | except:
38 | with open(gold_extractions_file, "rb") as f:
39 | gold_extractions = pickle.load(f)
40 |
41 | test_votes = []
42 | test_golds = []
43 | total_abstains = []
44 | average_unique_votes = []
45 | missing_files = []
46 | random.seed(0)
47 | for file, extractions in tqdm(all_votes.items()):
48 | if file not in gold_extractions:
49 | missing_files.append(file)
50 | continue
51 |
52 | extractions = [clean_comparison(e) for e in extractions]
53 | if has_abstains:
54 | extractions = [e if e else 'abstain' for e in extractions]
55 |
56 | unique_votes = Counter(extractions).most_common(num_elts)
57 | unique_votes = [i for i, _ in unique_votes if i != 'abstain']
58 | average_unique_votes.append(len(unique_votes))
59 | if len(unique_votes) < num_elts:
60 | missing_elts = num_elts - len(unique_votes)
61 | for elt_num in range(missing_elts):
62 | unique_votes.append(f"dummy{elt_num}")
63 |
64 | random.shuffle(unique_votes)
65 | label_name_to_int = {elt: j for j, elt in enumerate(unique_votes)}
66 | label_name_to_ints.append(label_name_to_int)
67 |
68 | test_votes.append(np.array(
69 | [label_name_to_int[ans] if ans in label_name_to_int else -1 for ans in extractions]
70 | ))
71 |
72 | num_abstains = len([a for a in extractions if a not in label_name_to_int])
73 | total_abstains.append(num_abstains)
74 |
75 | # golds are just for class balance purposes
76 | if attribute in gold_extractions[file]:
77 | gold = gold_extractions[file][attribute]
78 | elif clean_comparison(attribute) in gold_extractions[file]:
79 | gold = gold_extractions[file][clean_comparison(attribute)]
80 | else:
81 | gold = ""
82 | gold = clean_comparison(gold)
83 | if gold in label_name_to_int:
84 | test_golds.append(label_name_to_int[gold])
85 | else:
86 | gold = random.sample(range(len(label_name_to_int)), 1)
87 | test_golds.append(gold[0])
88 |
89 | test_votes = np.array(test_votes)
90 | test_gold = np.array(test_golds)
91 |
92 | test_votes = test_votes.astype(int)
93 | test_gold = test_gold.astype(int)
94 |
95 | print(f"Average abstains across documents: {np.mean(total_abstains)}")
96 | print(f"Average unique votes per document: {np.mean(average_unique_votes)}")
97 |
98 | return test_votes, test_gold, label_name_to_ints, missing_files
99 |
100 |
101 | def get_top_deps_from_inverse_sig(J, k):
102 | m = J.shape[0]
103 | deps = []
104 | sorted_idxs = np.argsort(np.abs(J), axis=None)
105 | n = m*m
106 | idxs = sorted_idxs[-k:]
107 | for idx in idxs:
108 | i = int(np.floor(idx / m))
109 | j = idx % m
110 | if (j, i) in deps:
111 | continue
112 | deps.append((i, j))
113 | return deps
114 |
115 |
116 | def learn_structure(L):
117 | m = L.shape[1]
118 | n = float(np.shape(L)[0])
119 | sigma_O = (np.dot(L.T,L))/(n-1) - np.outer(np.mean(L,axis=0), np.mean(L,axis=0))
120 |
121 | #bad code
122 | O = 1/2*(sigma_O+sigma_O.T)
123 | O_root = np.real(sp.linalg.sqrtm(O))
124 |
125 | # low-rank matrix
126 | L_cvx = cp.Variable([m,m], PSD=True)
127 |
128 | # sparse matrix
129 | S = cp.Variable([m,m], PSD=True)
130 |
131 | # S-L matrix
132 | R = cp.Variable([m,m], PSD=True)
133 |
134 | #reg params
135 | lam = 1/np.sqrt(m)
136 | gamma = 1e-8
137 |
138 | objective = cp.Minimize(0.5*(cp.norm(R @ O_root, 'fro')**2) - cp.trace(R) + lam*(gamma*cp.pnorm(S,1) + cp.norm(L_cvx, "nuc")))
139 | constraints = [R == S - L_cvx, L_cvx>>0]
140 |
141 | prob = cp.Problem(objective, constraints)
142 | result = prob.solve(verbose=False, solver=cp.SCS)
143 | opt_error = prob.value
144 |
145 | #extract dependencies
146 | J_hat = S.value
147 |
148 | if J_hat is None:
149 | raise ValueError("CVXPY failed to solve the structured learning problem, use result without dependencies.")
150 |
151 | for i in range(m):
152 | J_hat[i, i] = 0
153 | return J_hat
154 |
155 |
156 | def learn_structure_multiclass(L, k):
157 | m = L.shape[1]
158 | J_hats = np.zeros((k, m, m))
159 | for c in range(k):
160 |
161 | all_votes_c = np.where(L == c, 1, 0)
162 | J_hats[c] = learn_structure(all_votes_c)
163 |
164 | return J_hats
165 |
166 |
167 | def get_min_off_diagonal(J_hat):
168 | J_hat_copy = J_hat.copy()
169 | for i in range(len(J_hat_copy)):
170 | J_hat_copy[i, i] = np.inf
171 | return np.abs(J_hat_copy).min()
172 |
173 |
174 | def run_ws(
175 | all_votes,
176 | gold_extractions_file,
177 | symmetric=True,
178 | attribute='',
179 | has_abstains=1.0,
180 | extraction_fraction_thresh=0.9,
181 | ):
182 | test_votes, test_gold, label_name_to_ints, missing_files = get_data(
183 | all_votes,
184 | gold_extractions_file,
185 | attribute=attribute,
186 | has_abstains=has_abstains,
187 | extraction_fraction_thresh=extraction_fraction_thresh,
188 | )
189 |
190 | classes = np.sort(np.unique(test_gold))
191 | vote_classes = np.sort(np.unique(test_votes))
192 | n_test, m = test_votes.shape
193 | k = len(classes)
194 | abstains = len(vote_classes) == len(classes) + 1
195 | print(f"Abstains: {abstains}")
196 |
197 | m = test_votes.shape[1]
198 | all_votes = test_votes
199 |
200 | label_model = LabelModel(k=k, seed=123)
201 |
202 | # scale to 0, 1, 2 (0 is abstain)
203 | test_votes_scaled = (test_votes + np.ones((n_test, m))).astype(int)
204 | test_gold_scaled = (test_gold + np.ones(n_test)).astype(int)
205 | all_votes_scaled = test_votes_scaled
206 |
207 | label_model.train_model(
208 | all_votes_scaled,
209 | Y_dev=test_gold_scaled,
210 | abstains=abstains,
211 | symmetric=symmetric,
212 | n_epochs=10000,
213 | log_train_every=1000,
214 | lr=0.00001
215 | )
216 |
217 | print('Trained Label Model Metrics (No deps):')
218 | scores, preds = label_model.score(
219 | (test_votes_scaled, test_gold_scaled),
220 | metric=['accuracy','precision', 'recall', 'f1']
221 | )
222 | print(scores)
223 | all_votes_no_abstains = np.where(all_votes == -1, 0, all_votes)
224 |
225 | used_deps = False
226 | try:
227 | if len(classes) == 2:
228 | J_hat = learn_structure(all_votes_no_abstains)
229 | else:
230 | J_hats = learn_structure_multiclass(all_votes_no_abstains, len(classes))
231 | J_hat = J_hats.mean(axis=0)
232 |
233 | # if values in J are all too large, then everything is connected / structure learning isn't learning the right thing. Don't model deps then
234 | min_entry = get_min_off_diagonal(J_hat)
235 | if min_entry < 1:
236 | deps = get_top_deps_from_inverse_sig(J_hat, 1)
237 | print("Recovered dependencies: ", deps)
238 |
239 | label_model.train_model(
240 | all_votes_scaled,
241 | Y_dev=test_gold_scaled,
242 | abstains=abstains,
243 | symmetric=symmetric,
244 | n_epochs=80000,
245 | log_train_every=1000,
246 | lr=0.000001,
247 | deps=deps
248 | )
249 | print('Trained Label Model Metrics (with deps):')
250 | scores, preds = label_model.score(
251 | (test_votes_scaled, test_gold_scaled),
252 | metric=['accuracy', 'precision', 'recall', 'f1']
253 | )
254 | print(scores)
255 | used_deps = True
256 | except:
257 | print(f"Not modeling dependencies.")
258 |
259 | # convert the preds back
260 | mapped_preds = []
261 | for label_name_to_int, pred in tqdm(zip(label_name_to_ints, preds)):
262 | int_to_label_name = {v:k for k, v in label_name_to_int.items()}
263 | try:
264 | pred = int_to_label_name[pred-1]
265 | except:
266 | pred = ''
267 | mapped_preds.append(pred)
268 | return mapped_preds, used_deps, missing_files
269 |
270 |
271 | if __name__ == "__main__":
272 | run_ws()
273 |
--------------------------------------------------------------------------------
/evaporate/prompts.py:
--------------------------------------------------------------------------------
1 | ############################ SCHEMA ID PROMPTS ############################
2 | SCHEMA_ID_PROMPTS = [
3 | f"""Sample text:
4 | | Charles III |
5 | | Mary Simon |
6 | Provinces and Territories
7 |
8 | - Saskatchewan
9 | - Manitoba
10 | - Ontario
11 | - Quebec
12 | - New Brunswick
13 | - Prince Edward Island
14 | - Nova Scotia
15 | - Newfoundland and Labrador
16 | - Yukon
17 | - Nunavut
18 | - Northwest Territories
19 |
20 |
21 | Question: List all relevant attributes about 'Canada' that are exactly mentioned in this sample text if any.
22 | Answer:
23 | - Monarch: Charles III
24 | - Governor General: Mary Simon
25 | - Provinces and Territories: Saskatchewan, Manitoba, Ontario, Quebec, New Brunswick, Prince Edward Island, Nova Scotia, Newfoundland and Labrador, Yukon, Nunavut, Northwest Territories
26 |
27 | ----
28 |
29 | Sample text:
30 | Patient birth date: 1990-01-01
31 | Prescribed medication: aspirin, ibuprofen, acetaminophen
32 | Prescribed dosage: 1 tablet, 2 tablets, 3 tablets
33 | Doctor's name: Dr. Burns
34 | Date of discharge: 2020-01-01
35 | Hospital address: 123 Main Street, New York, NY 10001
36 |
37 | Question: List all relevant attributes about 'medications' that are exactly mentioned in this sample text if any.
38 | Answer:
39 | - Prescribed medication: aspirin, ibuprofen, acetaminophen
40 | - Prescribed dosage: 1 tablet, 2 tablets, 3 tablets
41 |
42 | ----
43 |
44 | Sample text:
45 | {{chunk:}}
46 |
47 | Question: List all relevant attributes about '{{topic:}}' that are exactly mentioned in this sample text if any.
48 | Answer:"""
49 | ]
50 |
51 |
52 | ############################ PROMPTS FOR EXTRACTING A SPECIFIC FIELD BY DIRECTLY GIVING THE MODEL THE CONTEXT ############################
53 | METADATA_EXTRACTION_WITH_LM = [
54 | f"""Here is a file sample:
55 |
56 | Location |
57 | Cupertino, CaliforniaSince 1987 |
58 |
59 | Question: Return the full "location" span of this sample if it exists, otherwise output [].
60 | Answer: ['Cupertino, California Since 1987']
61 |
62 | ----
63 |
64 | Here is a file sample:
65 |
66 | {{chunk:}}
67 |
68 | Question: Return the full "{{attribute:}}" span of this sample if it exists, otherwise output [].
69 | Answer:""",
70 | ]
71 |
72 |
73 | METADATA_EXTRACTION_WITH_LM_ZERO_SHOT = [
74 | f"""Sample text:
75 |
76 | {{chunk:}}
77 |
78 | Question: What is the "{{attribute:}}" value in the text?
79 | Answer:"""
80 | ]
81 |
82 | EXTRA_PROMPT = [
83 | f"""Here is a file sample:
84 |
85 |
86 |
87 | Question: Return the full "price" from this sample if it exists, otherwise output [].
88 | Answer: ['$550']
89 |
90 | ----
91 |
92 | Here is a file sample:
93 |
94 | {{chunk:}}
95 |
96 | Question: Return the full "{{attribute:}}" from this sample if it exists, otherwise output [].
97 | Answer:""",
98 | ]
99 |
100 | METADATA_EXTRACTION_WITH_LM_CONTEXT = [
101 | f"""Here is a file sample:
102 |
103 | A. 510(k) Number:
104 | k143467
105 |
106 | Question: Return the full "510(k) Number" from this sample if it exists and the context around it, otherwise output [].
107 | Answer: [510(k) Number: k143467]
108 |
109 | ----
110 |
111 | Here is a file sample:
112 |
113 | The iphone price increases a lot this there. Each iphone's price is as high as 1000$.
114 |
115 | Question: Return the full "price" from this sample if it exists and the context around it, otherwise output [].
116 | Answer: [Each iphone's price is as high as 1000$]
117 |
118 | ----
119 |
120 | Here is a file sample:
121 |
122 | {{chunk:}}
123 |
124 | Question: Return the full "{{attribute:}}" from this sample if it exists and the context around it, otherwise output [].
125 | Answer:""",
126 | ]
127 |
128 | IS_VALID_ATTRIBUTE = [
129 | f"""Question: Could "2014" be a "year" value in a "students" database?
130 | Answer: Yes
131 |
132 | ----
133 |
134 | Question: Could "cupcake" be a "occupation" value in a "employee" database?
135 | Answer: No
136 |
137 | ----
138 |
139 | Question: Could "''" be a "animal" value in a "zoo" database?
140 | Answer: No
141 |
142 | ----
143 |
144 | Question: Could "police officer" be a "occupation" value in a "employee" database?
145 | Answer: Yes
146 |
147 | ----
148 |
149 | Question: Could "{{value:}}" be a "{{attr_str:}}" value in a {{topic:}} database?
150 | Answer:"""
151 | ]
152 |
153 |
154 | PICK_VALUE = [
155 | f"""Examples:
156 | - 32
157 | - 2014
158 | - 99.4
159 | - 2012
160 |
161 | Question: Which example is a "year"?
162 | Answer: 2012, 2014
163 |
164 | ----
165 |
166 | Examples:
167 | - police officer
168 | - occupation
169 |
170 | Question: Which example is a "occupation"?
171 | Answer: police officer
172 |
173 | ----
174 |
175 | Examples:
176 | {{pred_str:}}
177 |
178 | Question: Which example is a "{{attribute:}}"?
179 | Answer:"""
180 | ]
181 |
182 |
183 | PICK_VALUE_CONTEXT = [
184 | f"""Here are file samples:
185 |
186 | -The purpose for submission is to obtain substantial equivalence determination for the illumigene HSV 1&2 DNA Amplification Assay.
187 | -The purpose for submission of this document is not specified in the provided sample.
188 | -The purpose for submission of this file is not specified.
189 |
190 | Question: Extract "the purpose for submission" from the right sample , otherwise output [].
191 | Answer: to obtain substantial equivalence determination for the illumigene HSV 1&2 DNA Amplification Assay
192 |
193 | ----
194 |
195 | Here are file samples:
196 |
197 | {{pred_str:}}
198 |
199 | Question: Return the full "{{attribute:}}" from this sample if it exists, otherwise output [].
200 | Answer:""",
201 | ]
202 |
203 |
204 |
205 | ############################## PROMPTS TO GENERATE FUNCTIONS THAT PARSE FOR A SPECIFIC FIELD ##############################
206 | METADATA_GENERATION_FOR_FIELDS = [
207 | # base prompt
208 | f"""Here is a sample of text:
209 |
210 | {{chunk:}}
211 |
212 |
213 | Question: Write a python function to extract the entire "{{attribute:}}" field from text, but not any other metadata. Return the result as a list.
214 |
215 |
216 | import re
217 |
218 | def get_{{function_field:}}_field(text: str):
219 | \"""
220 | Function to extract the "{{attribute:}} field".
221 | \"""
222 | """,
223 |
224 | # prompt with flexible library imports
225 | f"""Here is a file sample:
226 |
227 | DESCRIPTION: This file answers the question, "How do I sort a dictionary by value?"
228 | DATES MODIFIED: The file was modified on the following dates:
229 | 2009-03-05T00:49:05
230 | 2019-04-07T00:22:14
231 | 2011-11-20T04:21:49
232 | USERS: The users who modified the file are:
233 | Jeff Jacobs
234 | Richard Smith
235 | Julia D'Angelo
236 | Rebecca Matthews
237 | FILE TYPE: This is a text file.
238 |
239 | Question: Write a python function called "get_dates_modified_field" to extract the "DATES MODIFIED" field from the text. Include any imports.
240 |
241 | import re
242 |
243 | def get_dates_modified_field(text: str):
244 | \"""
245 | Function to extract the dates modified.
246 | \"""
247 | parts= text.split("USERS")[0].split("DATES MODIFIED")[-1]
248 | pattern = r'\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}'
249 | return re.findall(pattern, text)
250 |
251 | ----
252 |
253 | Here is a file sample:
254 |
255 | U.S. GDP Rose 2.9% in the Fourth Quarter After a Year of High Inflation - WSJ
256 |
257 |
258 |
259 |
260 |
261 | Question: Write a python function called "get_date_published_field" to extract the "datePublished" field from the text. Include any imports.
262 |
263 | from bs4 import BeautifulSoup
264 |
265 | def get_date_published_field(text: str):
266 | \"""
267 | Function to extract the date published.
268 | \"""
269 | soup = BeautifulSoup(text, parser="html.parser")
270 | date_published_field = soup.find('meta', itemprop="datePublished")
271 | date_published_field = date_published_field['content']
272 | return date_published_field
273 |
274 | ----
275 |
276 | Here is a sample of text:
277 |
278 | {{chunk:}}
279 |
280 | Question: Write a python function called "get_{{function_field:}}_field" to extract the "{{attribute:}}" field from the text. Include any imports."""
281 | ]
282 |
283 |
284 | class Step:
285 | def __init__(self, prompt) -> None:
286 | self.prompt = prompt
287 |
288 | def execute(self):
289 | pass
290 |
291 |
292 |
--------------------------------------------------------------------------------
/evaporate/schema_identification.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | import statistics
4 | import random
5 | from tqdm import tqdm
6 | from collections import Counter, defaultdict
7 | from typing import List, Dict, Tuple, Set
8 |
9 | from evaporate.prompts import Step, SCHEMA_ID_PROMPTS
10 | from evaporate.utils import apply_prompt
11 | from evaporate.profiler_utils import clean_metadata
12 |
13 |
14 | def directly_extract_from_chunks_w_value(
15 | file2chunks,
16 | sample_files,
17 | manifest_session,
18 | overwrite_cache=False,
19 | topic=None,
20 | use_dynamic_backoff=True,
21 | ):
22 | total_tokens_prompted = 0
23 | field2value = defaultdict(list)
24 | field2count = Counter()
25 | file2results = defaultdict()
26 | num_chunks_per_file = [len(file2chunks[file]) for file in file2chunks]
27 | avg_num_chunks_per_file = statistics.mean(num_chunks_per_file)
28 | stdev_num_chunks_per_file = statistics.stdev(num_chunks_per_file)
29 |
30 | for i, file in enumerate(sample_files):
31 | chunks = file2chunks[file]
32 | print(f"Chunks in sample file {file}: {len(chunks)}")
33 |
34 | for i, file in tqdm(
35 | enumerate(sample_files),
36 | total=len(sample_files),
37 | desc="Directly extracting metadata from chunks",
38 | ):
39 | chunks = file2chunks[file]
40 | extractionset = set()
41 | file_results = {}
42 | for chunk_num, chunk in enumerate(chunks):
43 | if (chunk_num > avg_num_chunks_per_file + stdev_num_chunks_per_file) and use_dynamic_backoff:
44 | break
45 | prompt_template = SCHEMA_ID_PROMPTS[0]
46 | prompt = prompt_template.format(chunk=chunk, topic=topic)
47 | try:
48 | result, num_toks = apply_prompt(
49 | Step(prompt),
50 | max_toks=500,
51 | manifest=manifest_session,
52 | overwrite_cache=overwrite_cache
53 | )
54 | except:
55 | print("Failed to apply prompt to chunk.")
56 | continue
57 | total_tokens_prompted += num_toks
58 | result = result.split("---")[0].strip("\n")
59 | results = result.split("\n")
60 | results = [r.strip("-").strip() for r in results]
61 | results = [r[2:].strip() if len(r) > 2 and r[1] == "." else r for r in results ]
62 | for result in results:
63 | try:
64 | field = result.split(": ")[0].strip(":")
65 | value = ": ".join(result.split(": ")[1:])
66 | except:
67 | print(f"Skipped: {result}")
68 | continue
69 | field_versions = [
70 | field,
71 | field.replace(" ", ""),
72 | field.replace("-", ""),
73 | field.replace("_", ""),
74 | ]
75 | if not any([f.lower() in chunk.lower() for f in field_versions]) and use_dynamic_backoff:
76 | continue
77 | if not value and use_dynamic_backoff:
78 | continue
79 | field = field.lower().strip("-").strip("_").strip(" ").strip(":")
80 | if field in extractionset and use_dynamic_backoff:
81 | continue
82 | field2value[field].append(value)
83 | extractionset.add(field)
84 | field2count[field] += 1
85 | file_results[field] = value
86 | file2results[file] = file_results
87 | return field2value, field2count, total_tokens_prompted
88 |
89 |
90 | def get_metadata_string_w_value(field2value, exclude=[], key=0):
91 | field2num_extractions = Counter()
92 | for field, values in field2value.items():
93 | field2num_extractions[field] += len(values)
94 |
95 | reranked_metadata = {}
96 | try:
97 | max_count = field2num_extractions.most_common(1)[0][1]
98 | except:
99 | return ''
100 | fields = []
101 | sort_field2num_extractions = sorted(
102 | field2num_extractions.most_common(),
103 | key=lambda x: (x[1], x[0]),
104 | reverse=True
105 | )
106 | for item in sort_field2num_extractions:
107 | field, count = item[0], item[1]
108 | if field.lower() in exclude:
109 | continue
110 | if count == 1 and max_count > 1:
111 | continue
112 | idx = min(key, len(field2value[field]) - 1)
113 | values = [field2value[field][idx]]
114 | if idx < len(field2value[field]) - 1:
115 | values.append(field2value[field][idx + 1])
116 | reranked_metadata[field] = values
117 | if len(reranked_metadata) > 200:
118 | break
119 | fields.append(field)
120 |
121 | random.seed(key)
122 | keys=reranked_metadata.keys()
123 | random.shuffle(list(keys))
124 | reordered_dict = {}
125 | for key in keys:
126 | reordered_dict[key] = reranked_metadata[key]
127 | reranked_metadata_str = str(reordered_dict)
128 | return reranked_metadata_str
129 |
130 |
131 | def rerank(
132 | field2value, exclude, cleaned_counter, order_of_addition, base_extraction_count,
133 | most_in_context_example, topic, manifest_session, overwrite_cache=False
134 | ):
135 | total_tokens_prompted = 0
136 | votes_round1 = Counter()
137 | for i in range(3):
138 | reranked_metadata_str = get_metadata_string_w_value(field2value, exclude=exclude, key=i)
139 | if not reranked_metadata_str:
140 | continue
141 |
142 | prompt = \
143 | f"""{most_in_context_example}Attributes:
144 | {reranked_metadata_str}
145 |
146 | List the most useful keys to include in a SQL database about "{topic}", if any.
147 | Answer:"""
148 | try:
149 | result, num_toks = apply_prompt(
150 | Step(prompt),
151 | max_toks=500,
152 | manifest=manifest_session,
153 | overwrite_cache=overwrite_cache,
154 | )
155 | except:
156 | print("Failed to apply prompt")
157 | continue
158 | total_tokens_prompted += num_toks
159 | result = result.split("---")[0].strip("\n")
160 | results = result.split("\n")
161 | result = results[0].replace("[", "").replace("]", "").replace("'", "").replace('"', '')
162 | result = result.split(", ")
163 | result = [r.lower() for r in result]
164 |
165 | indices = [idx for idx, r in enumerate(result) if not r]
166 | if result and indices:
167 | result = result[:indices[0]]
168 |
169 | # Deduplicate but preserve order
170 | result = list(dict.fromkeys(result))
171 | for r in result:
172 | r = r.strip("_").strip("-")
173 | r = r.strip("'").strip('"').strip()
174 | if not r or r in exclude or r not in base_extraction_count:
175 | continue
176 | votes_round1[r] += 2
177 |
178 | fields = sorted(list(votes_round1.keys()))
179 | for r in fields:
180 | r = r.strip("_").strip("-")
181 | r = r.strip("'").strip('"').strip()
182 | if not r or r in exclude or r not in base_extraction_count:
183 | continue
184 | if votes_round1[r] > 1:
185 | cleaned_counter[r] = votes_round1[r] * base_extraction_count[r]
186 | order_of_addition.append(r)
187 | else:
188 | cleaned_counter[r] = base_extraction_count[r]
189 | order_of_addition.append(r)
190 | exclude.append(r)
191 |
192 | return cleaned_counter, order_of_addition, exclude, total_tokens_prompted
193 |
194 |
195 | def rerank_metadata(
196 | base_extraction_count, field2value, topic, manifest_session, overwrite_cache
197 | ):
198 |
199 | most_in_context_example = \
200 | """Attributes:
201 | {'name': 'Jessica', 'student major': 'Computer Science', 'liscense': 'accredited', 'college name': 'University of Michigan', ''GPA': '3.9', 'student email': 'jess@umich.edu', 'rating': '42', 'title': 'details'}
202 |
203 | List the most useful keys to include in a SQL database for "students", if any.
204 | Answer: ['name', 'student major', 'college name', 'GPA', 'student email']
205 |
206 | ----
207 |
208 | """
209 |
210 | total_tokens_prompted = 0
211 | cleaned_counter = Counter()
212 | exclude = []
213 | order_of_addition = []
214 |
215 | cleaned_counter, order_of_addition, exclude, total_tokens_prompted = rerank(
216 | field2value, exclude, cleaned_counter, order_of_addition, base_extraction_count,
217 | most_in_context_example, topic, manifest_session, overwrite_cache=overwrite_cache
218 | )
219 |
220 | cleaned_counter, order_of_addition, exclude, total_tokens_prompted = rerank(
221 | field2value, exclude, cleaned_counter, order_of_addition, base_extraction_count,
222 | most_in_context_example, topic, manifest_session, overwrite_cache=overwrite_cache
223 | )
224 |
225 | fields = sorted(list(base_extraction_count.keys()))
226 | for field in fields:
227 | if field not in cleaned_counter:
228 | cleaned_counter[field] = base_extraction_count[field] / 2
229 | order_of_addition.append(field)
230 | return cleaned_counter, total_tokens_prompted, order_of_addition
231 |
232 |
233 | #################### SAVE GENERATIVE INDEX OF FILE BASED METADATA #########################
234 | def identify_schema(run_string, args, file2chunks: Dict, file2contents: Dict, sample_files: List, manifest_sessions: Dict, group_name: str, profiler_args):
235 | # get sample and eval files, convert the sample scripts to chunks
236 | random.seed(0)
237 | total_tokens_prompted = 0
238 |
239 | field2value, extract_w_value, num_toks = directly_extract_from_chunks_w_value(
240 | file2chunks,
241 | sample_files,
242 | manifest_sessions[profiler_args.GOLD_KEY],
243 | overwrite_cache=profiler_args.overwrite_cache,
244 | topic=args.topic,
245 | use_dynamic_backoff=profiler_args.use_dynamic_backoff,
246 | )
247 | total_tokens_prompted += num_toks
248 |
249 | base_extraction_count, num_toks, order_of_addition = rerank_metadata(
250 | extract_w_value,
251 | field2value,
252 | args.topic,
253 | manifest_sessions[profiler_args.GOLD_KEY],
254 | profiler_args.overwrite_cache,
255 | )
256 | total_tokens_prompted += num_toks
257 |
258 | with open(f"{args.generative_index_path}/{run_string}_identified_schema.json", "w") as f:
259 | json.dump(base_extraction_count, f)
260 |
261 | with open(f"{args.generative_index_path}/{run_string}_order_of_addition.json", "w") as f:
262 | json.dump(order_of_addition, f)
263 |
264 | return total_tokens_prompted
265 |
--------------------------------------------------------------------------------
/evaporate/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from collections import Counter, defaultdict
4 |
5 | from manifest import Manifest
6 | from evaporate.configs import get_args
7 | from evaporate.prompts import Step
8 | from openai import OpenAI
9 |
10 | cur_idx = 0
11 | TOGETHER_API_KEY = os.environ.get("TOGETHER_API_KEY")
12 | #If using together AI, you will need to set the TOGETHER_API_KEY to your API key
13 |
14 |
15 | def together_call(prompt, model, streaming = False, max_tokens = 1024):
16 | client = OpenAI(
17 | api_key=TOGETHER_API_KEY,
18 | base_url='https://api.together.xyz',
19 |
20 | )
21 | messages = [{
22 | "role": "system",
23 | "content": "You are an AI assistant",
24 | }, {
25 | "role": "user",
26 | "content": prompt,
27 | }]
28 | chat_completion = client.chat.completions.create(messages=messages,
29 | model=model,
30 | max_tokens=max_tokens,
31 | #response_format={ "type": "json_object" },
32 | stream=streaming)
33 | response = chat_completion.choices[0].message.content
34 | return response
35 |
36 | def apply_prompt(step : Step, max_toks = 50, do_print=False, manifest=None, overwrite_cache=False):
37 | global cur_idx
38 | manifest_lst = manifest.copy()
39 | if len(manifest) == 1:
40 | manifest = manifest_lst[0]
41 | else:
42 | manifest = manifest_lst[cur_idx]
43 |
44 | # sometimes we want to rotate keys
45 | cur_idx = cur_idx + 1
46 | if cur_idx >= len(manifest_lst)-1:
47 | cur_idx = 0
48 |
49 | prompt = step.prompt
50 | response, num_tokens = get_response(
51 | prompt,
52 | manifest,
53 | max_toks = max_toks,
54 | overwrite=overwrite_cache,
55 | stop_token="---"
56 | )
57 | step.response = response
58 | if do_print:
59 | print(response)
60 | return response, num_tokens
61 |
62 |
63 | def get_file_attribute(attribute):
64 | attribute = attribute.lower()
65 | attribute = attribute.replace("/", "_").replace(")", "").replace("-", "_")
66 | attribute = attribute.replace("(", "").replace(" ", "_")
67 | if len(attribute) > 30:
68 | attribute = attribute[:30]
69 | return attribute
70 |
71 |
72 | def get_all_files(data_dir):
73 | files = []
74 | for file in os.listdir(data_dir):
75 | if os.path.isfile(os.path.join(data_dir, file)):
76 | files.append(os.path.join(data_dir, file))
77 | else:
78 | files.extend(get_all_files(os.path.join(data_dir, file)))
79 | return files
80 |
81 |
82 | def get_directory_hierarchy(data_dir):
83 | if not data_dir.endswith("/") and os.path.isdir(data_dir):
84 | data_dir = data_dir + "/"
85 | directories2subdirs = defaultdict(list)
86 | for file in os.listdir(data_dir):
87 | new_dir = os.path.join(data_dir, file)
88 | if not new_dir.endswith("/") and os.path.isdir(new_dir):
89 | new_dir = new_dir + "/"
90 | if os.path.isdir(new_dir):
91 | directories2subdirs[data_dir].append(new_dir)
92 | if os.listdir(new_dir):
93 | more_subdirs = get_directory_hierarchy(new_dir)
94 | for k, v in more_subdirs.items():
95 | directories2subdirs[k].extend(v)
96 | else:
97 | directories2subdirs[new_dir] = []
98 | else:
99 | directories2subdirs[data_dir].append(new_dir)
100 | return directories2subdirs
101 |
102 |
103 | def get_unique_file_types(files):
104 | suffix2file = {}
105 | suffix2count = Counter()
106 | for file in files:
107 | suffix = file.split(".")[-1]
108 | if not suffix:
109 | suffix = "txt"
110 | suffix2count[suffix] += 1
111 | if suffix not in suffix2file:
112 | suffix2file[suffix] = file
113 | return suffix2file, suffix2count
114 |
115 |
116 | def get_structure(dataset_name, profiler_args):
117 | args = get_args(profiler_args)
118 | if not os.path.exists(args.cache_dir):
119 | os.makedirs(args.cache_dir)
120 |
121 | if not os.path.exists(args.generative_index_path):
122 | os.makedirs(args.generative_index_path)
123 |
124 | if not os.path.exists(args.generative_index_path):
125 | os.makedirs(args.generative_index_path)
126 |
127 | # all files
128 | cache_path = f"{args.cache_dir}/all_files.json"
129 | if not os.path.exists(cache_path) or args.overwrite_cache:
130 | files = get_all_files(args.data_dir)
131 | with open(cache_path, "w") as f:
132 | json.dump(files, f)
133 | else:
134 | with open(cache_path) as f:
135 | files = json.load(f)
136 |
137 | # all directories
138 | cache_path = f"{args.cache_dir}/all_dirs.json"
139 | if not os.path.exists(cache_path) or args.overwrite_cache:
140 | directory_hierarchy = get_directory_hierarchy(args.data_dir)
141 | with open(cache_path, "w") as f:
142 | json.dump(directory_hierarchy, f)
143 | else:
144 | with open(cache_path) as f:
145 | directory_hierarchy = json.load(f)
146 |
147 | suffix2file, suffix2count = get_unique_file_types(files)
148 | file_examples = "\n".join(list(suffix2file.values()))
149 | file_types = ", ".join((suffix2file.keys()))
150 | return directory_hierarchy, files, file_examples, file_types, args
151 |
152 |
153 | def get_files_in_group(dir_path):
154 | file_group = []
155 | for i, (root,dirs,files) in enumerate(os.walk(dir_path, topdown=True)):
156 | files = [f"{root}/{f}" for f in files]
157 | file_group.extend(files)
158 | print(f"Working with a sample size of : {len(file_group)} files.")
159 | return file_group
160 |
161 |
162 | # MANIFEST
163 | def get_manifest_sessions(MODELS, MODEL2URL=None, KEYS=[]):
164 | manifest_sessions = defaultdict(list)
165 | for model in MODELS:
166 | if any(kwd in model for kwd in ["davinci", "curie", "babbage", "ada", "cushman"]):
167 | if not KEYS:
168 | raise ValueError("You must provide a list of keys to use these models.")
169 | for key in KEYS:
170 | manifest, model_name = get_manifest_session(
171 | client_name="openai",
172 | client_engine=model,
173 | client_connection=key,
174 | )
175 | manifest_sessions[model].append(manifest)
176 | elif any(kwd in model for kwd in ["gpt-4", "gpt-3.5"]):
177 | if not KEYS:
178 | raise ValueError("You must provide a list of keys to use these models.")
179 | for key in KEYS:
180 | manifest, model_name = get_manifest_session(
181 | client_name="openaichat",
182 | client_engine=model,
183 | client_connection=key,
184 | )
185 | manifest_sessions[model].append(manifest)
186 | else:
187 | if(model not in MODEL2URL):
188 | manifest = {}
189 | manifest["__name"] = model
190 | print("using together AI")
191 | else:
192 | print("using huggingface")
193 | manifest, model_name = get_manifest_session(
194 | client_name="huggingface",
195 | client_engine=model,
196 | client_connection=MODEL2URL[model],
197 | )
198 | manifest_sessions[model].append(manifest)
199 | return manifest_sessions
200 |
201 |
202 | def get_manifest_session(
203 | client_name="huggingface",
204 | client_engine=None,
205 | client_connection="http://127.0.0.1:5000",
206 | cache_connection=None,
207 | temperature=0,
208 | top_p=1.0,
209 | ):
210 | if client_name == "huggingface" and temperature == 0:
211 | params = {
212 | "temperature": 0.001,
213 | "do_sample": False,
214 | "top_p": top_p,
215 | }
216 | elif client_name in {"openai", "ai21", "openaichat"}:
217 | params = {
218 | "temperature": temperature,
219 | "top_p": top_p,
220 | "engine": client_engine,
221 | }
222 | else:
223 | raise ValueError(f"{client_name} is not a valid client name")
224 |
225 | cache_params = {
226 | "cache_name": "sqlite",
227 | "cache_connection": cache_connection,
228 | }
229 |
230 | manifest = Manifest(
231 | client_name=client_name,
232 | client_connection=client_connection,
233 | **params,
234 | **cache_params,
235 | )
236 |
237 | params = manifest.client_pool.get_current_client().get_model_params()
238 | model_name = params["model_name"]
239 | if "engine" in params:
240 | model_name += f"_{params['engine']}"
241 | return manifest, model_name
242 |
243 |
244 | def get_response(
245 | prompt,
246 | manifest,
247 | overwrite=False,
248 | max_toks=10,
249 | stop_token=None,
250 | gold_choices=[],
251 | verbose=False,
252 | ):
253 | prompt = prompt.strip()
254 | if gold_choices:
255 | gold_choices = [" " + g.strip() for g in gold_choices]
256 | if type(manifest) == dict and manifest["__name"] != "openai":
257 | response = together_call(prompt, manifest["__name"])
258 | num_tokens = 0
259 | else:
260 | response_obj = manifest.run(
261 | prompt,
262 | gold_choices=gold_choices,
263 | overwrite_cache=overwrite,
264 | return_response=True,
265 | )
266 | response_obj = response_obj.get_json_response()["choices"][0]
267 | log_prob = response_obj["text_logprob"]
268 | response = response_obj["text"]
269 | num_tokens = response_obj['usage']['total_tokens']
270 | else:
271 | if type(manifest) == dict and manifest["__name"] != "openai":
272 | response = together_call(prompt, manifest["__name"])
273 | num_tokens = 0
274 | else:
275 | response_obj = manifest.run(
276 | prompt,
277 | max_tokens=max_toks,
278 | stop_token=stop_token,
279 | overwrite_cache=overwrite,
280 | return_response=True
281 | )
282 | num_tokens = -1
283 | try:
284 | num_tokens = response_obj.get_usage_obj().usages[0].total_tokens
285 | except:
286 | num_tokens = 0
287 | print("Fail to get total tokens used")
288 | response_obj = response_obj.get_json_response()
289 | response = response_obj["choices"][0]["text"]
290 | stop_token = "---"
291 | response = response.strip().split(stop_token)[0].strip() if stop_token else response.strip()
292 | log_prob = None
293 | if verbose:
294 | print("\n***Prompt***\n", prompt)
295 | print("\n***Response***\n", response)
296 | if log_prob:
297 | return response, log_prob
298 | return response, num_tokens
299 |
300 |
--------------------------------------------------------------------------------
/evaporate/evaluate_profiler.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict, Counter
2 | import numpy as np
3 | from evaporate.prompts import (PICK_VALUE_CONTEXT, Step,)
4 | from evaporate.utils import apply_prompt
5 |
6 |
7 | def clean_comparison(responses, field):
8 | clean_responses = []
9 | if type(responses) == str:
10 | responses = [responses]
11 | for response in responses:
12 | response = response.lower()
13 | field = field.lower()
14 | field_reformat = field.replace("_", "-")
15 |
16 | for char in ["'", field, field_reformat, ":", "<", ">", '"', "none"]:
17 | response = response.replace(char, " ")
18 | for char in [",", ".", "?", "!", ";", "(", ")", "[", "]", "{", "}", "-", "none", "\n", "\t", "\r"]:
19 | response = response.replace(char, " ")
20 | response = response.replace(" ", " ")
21 | response = response.split()
22 | response = [r.strip() for r in response]
23 | response = [r for r in response if r]
24 | response = ' '.join(response)
25 | clean_responses.append(response)
26 | clean_responses = ", ".join(clean_responses)
27 | return clean_responses
28 |
29 |
30 | def normalize_value_type(metadata, attribute):
31 | # make everything a list of strings since functions can return diverse types
32 | cleaned_items = []
33 | if type(metadata) == str:
34 | metadata = [metadata]
35 | for item in metadata:
36 | if type(item) == list:
37 | item = [str(i) for i in item]
38 | item = ", ".join(item)
39 | elif type(item) == tuple:
40 | item = list(item)
41 | item = [str(i) for i in item]
42 | item = ", ".join(item)
43 | elif item is None:
44 | item = ''
45 | elif type(item) != str:
46 | item = [str(item)]
47 | item = ", ".join(item)
48 | if item:
49 | cleaned_items.append(item)
50 | return cleaned_items
51 |
52 |
53 | def pick_a_gold_label(golds, attribute="", manifest_session=None, overwrite_cache=False):
54 | """
55 | To counteract the large model hallucinating on various chunks affecting the evaluation of good functions.
56 | """
57 |
58 | pred_str = "- " + "\n- ".join(golds)
59 |
60 | prompt_template = PICK_VALUE_CONTEXT[0]
61 | prompt = prompt_template.format(pred_str=pred_str, attribute=attribute)
62 | try:
63 | check, num_toks = apply_prompt(
64 | Step(prompt),
65 | max_toks=100,
66 | manifest=manifest_session,
67 | overwrite_cache=overwrite_cache
68 | )
69 | except:
70 | return golds, 0
71 | check = check.split("\n")
72 | check = [c for c in check if c]
73 | if check:
74 | if "none" in check[0].lower():
75 | check = golds
76 | else:
77 | check = check[0]
78 | return check, num_toks
79 |
80 |
81 | def text_f1(
82 | preds=[],
83 | golds=[],
84 | extraction_fraction=1.0,
85 | attribute=None,
86 | extraction_fraction_thresh=0.8,
87 | use_abstension=True,
88 | ):
89 | """Compute average F1 of text spans.
90 | Taken from Squad without prob threshold for no answer.
91 | """
92 | total_f1 = 0
93 | total_recall = 0
94 | total_prec = 0
95 | f1s = []
96 | total = 0
97 |
98 | if extraction_fraction >= extraction_fraction_thresh and use_abstension:
99 | new_preds = []
100 | new_golds = []
101 | for pred, gold in zip(preds, golds):
102 | if pred:
103 | new_preds.append(pred)
104 | new_golds.append(gold)
105 | preds = new_preds
106 | golds = new_golds
107 | if not preds:
108 | return 0.0, 0.0
109 | for pred, gold in zip(preds, golds):
110 | if type(pred) == str:
111 | pred_toks = pred.split()
112 | else:
113 | pred_toks = pred
114 | if type(gold) == str:
115 | gold_toks_list = [gold.split()]
116 | else:
117 | assert 0, print(gold)
118 | gold_toks_list = gold
119 |
120 | if type(gold_toks_list) == list and gold_toks_list:
121 | for gold_toks in gold_toks_list:
122 |
123 | # If both lists are lenght 1, split to account for example like:
124 | # ["a b"], ["a"] -> ["a","b"], ["a"]
125 | if len(gold_toks) == 1 and len(pred_toks) == 1:
126 | gold_toks = gold_toks[0].split()
127 | pred_toks = pred_toks[0].split()
128 |
129 | common = Counter(pred_toks) & Counter(gold_toks)
130 | num_same = sum(common.values())
131 | if len(gold_toks) == 0 or len(pred_toks) == 0:
132 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
133 | total_f1 += int(gold_toks == pred_toks)
134 | f1s.append(int(gold_toks == pred_toks))
135 | total_recall += int(gold_toks == pred_toks)
136 | elif num_same == 0:
137 | total_f1 += 0
138 | f1s.append(0)
139 | else:
140 | precision = 1.0 * num_same / len(pred_toks)
141 | recall = 1.0 * num_same / len(gold_toks)
142 | f1 = (2 * precision * recall) / (precision + recall)
143 | total_f1 += f1
144 | total_recall += recall
145 | total_prec += precision
146 | f1s.append(f1)
147 |
148 | total += 1
149 | if not total:
150 | return 0.0, 0.0
151 | f1_avg = total_f1 / total
152 | f1_median = np.percentile(f1s, 50)
153 | return f1_avg, f1_median
154 |
155 |
156 | def evaluate(
157 | all_extractions:list,
158 | gold_key:str,
159 | field:str,
160 | manifest_session=None,
161 | overwrite_cache=False,
162 | combiner_mode='mv',
163 | extraction_fraction_thresh=0.8,
164 | use_abstension=True,
165 | ):
166 | normalized_field_name = field
167 | for char in ["'", ":", "<", ">", '"', "_", "-", " ", "none"]:
168 | normalized_field_name = normalized_field_name.replace(char, "")
169 |
170 | key2golds = defaultdict(list)
171 | key2preds = defaultdict(list)
172 | total_tokens_prompted = 0
173 |
174 | # handle FM golds on D_eval
175 | gold_file2metadata = all_extractions[gold_key]
176 | cleaned_gold_metadata = {}
177 | for filepath, gold_metadata in gold_file2metadata.items():
178 | gold_metadata = normalize_value_type(gold_metadata, field)
179 | if len(gold_metadata) > 1:
180 | gold_metadata, num_toks = pick_a_gold_label(
181 | gold_metadata,
182 | attribute=field,
183 | manifest_session=manifest_session,
184 | overwrite_cache=overwrite_cache
185 | )
186 | total_tokens_prompted += num_toks
187 | gold_metadata = clean_comparison(gold_metadata, field)
188 | cleaned_gold_metadata[filepath] = gold_metadata
189 | # handle function preds on D_eval
190 | for i, (key, file2metadata) in enumerate(all_extractions.items()):
191 | if key == gold_key:
192 | continue
193 | for filepath, metadata in file2metadata.items():
194 | gold_metadata = cleaned_gold_metadata[filepath]
195 | pred_metadata = normalize_value_type(metadata, field)
196 | pred_metadata = clean_comparison(pred_metadata, field)
197 | key2golds[key].append(gold_metadata)
198 | key2preds[key].append(pred_metadata)
199 |
200 | # Handling abstensions
201 |
202 | metrics = {}
203 | for key, golds in key2golds.items():
204 | num_extractions = 0
205 | for golds in key2golds[key]:
206 | if golds and not any(golds.lower() == wd for wd in ['none']):
207 | num_extractions += 1
208 | extraction_fraction = float(num_extractions) / float(len(key2golds[key]))
209 | if combiner_mode == "top_k":
210 | # Don't use the extraction fraction in the naive setting for scoring
211 | extraction_fraction = 0.0
212 | #print(f"Extraction fraction: {extraction_fraction}")
213 | preds = key2preds[key]
214 | f1, f1_med = text_f1(
215 | preds, golds,
216 | extraction_fraction=extraction_fraction,
217 | attribute=field,
218 | extraction_fraction_thresh=extraction_fraction_thresh,
219 | use_abstension=use_abstension,
220 | )
221 | priorf1, priorf1_med = text_f1(preds, golds, extraction_fraction=0.0, attribute=field)
222 | metrics[key] = {
223 | "average_f1": f1,
224 | "median_f1": f1_med,
225 | "extraction_fraction": extraction_fraction,
226 | "prior_average_f1": priorf1,
227 | "prior_median_f1": priorf1_med,
228 | }
229 |
230 | return metrics, key2golds, total_tokens_prompted
231 |
232 |
233 | def get_topk_scripts_per_field(
234 | script2metrics,
235 | function_dictionary,
236 | all_extractions,
237 | gold_key='',
238 | k=3,
239 | do_end_to_end=False,
240 | keep_thresh = 0.5,
241 | cost_thresh = 1,
242 | combiner_mode='mv',
243 | ):
244 | script2avg = dict(
245 | sorted(script2metrics.items(),
246 | reverse=True,
247 | key=lambda x: (x[1]['average_f1'], x[1]['median_f1']))
248 | )
249 |
250 | top_k_scripts = [k for k, v in script2avg.items() if k != gold_key]
251 | top_k_values = [
252 | max(v['average_f1'], v['median_f1']) for k, v in script2avg.items() if k != gold_key
253 | ]
254 | if not top_k_values:
255 | return []
256 |
257 | best_value = top_k_values[0]
258 | best_script = top_k_scripts[0]
259 | if best_value < keep_thresh and do_end_to_end:
260 | return []
261 |
262 | filtered_fn_scripts = {
263 | k:v for k, v in script2metrics.items() if (
264 | v['average_f1'] >= keep_thresh or v['median_f1'] >= keep_thresh
265 | ) and "function" in k
266 | }
267 | top_k_fns = []
268 | num_fns = 0
269 | if filtered_fn_scripts:
270 | script2avg = dict(
271 | sorted(filtered_fn_scripts.items(),
272 | reverse=True,
273 | key=lambda x: (x[1]['average_f1'], x[1]['median_f1']))
274 | )
275 |
276 | top_k_fns = [
277 | k for k, v in script2avg.items() if k != gold_key and abs(
278 | max(v['average_f1'], v['median_f1'])-best_value
279 | ) < cost_thresh
280 | ]
281 | num_fns = len(top_k_fns)
282 |
283 | if num_fns:
284 | top_k_scripts = top_k_scripts[0:min(k, num_fns)]
285 | else:
286 | return []
287 |
288 | # construct final set of functions
289 | final_set = []
290 | for key in top_k_scripts:
291 | if key in top_k_fns:
292 | final_set.append(key)
293 |
294 | if len(final_set) > k:
295 | final_set = final_set[:k]
296 | if not final_set and not do_end_to_end:
297 | return [top_k_scripts[0]]
298 |
299 | # print results
300 | # print(f"Top {k} scripts:")
301 | # for script in final_set:
302 | # print(f"- {script}; Score: {script2metrics[script]}")
303 | print(f"Best script overall: {best_script}; Score: {script2metrics[best_script]}")
304 | return final_set
305 |
306 |
--------------------------------------------------------------------------------
/evaporate/main.py:
--------------------------------------------------------------------------------
1 | import time
2 | from tqdm import tqdm
3 | from evaporate.run_profiler import prerun_profiler, identify_attributes, get_attribute_function
4 | from evaporate.profiler import get_model_extractions
5 | from evaporate.configs import set_profiler_args
6 | from evaporate.evaluate_synthetic_utils import text_f1, get_file_attribute
7 | from evaporate.evaluate_profiler import pick_a_gold_label, evaluate, get_topk_scripts_per_field
8 | from evaporate.retrieval import get_most_similarity
9 | from evaporate.profiler import get_functions, apply_final_profiling_functions,apply_final_ensemble,combine_extractions
10 | import os
11 | import json
12 | from collections import defaultdict, Counter
13 | import numpy as np
14 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
15 |
16 |
17 | class EvaporateData:
18 | def __init__(self, profiler_args):
19 | self.GOLD_MODEL = profiler_args["direct_extract_model"]
20 | profiler_args["GOLD_KEY"] = "gold_extraction_file"
21 | self.GOLD_KEY = "gold_extraction_file"
22 |
23 | self.profiler_args= set_profiler_args(profiler_args)
24 | self.data_dict = prerun_profiler(self.profiler_args)
25 | self.runtime = {}
26 | self.token_used = {}
27 | self.accuracy = {}
28 | self.attributes = []
29 | self.direct_result = {}
30 | self.function_dictionary = {}
31 | self.all_extractions = {}
32 | self.manifest_sessions = self.data_dict["manifest_sessions"]
33 | self.selected_func_key = {}
34 | self.extract_result = {}
35 | self.all_metrics = None
36 | self.gold_extractions = self.data_dict["gold_extractions"]
37 |
38 | def save_results():
39 | pass
40 |
41 | def get_attribute(self, do_end_to_end = False):
42 | if do_end_to_end or self.profiler_args.do_end_to_end:
43 | self.attributes, total_time, num_toks, evaluation_result = identify_attributes(self.profiler_args, self.data_dict, evaluation=True)
44 | self.runtime["get_attribute"] = total_time
45 | self.token_used["get_attribute"] = num_toks
46 | self.accuracy["get_attribute"] = evaluation_result
47 | else:
48 | self.attributes = self.data_dict["gold_attributes"]
49 | return self.attributes
50 |
51 |
52 | def direct_extract(self, use_retrieval_model = True, is_getting_sample = False, gold = ""):
53 | if(self.attributes == []):
54 | print("Please run get_attribute first")
55 | return
56 | files = list(self.data_dict["file2chunks"].keys())
57 | if(is_getting_sample):
58 | files = self.data_dict["sample_files"]
59 | time_begin = time.time()
60 | token_used = 0
61 | for attribute in self.attributes:
62 | if(attribute in self.direct_result):
63 | print("already extract ", attribute)
64 | #continue
65 | new_file_chunk_dict = {}
66 | if(use_retrieval_model):
67 | baseline_sentence = attribute + ":"+ gold[attribute]
68 | for file in files:
69 | sentences = self.data_dict["file2chunks"][file]
70 | new_file_chunk_dict[file] = [get_most_similarity(baseline_sentence, sentences)]
71 | else:
72 | new_file_chunk_dict = self.data_dict["file2chunks"]
73 | extractions, num_toks, errored_out = get_model_extractions(
74 | new_file_chunk_dict,
75 | files,
76 | attribute,
77 | self.manifest_sessions[self.GOLD_MODEL],
78 | self.GOLD_MODEL,
79 | overwrite_cache=self.profiler_args.overwrite_cache,
80 | collecting_preds=True,
81 | )
82 | token_used += num_toks
83 | self.direct_result[attribute] = {}
84 | for file in extractions:
85 | golds = []
86 | for tmp in extractions[file]:
87 | golds.append( "- " + "\n- ".join(tmp))
88 | golds = "- " + "\n- ".join(golds)
89 | if(use_retrieval_model):
90 | try:
91 | self.direct_result[attribute][file] = extractions[file][0]
92 | except:
93 | print("error in ", attribute, file, extractions[file])
94 | else:
95 | self.direct_result[attribute][file] = pick_a_gold_label(golds, attribute, self.manifest_session)
96 | print("finish extract ", attribute)
97 | self.runtime["direct_extract"] = time.time() - time_begin
98 | self.token_used["direct_extract"] = token_used
99 | return self.direct_result, self.evaluate(self.direct_result)
100 |
101 | def get_extract_functions(self):
102 | self.runtime["get_extract_functions"] = 0
103 | self.token_used["get_extract_functions"] = 0
104 | begin_time = time.time()
105 | total_tokens_prompted = 0
106 | for attribute in self.attributes:
107 | if attribute in self.function_dictionary:
108 | print("already generate ", attribute, " function")
109 | continue
110 | self.all_extractions[attribute] = {}
111 | self.function_dictionary[attribute] = {}
112 | for model in self.profiler_args.EXTRACTION_MODELS:
113 | manifest_session = self.manifest_sessions[model]
114 | functions, function_promptsource, num_toks = get_functions(
115 | self.data_dict["file2chunks"],
116 | self.data_dict["sample_files"],
117 | {},
118 | attribute,
119 | manifest_session,
120 | overwrite_cache=self.profiler_args.overwrite_cache,
121 | )
122 | total_tokens_prompted += num_toks
123 | for fn_key, fn in functions.items():
124 | self.all_extractions[attribute][fn_key], num_function_errors = apply_final_profiling_functions(
125 | self.data_dict["file2contents"],
126 | self.data_dict["sample_files"],
127 | fn,
128 | attribute,
129 | )
130 | self.function_dictionary[attribute][fn_key] = {}
131 | self.function_dictionary[attribute][fn_key]['function'] = fn
132 | self.function_dictionary[attribute][fn_key]['promptsource'] = function_promptsource[fn_key]
133 | self.function_dictionary[attribute][fn_key]['extract_model'] = model
134 | self.runtime["get_extract_functions"] = time.time() - begin_time
135 | self.token_used["get_extract_functions"] = total_tokens_prompted
136 | return self.function_dictionary
137 |
138 | def weak_supervision(self, use_gold_key = False):
139 | result = self.direct_extract
140 | self.runtime["weak_supervision"] = 0
141 | self.token_used["weak_supervision"] = 0
142 | begin_time = time.time()
143 | total_tokens_prompted = 0
144 | for attribute in self.attributes:
145 | if attribute in self.selected_func_key:
146 | print("already weak supervision ", attribute)
147 | continue
148 | self.all_extractions[attribute]["gold_extraction_file"] = self.gold_extractions
149 | if(use_gold_key):
150 | self.GOLD_KEY = "gold_extraction_file"
151 | else:
152 | if(result == {}):
153 | print("Please run direct_extract first")
154 | return
155 | self.all_extractions[attribute]["gold-key"] = result[attribute]
156 | self.GOLD_KEY = "gold-key"
157 | self.all_metrics, key2golds, num_toks = evaluate(
158 | self.all_extractions[attribute],
159 | self.GOLD_KEY,
160 | field=attribute,
161 | manifest_session=self.manifest_sessions[self.profiler_args.GOLD_KEY],
162 | overwrite_cache=self.profiler_args.overwrite_cache,
163 | combiner_mode=self.profiler_args.combiner_mode,
164 | extraction_fraction_thresh=self.profiler_args.extraction_fraction_thresh,
165 | use_abstension=self.profiler_args.use_abstension,
166 | )
167 | total_tokens_prompted += num_toks
168 | selected_keys = get_topk_scripts_per_field(
169 | self.all_metrics,
170 | self.function_dictionary[attribute],
171 | self.all_extractions,
172 | self.GOLD_KEY,
173 | k=self.profiler_args.num_top_k_scripts,
174 | do_end_to_end=self.profiler_args.do_end_to_end,
175 | combiner_mode=self.profiler_args.combiner_mode,
176 | keep_thresh = 0.00
177 | )
178 | self.selected_func_key[attribute] = selected_keys
179 | self.runtime["weak_supervision"] = time.time() - begin_time
180 | self.token_used["weak_supervision"] = total_tokens_prompted
181 | return self.selected_func_key
182 |
183 | def apply_functions(self):
184 | total_tokens_prompted = 0
185 | self.runtime["apply_functions"] = 0
186 | self.token_used["apply_functions"] = 0
187 | self.extract_result = {}
188 | begin_time = time.time()
189 | for attribute in self.attributes:
190 | print(f"Apply the scripts to the data lake and save the metadata. Taking the top {self.profiler_args.num_top_k_scripts} scripts per field.")
191 | top_k_extractions, num_toks = apply_final_ensemble(
192 | self.profiler_args.full_file_groups,
193 | self.data_dict["file2chunks"],
194 | self.data_dict['file2contents'],
195 | self.selected_func_key[attribute],
196 | self.all_metrics,
197 | attribute,
198 | self.function_dictionary[attribute],
199 | data_lake=self.profiler_args.data_lake,
200 | manifest_sessions=self.manifest_sessions,
201 | function_cache=True,
202 | MODELS=self.profiler_args.EXTRACTION_MODELS,
203 | overwrite_cache=self.profiler_args.overwrite_cache,
204 | do_end_to_end=self.profiler_args.do_end_to_end,
205 | )
206 | total_tokens_prompted += num_toks
207 |
208 | file2metadata, num_toks = combine_extractions(
209 | self.profiler_args,
210 | top_k_extractions,
211 | self.all_metrics,
212 | combiner_mode=self.profiler_args.combiner_mode,
213 | train_extractions=self.all_extractions[attribute],
214 | attribute=attribute,
215 | gold_key = self.GOLD_KEY,
216 | extraction_fraction_thresh=self.profiler_args.extraction_fraction_thresh,
217 | )
218 | total_tokens_prompted += num_toks
219 | self.extract_result[attribute] = file2metadata
220 | self.runtime["apply_functions"] = time.time() - begin_time
221 | self.token_used["apply_functions"] = total_tokens_prompted
222 | return self.extract_result, self.evaluate(self.extract_result)
223 |
224 | def evaluate(self, result):
225 | f1_pred = {}
226 | for attribute in self.attributes:
227 | attribute2 = get_file_attribute(attribute)
228 | preds = []
229 | golds = []
230 | for file in result[attribute]:
231 | if(file in self.gold_extractions.keys()):
232 | preds.append(result[attribute][file])
233 | golds.append(self.gold_extractions[file][attribute])
234 | try:
235 | f1_pred[attribute] = text_f1(preds, golds)
236 | except:
237 | print(attribute)
238 | #turn to f1_pred[extraction_name] to a list of f1 scores
239 | f1_result = np.array(list(f1_pred.values()))
240 | return {"mean":f1_result.mean(), "std":f1_result.std(),"result":f1_pred}
--------------------------------------------------------------------------------
/evaporate/weak_supervision/binary_deps.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | import numpy as np
3 | from itertools import chain, product, combinations
4 | from scipy.sparse import issparse
5 | import more_itertools
6 | import torch
7 |
8 |
9 | class DependentPGM:
10 | """
11 | This class describes a PGM learned from labeled data with specified edge structure.
12 |
13 | Args:
14 | edges: list of edges that are dependent
15 | train_votes: n x m array of votes in {0, 1}
16 | train_gold: n array of true labels in {0, 1}
17 | """
18 | def __init__(
19 | self, edges, train_votes, train_gold, abstains = False, classes = [0, 1], abstain_value = -1) -> None:
20 | """
21 | Initialize the PGM by computing its junction tree factorization (c_tree and c_data)
22 | and by computing individual LF accuracy and class balance.
23 | """
24 |
25 | self.edges = edges
26 | self.train_votes = train_votes
27 | self.train_gold = train_gold
28 |
29 | self.classes = classes
30 | self.k = len(classes)
31 | assert len(np.unique(self.train_gold)) == self.k
32 |
33 | self.abstains = abstains
34 | assert len(np.unique(self.train_votes)) == int(abstains) + self.k
35 | self.abstain_value = abstain_value
36 |
37 | self.n, self.m = self.train_votes.shape
38 |
39 | self.nodes = np.arange(self.m)
40 | self.higher_order = len(edges) != 0
41 |
42 | # construct data structures containing dependency graph information (maximal cliques and separator sets)
43 | self._set_clique_tree()
44 | self._set_clique_data()
45 |
46 | # compute LF accuracies and class balance
47 | self._get_accs_and_cb()
48 |
49 | def _get_scaled(self):
50 | if self.classes == [0, 1]:
51 | self.train_votes_scaled = 2*self.train_votes - 1
52 | self.train_gold_scaled = 2*self.train_gold - 1
53 | if self.abstains:
54 | self.train_votes_scaled[self.train_votes == self.abstain_value] = 0
55 | else:
56 | self.train_votes_scaled = self.train_votes
57 | self.train_gold_scaled = self.train_gold
58 |
59 |
60 |
61 |
62 |
63 | def _set_clique_tree(self):
64 | G1 = nx.Graph()
65 | G1.add_nodes_from(self.nodes)
66 | G1.add_edges_from(self.edges)
67 |
68 | # Check if graph is chordal
69 | # TODO: Add step to triangulate graph if not
70 | if not nx.is_chordal(G1):
71 | raise nx.NetworkXError("Graph triangulation not implemented.")
72 |
73 | # Create maximal clique graph G2
74 | # Each node is a maximal clique C_i
75 | # Let w = |C_i \cap C_j|; C_i, C_j have an edge with weight w if w > 0
76 | G2 = nx.Graph()
77 | for i, c in enumerate(nx.chordal_graph_cliques(G1)):
78 | G2.add_node(i, members=c)
79 | for i in G2.nodes():
80 | for j in G2.nodes():
81 | S = G2.nodes[i]["members"].intersection(G2.nodes[j]["members"])
82 | w = len(S)
83 | if w > 0:
84 | G2.add_edge(i, j, weight=w, members=S)
85 |
86 | self.c_tree = nx.maximum_spanning_tree(G2) # should be maximum??? Because we want maximal separator sets
87 | # Return a minimum spanning tree of G2
88 |
89 | def _set_clique_data(self):
90 | # Create a helper data structure which maps cliques (as tuples of member
91 | # sources) --> {start_index, end_index, maximal_cliques}, where
92 | # the last value is a set of indices in this data structure
93 | self.c_data = dict()
94 | for i in range(self.m):
95 | self.c_data[i] = {
96 | "vertices": [i],
97 | "max_cliques": set( # which max clique i belongs to
98 | [
99 | j
100 | for j in self.c_tree.nodes()
101 | if i in self.c_tree.nodes[j]["members"]
102 | ]
103 | ),
104 | }
105 |
106 | # Get the higher-order clique statistics based on the clique tree
107 | # First, iterate over the maximal cliques (nodes of c_tree) and
108 | # separator sets (edges of c_tree)
109 | if self.higher_order:
110 | counter = 0
111 | for item in chain(self.c_tree.nodes(), self.c_tree.edges()):
112 | if isinstance(item, int):
113 | C = self.c_tree.nodes[item]
114 | C_type = "node"
115 | elif isinstance(item, tuple):
116 | C = self.c_tree[item[0]][item[1]]
117 | C_type = "edge"
118 | else:
119 | raise ValueError(item)
120 | members = list(C["members"])
121 | nc = len(members)
122 |
123 | # Else add one column for each possible value
124 | if nc != 1:
125 | # Add to self.c_data as well
126 | #idx = counter + m
127 | self.c_data[tuple(members)] = {
128 | "vertices": members,
129 | "max_cliques": set([item]) if C_type == "node" else set(item),
130 | }
131 | counter += 1
132 |
133 |
134 | def _get_accs_and_cb(self):
135 | classes = [0, 1]
136 | self.gold_idxs = [np.where(self.train_gold == c)[0] for c in classes]
137 |
138 | self.accs = np.zeros((self.m, 2)) # [i, j, k] = Pr(prompt_i = j| y = k)
139 | for p in range(self.m):
140 | for i in classes:
141 | self.accs[p, i] = len(np.where(self.train_votes[self.gold_idxs[i], p] == 1)[0]) / len(self.gold_idxs[i])
142 |
143 | self.accs = np.clip(self.accs, 0.0001, 0.9999)
144 | self.balance = len(self.gold_idxs[1]) / self.n
145 |
146 | def get_clique_probs(self, idxs, vals, y):
147 | """
148 | Computes marginal probability over voters indexed by idx, Pr(votes_idxs = vals | y).
149 | """
150 | truth_matrix = np.ones(len(self.gold_idxs[y])).astype(bool)
151 | for i, lf in enumerate(idxs):
152 | truth_matrix = np.logical_and(truth_matrix, self.train_votes[self.gold_idxs[y], lf] == vals[i])
153 |
154 | if len(np.where(truth_matrix == True)[0]) == 0:
155 | return 0.00001
156 | return len(np.where(truth_matrix == True)[0]) / len(self.gold_idxs[y])
157 |
158 |
159 | def get_cond_probs(self, votes, y):
160 | """
161 | Computes the probability Pr(votes | y).
162 | """
163 | pr_y = self.balance if y == 1 else 1 - self.balance
164 | prod = pr_y
165 |
166 | for i in self.c_tree.nodes():
167 | node = self.c_tree.nodes[i]
168 | members = list(node['members'])
169 | if len(members) == 1:
170 | v = members[0]
171 | print(f"multiplying by {votes[v] * self.accs[v, y]}")
172 | prod *= votes[v] * self.accs[v, y] + (1 - votes[v]) * (1 - self.accs[v, y])
173 | else:
174 | print(members)
175 | print(f"multiplying by {self.get_clique_probs(members, votes[members], y)}")
176 |
177 | prod *= self.get_clique_probs(members, votes[members], y)
178 |
179 | for i in self.c_tree.edges():
180 | edge = self.c_tree.edges[i]
181 | members = list(edge['members'])
182 | if len(members) == 1:
183 | v = members[0]
184 | deg = len(self.c_data[v]['max_cliques'])
185 | prod /= (votes[v] * self.accs[v, y] + (1 - votes[v]) * (1 - self.accs[v, y]))**(deg-1)
186 |
187 | print(members)
188 | print(f"Dividing by {votes[v] * self.accs[v, y] + (1 - votes[v]) * (1 - self.accs[v, y])} to the {deg - 1} power")
189 |
190 | else:
191 | deg = len(self.c_data[tuple(members)]['max_cliques'])
192 | prod /= (self.get_clique_probs(members, votes[members], y))**(deg-1)
193 |
194 | print(members)
195 | print(f"Dividing by {self.get_clique_probs(members, votes[members], y)} to the {deg - 1} power")
196 |
197 | print(prod)
198 | return prod
199 |
200 | def get_probs(self, votes):
201 | """
202 | Computes the probability Pr(y = 1 | votes).
203 | """
204 | pos = self.get_cond_probs(votes, 1)
205 | neg = self.get_cond_probs(votes, 0)
206 | if pos == 0:
207 | return 0
208 | else:
209 | return pos / (pos + neg)
210 |
211 | def evaluate(self, test_votes, test_gold):
212 | """
213 | Using our learned PGM, output rounded estimates of Pr(y = 1 | votes) and computes its accuracy.
214 |
215 | Args:
216 | test_votes: vote array to perform inference on in {0, 1}
217 | test_gold: true labels to compare to in {0, 1}
218 | """
219 | n_test = len(test_votes)
220 |
221 | output_rounded = np.zeros(n_test)
222 | output_probs = np.zeros(n_test)
223 | err = 0
224 | for i in range(n_test):
225 | output_probs[i] = self.get_probs(test_votes[i])
226 | output_rounded[i] = np.round(output_probs[i])
227 | err += output_rounded[i] != test_gold[i]
228 |
229 | accuracy = 1 - err / n_test
230 |
231 | return output_probs, output_rounded, accuracy
232 |
233 |
234 | def is_triangulated(nodes, edges):
235 | """
236 | If a graph is triangulated (e.g. if a junction tree factorization exists).
237 | """
238 | G1 = nx.Graph()
239 | G1.add_nodes_from(nodes)
240 | G1.add_edges_from(edges)
241 | return nx.is_chordal(G1)
242 |
243 |
244 | def structure_learning(m, votes, gold, acc_theta, classes = [0, 1], l1_lambda=0.2):
245 | """
246 | Structure learning algorithm (Ising model selection) from Ravikumar (2010).
247 |
248 | Args:
249 | - votes: n_train x m array of training votes
250 | - gold: n_train array of gold labels on the training data
251 | - acc_theta: E[vote_i y] (where vote and y are scaled to [-1, 1]). This is a scaled version of accuracy that we will initialize some of the
252 | parameters in our PGM with in order to specify that we don't want to optimize over the edges between votes and y.
253 | We only are learning edges among votes!
254 | - classes: the list of classes the data can take on.
255 | - l1_lambda: l1 regularization strength
256 | """
257 |
258 | # scale the data
259 | classes = np.sort(np.unique(gold))
260 | vote_classes = np.sort(np.unique(votes))
261 | if 0 in classes and 1 in classes:
262 | votes_scaled = 2*votes - 1
263 | gold_scaled = 2*gold - 1
264 | if len(vote_classes) == len(classes) + 1:
265 | votes_scaled[votes == -1] = 0
266 | else:
267 | votes_scaled = votes
268 | gold_scaled = gold
269 |
270 | acc_theta = torch.from_numpy(acc_theta).type(torch.FloatTensor)
271 | all_thetas = np.zeros((m, m)) # learned thetas from alg
272 |
273 | # for each prompt, we fit a logistic regression model on it with prompt_i's output as the response variable and all otehr prompt outputs as the covariates.
274 | # big_theta is a vector of weights that denote dependence on each prompt (0 is independence).
275 | for v in range(m):
276 | print(f"Learning neighborhood of vertex {v}.")
277 | if len(classes) == 2:
278 | big_theta = learn_neighborhood(m, v, votes_scaled, gold_scaled, acc_theta, l1_lambda)
279 | else:
280 | big_theta = learn_neighborhood_multi(m, v, votes_scaled, gold_scaled, acc_theta, l1_lambda, classes)
281 | all_thetas[v] = big_theta
282 |
283 | return all_thetas
284 |
285 |
286 | # v is the vertex whose neighborhood graph we are estimating
287 | def learn_neighborhood(m, vertex, votes, gold, accs, l1_lambda, epochs = 50000):
288 | """
289 | Learn the neighborhood graph for a vertex.
290 |
291 | Args:
292 | - m: number of prompts
293 | - vertex: the index of the prompt we are selecting as the response variable
294 | - votes: votes on training data
295 | - gold: gold label of training data
296 | - accs: training accuracies of each prompt we use to initialize the PGM parameters with
297 | - l1_lambda: regularization strength
298 | - epochs: number of iterations
299 | """
300 | n = len(gold)
301 | vote_y = np.concatenate((votes, gold.reshape(n, 1)), axis=1)
302 |
303 | xr = vote_y[:, vertex]
304 | x_notr = np.delete(vote_y, vertex, axis=1)
305 |
306 | xr = torch.from_numpy(xr).type(torch.FloatTensor)
307 | x_notr = torch.from_numpy(x_notr).type(torch.FloatTensor)
308 |
309 |
310 | theta = torch.zeros(m) # last index is for accuracy between vertex and y
311 | theta[m - 1] = accs[vertex] # initialize this to be the train accuracy. We do want this to be an optimizable variable still though.
312 | theta.requires_grad_()
313 |
314 | optimizer = torch.optim.SGD([theta], lr=0.0001)
315 | for t in range(epochs):
316 | optimizer.zero_grad()
317 |
318 | # logistic regression from Ravikumar et al
319 | fx = (torch.log(torch.exp(torch.matmul(x_notr, theta))
320 | + torch.exp(-torch.matmul(x_notr, theta))).mean())
321 | loss = fx - torch.multiply(xr, x_notr.T).mean(dim=1).dot(theta) + l1_lambda * torch.linalg.vector_norm(theta[:m], ord=1)
322 |
323 | loss.backward()
324 | optimizer.step()
325 |
326 | #if t % 1000 == 0:
327 | # print(f"Loss: {loss}")
328 |
329 | big_theta = np.concatenate([theta.detach().numpy()[:vertex], [0], theta.detach().numpy()[vertex:m - 1]])
330 | return big_theta
331 |
332 | # v is the vertex whose neighborhood graph we are estimating
333 | def learn_neighborhood_multi(m, vertex, votes, gold, accs, l1_lambda, classes, epochs = 50000):
334 | # votes: in range {0, ... k}
335 | n = len(gold)
336 | vote_y = np.concatenate((votes, gold.reshape(n, 1)), axis=1)
337 |
338 | xr = vote_y[:, vertex]
339 | x_notr = np.delete(vote_y, vertex, axis=1)
340 |
341 | xr = torch.from_numpy(xr).type(torch.FloatTensor)
342 | x_notr = torch.from_numpy(x_notr).type(torch.FloatTensor)
343 |
344 |
345 | theta = torch.zeros(m) # last index is for accuracy between vertex and y
346 | theta[m - 1] = accs[vertex] # initialize this
347 | theta.requires_grad_()
348 |
349 | optimizer = torch.optim.SGD([theta], lr=0.0001)
350 | for t in range(epochs):
351 | optimizer.zero_grad()
352 |
353 | # logistic regression from Ravikumar et al
354 | mu = 0
355 | for i in range(x_notr.shape[1]):
356 | # mu = \sum_i theta_i * \sum_data sign{x_r = x_i}
357 | mu += (2*(xr == x_notr[:, i])-1).type(torch.FloatTensor).mean() * theta[i]
358 |
359 | fx = 0
360 | for k in classes:
361 | # \sum_y exp( \sum_i theta_i sign(x_i = y)) "normalization"
362 | fx += torch.exp(torch.matmul((2*(x_notr == k)-1).type(torch.FloatTensor), theta)).mean()
363 |
364 | loss = fx - mu + l1_lambda * torch.linalg.vector_norm(theta[:m], ord=1)
365 |
366 | loss.backward()
367 | optimizer.step()
368 |
369 | #if t % 1000 == 0:
370 | # print(f"Loss: {loss}")
371 |
372 | big_theta = np.concatenate([theta.detach().numpy()[:vertex], [0], theta.detach().numpy()[vertex:m - 1]])
373 | return big_theta
374 |
375 | def main():
376 | # load data
377 | vote_arr_train = np.load('./data/youtube-spam/train_votes.npy').T
378 | vote_arr_test = np.load('./data/youtube-spam/test_votes.npy').T
379 | gold_arr_train = np.load('./data/youtube-spam/train_gold.npy').T
380 | gold_arr_test = np.load('./data/youtube-spam/test_gold.npy').T
381 |
382 | # vote_arr_train = np.concatenate((vote_arr_train[:, 0: 2], vote_arr_train[:, 4:]), axis=1)
383 | # vote_arr_test = np.concatenate((vote_arr_test[:, 0: 2], vote_arr_test[:, 4:]), axis=1)
384 |
385 | n_train, num_prompts = vote_arr_train.shape
386 |
387 |
388 | # make validation set
389 | np.random.seed(4)
390 | val_idxs = np.random.choice(np.arange(n_train), size= 28, replace=False)
391 | vote_arr_val = vote_arr_train[val_idxs, :]
392 | vote_arr_train = np.delete(vote_arr_train, val_idxs, axis=0)
393 |
394 | gold_arr_val = gold_arr_train[val_idxs]
395 | gold_arr_train = np.delete(gold_arr_train, val_idxs)
396 |
397 | nodes = np. arange(num_prompts)
398 |
399 |
400 | # specify edgeset
401 | # edges =[(0, 1)]
402 | #model = DependentPGM(edges, vote_arr_train, gold_arr_train)
403 | #probs, output, acc = model.evaluate(vote_arr_test, gold_arr_test)
404 | #print(acc)
405 |
406 |
407 | # Brute-force iteration through a bunch of edges
408 | all_edges = list(combinations(nodes, 2))
409 | small_edgesets = list(more_itertools.powerset(all_edges))
410 | #small_edgesets = list(combinations(all_edges, 0)) + list(combinations(all_edges, 1)) + list(combinations(all_edges, 2)) + list(combinations(all_edges, 3))
411 | scores = np.zeros(len(small_edgesets))
412 |
413 | for i, edgeset in enumerate(small_edgesets):
414 | if len(edgeset) > 4:
415 | break
416 | if not is_triangulated(nodes, edgeset):
417 | continue
418 | model = DependentPGM(edgeset, vote_arr_train, gold_arr_train)
419 |
420 | probs, output, scores[i] = model.evaluate(vote_arr_val, gold_arr_val)
421 | if i % 100 == 0:
422 | print(f"Edgeset: {edgeset} \n score: {scores[i]}")
423 |
424 | print(f"Best edgeset score: {scores.max()}")
425 | print(f"Best edgeset: {small_edgesets[scores.argmax()]}")
426 |
427 | edges = small_edgesets[scores.argmax()]
428 |
429 | vote_arr_train = np.concatenate((vote_arr_train, vote_arr_val))
430 | gold_arr_train = np.concatenate((gold_arr_train, gold_arr_val))
431 |
432 | model = DependentPGM(edges, vote_arr_train, gold_arr_train)
433 | probs, output, acc = model.evaluate(vote_arr_test, gold_arr_test)
434 | print(f"Final model accuracy: {acc}")
435 |
436 |
437 | if __name__ == "__main__":
438 | main()
439 |
--------------------------------------------------------------------------------
/evaporate/configs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import datetime
4 |
5 | def get_run_string(
6 | data_lake, today, file_groups, profiler_args, do_end_to_end,
7 | train_size, dynamicbackoff, models
8 | ):
9 | body = profiler_args.body_only # Baseline systems only operate on the HTML body
10 | model_ct = len(models)
11 | if profiler_args.use_qa_model:
12 | model_ct += 1
13 | run_string = f"dataLake{data_lake}_date{today}_fileSize{len(file_groups)}_trainSize{train_size}_numAggregate{profiler_args.num_top_k_scripts}_chunkSize{profiler_args.chunk_size}_removeTables{profiler_args.remove_tables}_body{body}_cascading{do_end_to_end}_useBackoff{dynamicbackoff}_MODELS{model_ct}"
14 | return run_string
15 |
16 | def get_data_lake_info(args):
17 | extractions_file = None
18 |
19 | if 1:
20 | DATA_DIR = args.data_dir
21 | file_groups = os.listdir(args.data_dir)
22 | if not DATA_DIR.endswith("/"):
23 | DATA_DIR += "/"
24 | file_groups = [f"{DATA_DIR}{file_group}" for file_group in file_groups if not file_group.startswith(".")]
25 | full_file_groups = file_groups.copy()
26 | extractions_file = args.gold_extractions_file
27 | parser = "txt"
28 |
29 | return file_groups, extractions_file, parser, full_file_groups
30 |
31 | #default args for the experiment
32 | def get_experiment_args():
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument(
35 | "--data_lake",
36 | type=str,
37 | help="Name of the data lake to operate over. Must be in configs.py"
38 | )
39 | parser.add_argument(
40 | "--data_dir",
41 | type=str,
42 | help="Path to raw data-lake documents",
43 | )
44 | parser.add_argument(
45 | "--base_data_dir",
46 | type=str,
47 | default="/",
48 | help="Path to save intermediate result",
49 | )
50 | parser.add_argument(
51 | "--cache_dir",
52 | type=str,
53 | default= "",
54 | help = "Path to cache intermediate files during system execution",
55 | )
56 | parser.add_argument(
57 | "--generative_index_path",
58 | type=str,
59 | default= "",
60 | help = "Path to store the generated structured view of the data lake",
61 | )
62 | parser.add_argument(
63 | "--gold_extractions_file",
64 | type=str,
65 | default= "",
66 | help = "Path to store the generated structured view of the data lake",
67 | )
68 | parser.add_argument(
69 | "--do_end_to_end",
70 | action='store_true',
71 | default=False,
72 | help="True for OpenIE, False for ClosedIE"
73 | )
74 |
75 | parser.add_argument(
76 | "--num_attr_to_cascade",
77 | type=int,
78 | default=35,
79 | help="Number of attributes to generate functions for"
80 | )
81 |
82 | parser.add_argument(
83 | "--num_top_k_scripts",
84 | type=int,
85 | default=10,
86 | help="Number of generated functions to combine over for each attribute"
87 | )
88 |
89 | parser.add_argument(
90 | "--train_size",
91 | type=int,
92 | default=10,
93 | help="Number of files to prompt on"
94 | )
95 |
96 | parser.add_argument(
97 | "--combiner_mode",
98 | type=str,
99 | default='ws',
100 | help="Combiner mode for combining the outputs of the generated functions",
101 | choices=['ws', 'mv', 'top_k']
102 | )
103 |
104 | parser.add_argument(
105 | "--use_dynamic_backoff",
106 | action="store_true",
107 | default=True,
108 | help="Whether to generate functions or do Evaporate-Direct",
109 | )
110 |
111 | parser.add_argument(
112 | "--KEYS",
113 | type=str,
114 | default=[],
115 | help="List of keys to use the model api",
116 | nargs='*'
117 | )
118 | parser.add_argument(
119 | "--MODELS",
120 | type=str,
121 | default=["gpt-4"],
122 | help="List of models to use for the extraction step"
123 | )
124 | parser.add_argument(
125 | "--EXTRACTION_MODELS",
126 | type=str,
127 | default=["gpt-4"],
128 | help="List of models to use for the extraction step"
129 | )
130 | parser.add_argument(
131 | "--MODEL2URL",
132 | type=str,
133 | default={},
134 | help="Dict mapping models to their urls"
135 | )
136 | parser.add_argument(
137 | "--use_qa_model",
138 | action="store_true",
139 | default=False,
140 | help="Whether to use a QA model for the extraction step"
141 | )
142 | parser.add_argument(
143 | "--GOLD_KEY",
144 | type=str,
145 | default="gpt-4",
146 | help="Key to use for the gold standard"
147 | )
148 |
149 | parser.add_argument(
150 | "--eval_size",
151 | type=int,
152 | default=15,
153 | )
154 |
155 | parser.add_argument(
156 | "--max_chunks_per_file",
157 | type=int,
158 | default=-1,
159 | )
160 |
161 | parser.add_argument(
162 | "--chunk_size",
163 | type=int,
164 | default=3000,
165 | )
166 |
167 | parser.add_argument(
168 | "--extraction_fraction_thresh",
169 | type=int,
170 | default=0.9,
171 | help="for abstensions approach",
172 | )
173 |
174 | parser.add_argument(
175 | "--remove_tables",
176 | action="store_true",
177 | default=False,
178 | help="Remove tables from the html files?",
179 | )
180 |
181 | parser.add_argument(
182 | "--body_only",
183 | action="store_true",
184 | default=False,
185 | help="Only use HTML body",
186 | )
187 |
188 | parser.add_argument(
189 | "--max_metadata_fields",
190 | type=int,
191 | default=15,
192 | )
193 |
194 | parser.add_argument(
195 | "--overwrite_cache",
196 | action='store_true',
197 | default=False,
198 | help="overwrite the manifest cache"
199 | )
200 | parser.add_argument(
201 | "--swde_plus",
202 | action="store_true",
203 | default=False,
204 | help="Whether to use the extended SWDE dataset to measure OpenIE performance",
205 | )
206 |
207 | parser.add_argument(
208 | "--schema_id_sizes",
209 | type=int,
210 | default=0,
211 | help="Number of documents to use for schema identification stage, if it differs from extraction",
212 | )
213 |
214 | parser.add_argument(
215 | "--slice_results",
216 | action="store_true",
217 | default=False,
218 | help="Whether to measure the results by attribute-slice",
219 | )
220 |
221 | parser.add_argument(
222 | "--fn_generation_prompt_num",
223 | type=int,
224 | default=-1,
225 | help="For ablations on function generation with diversity, control which prompt we use. Default is all.",
226 | )
227 |
228 | parser.add_argument(
229 | "--upper_bound_fns",
230 | action="store_true",
231 | default=False,
232 | help="For ablations that select functions using ground truth instead of the FM.",
233 | )
234 |
235 | parser.add_argument(
236 | "--use_alg_filtering",
237 | type=str,
238 | default=True,
239 | help="Whether to filter functions based on quality.",
240 | )
241 |
242 | parser.add_argument(
243 | "--use_abstension",
244 | type=str,
245 | default=True,
246 | help="Whether to use the abstensions approach.",
247 | )
248 |
249 | parser.add_argument(
250 | "--set_dicts",
251 | type=str,
252 | default='',
253 | help="Alternate valid names for the SWDE attributes as provided in the benchmark.",
254 | )
255 |
256 | parser.add_argument(
257 | "--topic",
258 | type=list,
259 | default=[],
260 | help="Topic of the data lake",
261 | )
262 | experiment = parser.parse_args()
263 | return experiment
264 |
265 | #get args related to data storage paths
266 | def get_args(profiler_args):
267 | if(profiler_args.generative_index_path == ""):
268 | profiler_args.generative_index_path = os.path.join(profiler_args.base_data_dir, "generative_indexes/", profiler_args.data_lake)
269 | if(profiler_args.cache_dir == ""):
270 | profiler_args.cache_dir = os.path.join(profiler_args.base_data_dir, "cache/", profiler_args.data_lake)
271 | if(profiler_args.gold_extractions_file == ""):
272 | profiler_args.gold_extractions_file = os.path.join(profiler_args.base_data_dir, "gold_extractions.json" )
273 | parser = argparse.ArgumentParser(
274 | "LLM explorer.",
275 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
276 | )
277 |
278 | parser.add_argument(
279 | "--overwrite_cache",
280 | type=bool,
281 | default=profiler_args.overwrite_cache,
282 | help="Whether to overwrite the caching for prompts."
283 | )
284 |
285 | parser.add_argument(
286 | "--data_lake",
287 | type=str,
288 | default=profiler_args.data_lake,
289 | help="Name of the data lake"
290 | )
291 |
292 | parser.add_argument(
293 | "--data_dir",
294 | type=str,
295 | default = profiler_args.data_dir,
296 | help="Path to raw data-lake documents",
297 | )
298 |
299 | parser.add_argument(
300 | "--generative_index_path",
301 | type=str,
302 | default = profiler_args.generative_index_path,
303 | help="Path to store the generated structured view of the data lake",
304 | )
305 |
306 | parser.add_argument(
307 | "--cache_dir",
308 | type=str,
309 | default=profiler_args.cache_dir,
310 | help="Path to cache intermediate files during system execution",
311 | )
312 |
313 | parser.add_argument(
314 | "--set_dicts",
315 | type=str,
316 | default=profiler_args.set_dicts,
317 | help="Alternate valid names for the SWDE attributes as provided in the benchmark.",
318 | )
319 |
320 | parser.add_argument(
321 | "--gold_extractions_file",
322 | type=str,
323 | default=profiler_args.gold_extractions_file,
324 | help="Path to store the generated structured view of the data lake",
325 | )
326 | parser.add_argument(
327 | "--topic",
328 | type=list,
329 | default=profiler_args.topic,
330 | help="Topic of the data lake",
331 | )
332 |
333 | args = parser.parse_args(args=[])
334 | return args
335 |
336 | #for notebook settings
337 | def set_profiler_args(information):
338 | parser = argparse.ArgumentParser(
339 | "LLM explorer.",
340 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
341 | )
342 | parser.add_argument(
343 | "--data_lake",
344 | type=str,
345 | help="Name of the data lake to operate over. Must be in configs.py"
346 | )
347 | parser.add_argument(
348 | "--data_dir",
349 | type=str,
350 | help="Path to raw data-lake documents",
351 | )
352 | parser.add_argument(
353 | "--base_data_dir",
354 | type=str,
355 | default="/",
356 | help="Path to save intermediate result",
357 | )
358 | parser.add_argument(
359 | "--cache_dir",
360 | type=str,
361 | default= "",
362 | help = "Path to cache intermediate files during system execution",
363 | )
364 | parser.add_argument(
365 | "--generative_index_path",
366 | type=str,
367 | default= "",
368 | help = "Path to store the generated structured view of the data lake",
369 | )
370 | parser.add_argument(
371 | "--gold_extractions_file",
372 | type=str,
373 | default= "",
374 | help = "Path to store the generated structured view of the data lake",
375 | )
376 |
377 | parser.add_argument(
378 | "--num_attr_to_cascade",
379 | type=int,
380 | default=35,
381 | help="Number of attributes to generate functions for"
382 | )
383 |
384 | parser.add_argument(
385 | "--num_top_k_scripts",
386 | type=int,
387 | default=10,
388 | help="Number of generated functions to combine over for each attribute"
389 | )
390 |
391 |
392 | parser.add_argument(
393 | "--combiner_mode",
394 | type=str,
395 | default='ws',
396 | help="Combiner mode for combining the outputs of the generated functions",
397 | choices=['ws', 'mv', 'top_k']
398 | )
399 |
400 | parser.add_argument(
401 | "--use_dynamic_backoff",
402 | action="store_true",
403 | default=True,
404 | help="Whether to generate functions or do Evaporate-Direct",
405 | )
406 |
407 | parser.add_argument(
408 | "--KEYS",
409 | type=str,
410 | default=[],
411 | help="List of keys to use the model api",
412 | nargs='*'
413 | )
414 | parser.add_argument(
415 | "--MODELS",
416 | type=str,
417 | default=["gpt-4"],
418 | help="List of models to use for the extraction step"
419 | )
420 | parser.add_argument(
421 | "--EXTRACTION_MODELS",
422 | type=str,
423 | default=["gpt-4"],
424 | help="List of models to use for the extraction step"
425 | )
426 | parser.add_argument(
427 | "--MODEL2URL",
428 | type=str,
429 | default={},
430 | help="Dict mapping models to their urls"
431 | )
432 | parser.add_argument(
433 | "--use_qa_model",
434 | action="store_true",
435 | default=False,
436 | help="Whether to use a QA model for the extraction step"
437 | )
438 | parser.add_argument(
439 | "--GOLD_KEY",
440 | type=str,
441 | default="gpt-4",
442 | help="Key to use for the gold standard"
443 | )
444 |
445 | parser.add_argument(
446 | "--eval_size",
447 | type=int,
448 | default=15,
449 | )
450 |
451 | parser.add_argument(
452 | "--chunk_size",
453 | type=int,
454 | default=3000,
455 | )
456 |
457 | parser.add_argument(
458 | "--max_chunks_per_file",
459 | type=int,
460 | default=-1,
461 | )
462 |
463 |
464 | parser.add_argument(
465 | "--extraction_fraction_thresh",
466 | type=int,
467 | default=0.9,
468 | help="for abstensions approach",
469 | )
470 |
471 | parser.add_argument(
472 | "--remove_tables",
473 | action="store_true",
474 | default=False,
475 | help="Remove tables from the html files?",
476 | )
477 |
478 | parser.add_argument(
479 | "--body_only",
480 | action="store_true",
481 | default=False,
482 | help="Only use HTML body",
483 | )
484 |
485 | parser.add_argument(
486 | "--max_metadata_fields",
487 | type=int,
488 | default=15,
489 | )
490 |
491 | parser.add_argument(
492 | "--overwrite_cache",
493 | action='store_true',
494 | default=False,
495 | help="overwrite the manifest cache"
496 | )
497 | parser.add_argument(
498 | "--swde_plus",
499 | action="store_true",
500 | default=False,
501 | help="Whether to use the extended SWDE dataset to measure OpenIE performance",
502 | )
503 |
504 | parser.add_argument(
505 | "--schema_id_sizes",
506 | type=int,
507 | default=0,
508 | help="Number of documents to use for schema identification stage, if it differs from extraction",
509 | )
510 |
511 | parser.add_argument(
512 | "--slice_results",
513 | action="store_true",
514 | default=False,
515 | help="Whether to measure the results by attribute-slice",
516 | )
517 |
518 | parser.add_argument(
519 | "--fn_generation_prompt_num",
520 | type=int,
521 | default=-1,
522 | help="For ablations on function generation with diversity, control which prompt we use. Default is all.",
523 | )
524 |
525 | parser.add_argument(
526 | "--upper_bound_fns",
527 | action="store_true",
528 | default=False,
529 | help="For ablations that select functions using ground truth instead of the FM.",
530 | )
531 |
532 | parser.add_argument(
533 | "--use_alg_filtering",
534 | action='store_true',
535 | default=False,
536 | help="Whether to filter functions based on quality.",
537 | )
538 |
539 | parser.add_argument(
540 | "--use_abstension",
541 | action='store_true',
542 | default=False,
543 | help="Whether to use the abstensions approach.",
544 | )
545 |
546 | parser.add_argument(
547 | "--do_end_to_end",
548 | action='store_true',
549 | default=False,
550 | help="True for OpenIE, False for ClosedIE"
551 | )
552 |
553 | parser.add_argument(
554 | "--set_dicts",
555 | type=str,
556 | default='',
557 | help="Alternate valid names for the SWDE attributes as provided in the benchmark.",
558 | )
559 |
560 | parser.add_argument(
561 | "--topic",
562 | type=list,
563 | default=[],
564 | help="Topic of the data lake",
565 | )
566 |
567 | args = parser.parse_args(args=[])
568 | for key, value in information.items():
569 | setattr(args, key, value)
570 | today = datetime.datetime.today().strftime("%m%d%Y")
571 | file_groups, extractions_file, parser, full_file_groups = get_data_lake_info(args)
572 | setattr(args, "file_groups", file_groups)
573 | setattr(args, "extractions_file", extractions_file)
574 | setattr(args, "parser", parser)
575 | setattr(args, "full_file_groups", full_file_groups)
576 | if("train_size" not in args):
577 | args.train_size = 10
578 | if "use_dynamic_backoff" not in args:
579 | args.use_dynamic_backoff = True
580 | setattr(args, "run_string", get_run_string(args.data_lake, today, args.full_file_groups, args, args.do_end_to_end, args.train_size, args.use_dynamic_backoff, args.EXTRACTION_MODELS))
581 | if(args.generative_index_path == ""):
582 | args.generative_index_path = os.path.join(args.base_data_dir, "generative_indexes/", args.data_lake)
583 | if(args.cache_dir == ""):
584 | args.cache_dir = os.path.join(args.base_data_dir, "cache/", args.data_lake)
585 | if(args.gold_extractions_file == ""):
586 | args.gold_extractions_file = os.path.join(args.base_data_dir, "gold_extractions.json" )
587 | return args
588 |
589 |
590 |
--------------------------------------------------------------------------------
/demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 2,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "from evaporate.main import EvaporateData\n",
20 | "import json"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 3,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "MANIFEST_URL = \"http://127.0.0.1:5000\" # please make sure that a local manifest session is running with your model at this address\n",
30 | "DATA_DIR = \"/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k\"\n",
31 | "GOLD_PATH = \"/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/table.json\"\n",
32 | "BASE_DATA_DIR = \"/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/\"\n",
33 | "model = \"mistralai/Mistral-7B-Instruct-v0.2\""
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {},
39 | "source": [
40 | "# Set Up"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 4,
46 | "metadata": {},
47 | "outputs": [
48 | {
49 | "name": "stdout",
50 | "output_type": "stream",
51 | "text": [
52 | "using together\n"
53 | ]
54 | },
55 | {
56 | "name": "stderr",
57 | "output_type": "stream",
58 | "text": [
59 | "Chunking files: 100%|██████████| 100/100 [00:00<00:00, 1089.38it/s]\n"
60 | ]
61 | }
62 | ],
63 | "source": [
64 | "profiler_args= {\n",
65 | " \"data_lake\": \"fda\", \n",
66 | " \"combiner_mode\": \"mv\", \n",
67 | " \"do_end_to_end\": False,\n",
68 | " \"KEYS\": [\"\"], \n",
69 | " \"MODELS\":[model],\n",
70 | " \"EXTRACTION_MODELS\": [model],\n",
71 | " \"MODEL2URL\" : {},\n",
72 | " \"data_dir\": f\"{DATA_DIR}\", \n",
73 | " \"chunk_size\": 3000,\n",
74 | " \"base_data_dir\": f\"{BASE_DATA_DIR}\", \n",
75 | " \"gold_extractions_file\": GOLD_PATH,\n",
76 | " \"direct_extract_model\": model\n",
77 | "}\n",
78 | "\n",
79 | "evaporate = EvaporateData(profiler_args)"
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "metadata": {},
85 | "source": [
86 | "# Get Schema Attribute"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 5,
92 | "metadata": {},
93 | "outputs": [
94 | {
95 | "name": "stdout",
96 | "output_type": "stream",
97 | "text": [
98 | "['purpose for submission', 'type of test', 'classification', 'product code', 'panel', 'indications for use', 'predicate device name', 'proposed labeling', 'conclusion', '510k number', 'applicant', 'predicate 510k number', 'proprietary and established names', 'regulation section', 'measurand', 'intended use']\n"
99 | ]
100 | }
101 | ],
102 | "source": [
103 | "attributes = evaporate.get_attribute(do_end_to_end=False)\n",
104 | "print(attributes)"
105 | ]
106 | },
107 | {
108 | "cell_type": "markdown",
109 | "metadata": {},
110 | "source": [
111 | "# Get Direct Attribute"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 6,
117 | "metadata": {},
118 | "outputs": [
119 | {
120 | "name": "stderr",
121 | "output_type": "stream",
122 | "text": [
123 | "Extracting attribute purpose for submission using LM: 100%|██████████| 10/10 [00:28<00:00, 2.80s/it]\n"
124 | ]
125 | },
126 | {
127 | "name": "stdout",
128 | "output_type": "stream",
129 | "text": [
130 | "finish extract purpose for submission\n"
131 | ]
132 | },
133 | {
134 | "name": "stderr",
135 | "output_type": "stream",
136 | "text": [
137 | "Extracting attribute type of test using LM: 100%|██████████| 10/10 [00:14<00:00, 1.41s/it]"
138 | ]
139 | },
140 | {
141 | "name": "stdout",
142 | "output_type": "stream",
143 | "text": [
144 | "finish extract type of test\n",
145 | "{'purpose for submission': {'/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K150526.txt': ['The given context is a section from a Decision Memorandum for a 510(k) submission. In this section', 'the purpose for submission is stated as New Device. This indicates that the submission is for a new medical device that has not been previously cleared or approved by the FDA.'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K151046.txt': ['To obtain substantial equivalence determination for the illumigene® HSV 1&2 DNA Amplification Assay'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K180886.txt': ['To obtain a substantial equivalence for the addition of Delafloxacin at concentrations of 0.002-32 µg/mL for susceptibility testing of non-fastidious Gram negative organisms'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K181525.txt': ['Purpose for Submission: New Device'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K151265.txt': ['New Submission'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K171641.txt': ['This is a new 510(k) application for the determination of Substantial Equivalence for the Mesa Biotech Accula Flu A/Flu B Test and associated instrument. Mesa Biotech', 'Inc. has submitted a combined 510(k) and CLIA waiver package for dual review.'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K182472.txt': ['To obtain a Substantial Equivalence determination for the Cepheid Xpert GBS LB Control Panel for use with the Cepheid Xpert GBS LB Assay on the GeneXpert Instrument System.'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K161714.txt': ['Answer: The Immunalysis Barbiturates Urine Enzyme Immunoassay is a homogeneous enzyme immunoassay with a cutoff of 200 ng/mL. The assay is intended for use in laboratories for the qualitative and semi-quantitative analysis of Barbiturates in human urine with automated clinical chemistry analyzers. This assay is calibrated against Secobarbital. This in vitro diagnostic device is for prescription use only. The semi-quantitative mode is for purposes of enabling laboratories to determine an appropriate dilution of the specimen for confirmation by a confirmatory method such as Gas Chromatography/ Mass Spectrometry (GC-MS) or Liquid Chromatography/ Tandem Mass Spectrometry (LC-MS/MS) or permitting laboratories to establish quality control procedures. The Immunalysis Barbiturates Urine Enzyme Immunoassay provides only a preliminary analytical test result. A more specific alternate chemical method must be used in order to obtain a confirmed analytical result. GC-MS or LC-MS/MS is the preferred confirmatory method. Clinical consideration and professional judgment should be applied to any drug of abuse test result', 'particularly when preliminary positive results are used. The Immunalysis Multi-Drug Calibrators: The Immunalysis Multi-Drug Calibrators are intended for in vitro diagnostic use for the calibration of assays for the following analytes: Benzoylecgonine', 'Methamphetamine', 'Morphine', 'PCP', 'Secobarbital and Oxazepam. The calibrators are designed for prescription use with immunoassays.'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K162042.txt': ['2. Indication(s) for use:'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K170974.txt': ['Purpose for Submission: New device']}, 'type of test': {'/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K150526.txt': ['Type of Test: Quantitative', 'turbidimetric'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K151046.txt': ['Type of Test: Qualitative in vitro diagnostic device for the direct detection and differentiation of HSV-1 and HSV-2 DNA in cutaneous and mucocutaneous lesion specimens from symptomatic patients suspected of Herpetic infections.'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K180886.txt': ['Type of Test: Quantitative Antimicrobial Susceptibility Test growth based detection'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K181525.txt': ['In the provided context', 'the type of test is mentioned as Quantitative immunoturbidimetric assay. This information is important for understanding the nature of the test being performed and the technology used in the analysis. Immunoturbidimetry is a common analytical technique used in clinical chemistry and immunology to measure the concentration of various substances in a sample', 'such as proteins or enzymes', 'by detecting the turbidity or cloudiness caused by the interaction between the analyte and specific antibodies or reagents. In this case', 'the test is being used to quantitatively determine the free protein S antigen in human plasma.'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K151265.txt': ['Type of Test: Quantitative Amperometric Assay; glucose dehydrogenase - flavin adenine dinucleotide (GDH-FAD)'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K171641.txt': ['Type of Test: RT-PCR amplification followed by hybridization and colorimetric visualization of amplified products on a test strip'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K182472.txt': ['Type of Test: The Cepheid Xpert GBS LB Control Panel is intended for use as external quality control materials to monitor the performance of in vitro laboratory nucleic acid testing procedures for the qualitative detection of Group B Streptococcus (GBS) performed with the Cepheid Xpert GBS LB Assay on the GeneXpert Instrument System.'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K161714.txt': ['Homogenous Enzyme Immunoassay', 'Qualitative and Semi-quantitative'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K162042.txt': ['Type of Test: Quantitative', 'mid-infrared (MIR) spectrophotometric assay'], '/Users/xiqiao/Documents/__research/evaporate-clean/data/evaporate/data/fda_510ks/data/evaporate/fda-ai-pmas/510k/K170974.txt': ['Quantitative and Semi-quantitative Flow Cytometric Immunoassays']}}\n",
146 | "{'mean': 0.723529274358417, 'std': 0.08378645930419634, 'result': {'purpose for submission': (0.606440811221299, 0.7472527472527473), 'type of test': (0.7014743670578509, 0.8389491719017703)}}\n"
147 | ]
148 | },
149 | {
150 | "name": "stderr",
151 | "output_type": "stream",
152 | "text": [
153 | "\n"
154 | ]
155 | }
156 | ],
157 | "source": [
158 | "with open(GOLD_PATH, \"r\") as f:\n",
159 | " gold = json.load(f)\n",
160 | " gold = gold[\"/data/evaporate/fda-ai-pmas/510k/K151917.txt\"]\n",
161 | "#state reference dict for using retrieval model\n",
162 | "direct_attribute, direct_eval = evaporate.direct_extract(is_getting_sample = True, gold = gold, use_retrieval_model= True)\n",
163 | "print(direct_attribute)\n",
164 | "print(direct_eval)"
165 | ]
166 | },
167 | {
168 | "cell_type": "markdown",
169 | "metadata": {},
170 | "source": [
171 | "# Get Attribute functions"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 7,
177 | "metadata": {},
178 | "outputs": [
179 | {
180 | "name": "stderr",
181 | "output_type": "stream",
182 | "text": [
183 | "Generating functions for attribute purpose for submission: 100%|██████████| 10/10 [07:08<00:00, 42.83s/it]\n"
184 | ]
185 | },
186 | {
187 | "name": "stdout",
188 | "output_type": "stream",
189 | "text": [
190 | "Timeout 0\n",
191 | "Timeout\n",
192 | "Timeout 0\n",
193 | "Timeout\n"
194 | ]
195 | },
196 | {
197 | "name": "stderr",
198 | "output_type": "stream",
199 | "text": [
200 | "Generating functions for attribute type of test: 100%|██████████| 10/10 [07:07<00:00, 42.80s/it]\n"
201 | ]
202 | },
203 | {
204 | "name": "stdout",
205 | "output_type": "stream",
206 | "text": [
207 | "Timeout 0\n",
208 | "Timeout\n",
209 | "Timeout 0\n",
210 | "Timeout\n"
211 | ]
212 | }
213 | ],
214 | "source": [
215 | "extract_functions = evaporate.get_extract_functions()"
216 | ]
217 | },
218 | {
219 | "cell_type": "markdown",
220 | "metadata": {},
221 | "source": [
222 | "# Weak Supervision (select the best method)"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": 8,
228 | "metadata": {},
229 | "outputs": [
230 | {
231 | "name": "stdout",
232 | "output_type": "stream",
233 | "text": [
234 | "Best script overall: function_94; Score: {'average_f1': 0.10487804878048781, 'median_f1': 0.0, 'extraction_fraction': 1.0, 'prior_average_f1': 0.10487804878048781, 'prior_median_f1': 0.0}\n",
235 | "Best script overall: function_0; Score: {'average_f1': 0.1, 'median_f1': 0.0, 'extraction_fraction': 1.0, 'prior_average_f1': 0.1, 'prior_median_f1': 0.0}\n"
236 | ]
237 | }
238 | ],
239 | "source": [
240 | "selected_keys = evaporate.weak_supervision(use_gold_key=True)"
241 | ]
242 | },
243 | {
244 | "cell_type": "markdown",
245 | "metadata": {},
246 | "source": [
247 | "# apply on the data lake"
248 | ]
249 | },
250 | {
251 | "cell_type": "code",
252 | "execution_count": 9,
253 | "metadata": {},
254 | "outputs": [
255 | {
256 | "name": "stdout",
257 | "output_type": "stream",
258 | "text": [
259 | "Apply the scripts to the data lake and save the metadata. Taking the top 10 scripts per field.\n",
260 | "Applying function function_94...\n",
261 | "Applying function function_1...\n",
262 | "Applying function function_2...\n",
263 | "Applying function function_3...\n",
264 | "Applying function function_5...\n",
265 | "Applying function function_6...\n",
266 | "Applying function function_9...\n",
267 | "Applying function function_10...\n",
268 | "Applying function function_11...\n",
269 | "Applying function function_12...\n"
270 | ]
271 | },
272 | {
273 | "name": "stderr",
274 | "output_type": "stream",
275 | "text": [
276 | "Applying key function_94: 100%|██████████| 100/100 [00:00<00:00, 234187.83it/s]\n",
277 | "Applying key function_1: 100%|██████████| 100/100 [00:00<00:00, 216871.98it/s]\n",
278 | "Applying key function_2: 100%|██████████| 100/100 [00:00<00:00, 163393.22it/s]\n",
279 | "Applying key function_3: 100%|██████████| 100/100 [00:00<00:00, 135869.91it/s]\n",
280 | "Applying key function_5: 100%|██████████| 100/100 [00:00<00:00, 184122.21it/s]\n",
281 | "Applying key function_6: 100%|██████████| 100/100 [00:00<00:00, 512125.03it/s]\n",
282 | "Applying key function_9: 100%|██████████| 100/100 [00:00<00:00, 497544.96it/s]\n",
283 | "Applying key function_10: 100%|██████████| 100/100 [00:00<00:00, 68635.31it/s]\n",
284 | "Applying key function_11: 100%|██████████| 100/100 [00:00<00:00, 158514.89it/s]\n",
285 | "Applying key function_12: 100%|██████████| 100/100 [00:00<00:00, 322638.77it/s]"
286 | ]
287 | },
288 | {
289 | "name": "stdout",
290 | "output_type": "stream",
291 | "text": [
292 | "Apply the scripts to the data lake and save the metadata. Taking the top 10 scripts per field.\n",
293 | "Applying function function_0...\n",
294 | "Applying function function_1...\n",
295 | "Applying function function_2...\n",
296 | "Applying function function_3...\n",
297 | "Applying function function_4...\n",
298 | "Applying function function_5...\n",
299 | "Applying function function_6...\n",
300 | "Applying function function_7...\n"
301 | ]
302 | },
303 | {
304 | "name": "stderr",
305 | "output_type": "stream",
306 | "text": [
307 | "\n"
308 | ]
309 | },
310 | {
311 | "name": "stdout",
312 | "output_type": "stream",
313 | "text": [
314 | "Applying function function_8...\n",
315 | "Applying function function_9...\n"
316 | ]
317 | },
318 | {
319 | "name": "stderr",
320 | "output_type": "stream",
321 | "text": [
322 | "Applying key function_0: 100%|██████████| 100/100 [00:00<00:00, 121785.83it/s]\n",
323 | "Applying key function_1: 100%|██████████| 100/100 [00:00<00:00, 40932.02it/s]\n",
324 | "Applying key function_2: 100%|██████████| 100/100 [00:00<00:00, 29112.96it/s]\n",
325 | "Applying key function_3: 100%|██████████| 100/100 [00:00<00:00, 96243.78it/s]\n",
326 | "Applying key function_4: 100%|██████████| 100/100 [00:00<00:00, 135474.94it/s]\n",
327 | "Applying key function_5: 100%|██████████| 100/100 [00:00<00:00, 35907.06it/s]\n",
328 | "Applying key function_6: 100%|██████████| 100/100 [00:00<00:00, 181179.44it/s]\n",
329 | "Applying key function_7: 100%|██████████| 100/100 [00:00<00:00, 279062.14it/s]\n",
330 | "Applying key function_8: 100%|██████████| 100/100 [00:00<00:00, 195995.51it/s]\n",
331 | "Applying key function_9: 100%|██████████| 100/100 [00:00<00:00, 172747.28it/s]"
332 | ]
333 | },
334 | {
335 | "name": "stdout",
336 | "output_type": "stream",
337 | "text": [
338 | "{'mean': 0.0027519179133103184, 'std': 0.00294827166632422, 'result': {'purpose for submission': (0.0040076716532412736, 0.0), 'type of test': (0.006999999999999999, 0.0)}}\n"
339 | ]
340 | },
341 | {
342 | "name": "stderr",
343 | "output_type": "stream",
344 | "text": [
345 | "\n"
346 | ]
347 | }
348 | ],
349 | "source": [
350 | "function_extract, function_eval = evaporate.apply_functions()\n",
351 | "print(function_eval)"
352 | ]
353 | }
354 | ],
355 | "metadata": {
356 | "kernelspec": {
357 | "display_name": "Python 3 (ipykernel)",
358 | "language": "python",
359 | "name": "python3"
360 | },
361 | "language_info": {
362 | "codemirror_mode": {
363 | "name": "ipython",
364 | "version": 3
365 | },
366 | "file_extension": ".py",
367 | "mimetype": "text/x-python",
368 | "name": "python",
369 | "nbconvert_exporter": "python",
370 | "pygments_lexer": "ipython3",
371 | "version": "3.8.18"
372 | }
373 | },
374 | "nbformat": 4,
375 | "nbformat_minor": 4
376 | }
377 |
--------------------------------------------------------------------------------
/evaporate/profiler_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import re
3 | import argparse
4 | import random
5 | from bs4 import BeautifulSoup
6 | from collections import Counter, defaultdict
7 |
8 |
9 | def set_profiler_args(profiler_args):
10 |
11 | parser = argparse.ArgumentParser(
12 | "LLM profiler.",
13 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
14 | )
15 |
16 | parser.add_argument(
17 | "--chunk_size",
18 | type=int,
19 | default=5000
20 | )
21 |
22 | parser.add_argument(
23 | "--train_size",
24 | type=int,
25 | default=15,
26 | )
27 |
28 | parser.add_argument(
29 | "--eval_size",
30 | type=int,
31 | default=15,
32 | )
33 |
34 | parser.add_argument(
35 | "--max_chunks_per_file",
36 | type=int,
37 | default=-1,
38 | )
39 |
40 | parser.add_argument(
41 | "--num_top_k_scripts",
42 | type=int,
43 | default=1,
44 | help="of all the scripts we generate for the metadata fields, number to retain after scoring their qualities",
45 | )
46 |
47 | parser.add_argument(
48 | "--extraction_fraction_thresh",
49 | type=int,
50 | default=0.9,
51 | help="for abstensions approach",
52 | )
53 |
54 | parser.add_argument(
55 | "--remove_tables",
56 | type=bool,
57 | default=False,
58 | help="Remove tables from the html files?",
59 | )
60 |
61 | parser.add_argument(
62 | "--body_only",
63 | type=bool,
64 | default=False,
65 | help="Only use HTML body",
66 | )
67 |
68 | parser.add_argument(
69 | "--max_metadata_fields",
70 | type=int,
71 | default=15,
72 | )
73 |
74 | parser.add_argument(
75 | "--use_dynamic_backoff",
76 | type=bool,
77 | default=True,
78 | help="Whether to do the function generation workflow or directly extract from chunks",
79 | )
80 |
81 | parser.add_argument(
82 | "--use_qa_model",
83 | type=bool,
84 | default=False,
85 | help="Whether to apply the span-extractor QA model.",
86 | )
87 |
88 | parser.add_argument(
89 | "--overwrite_cache",
90 | type=int,
91 | default=0,
92 | help="overwrite the manifest cache"
93 | )
94 |
95 | # models to use in the pipeline
96 | parser.add_argument(
97 | "--MODELS",
98 | type=list,
99 | help="models to use in the pipeline"
100 | )
101 |
102 | parser.add_argument(
103 | "--KEYS",
104 | type=list,
105 | help="keys for openai models"
106 | )
107 |
108 | parser.add_argument(
109 | "--GOLDKEY",
110 | type=str,
111 | help="models to use in the pipeline"
112 | )
113 |
114 | parser.add_argument(
115 | "--MODEL2URL",
116 | type=dict,
117 | default={},
118 | help="models to use in the pipeline"
119 | )
120 |
121 | parser.add_argument(
122 | "--swde_plus",
123 | type=bool,
124 | default=False,
125 | help="Whether to use the extended SWDE dataset to measure OpenIE performance",
126 | )
127 |
128 | parser.add_argument(
129 | "--schema_id_sizes",
130 | type=int,
131 | default=0,
132 | help="Number of documents to use for schema identification stage, if it differs from extraction",
133 | )
134 |
135 | parser.add_argument(
136 | "--slice_results",
137 | type=bool,
138 | default=False,
139 | help="Whether to measure the results by attribute-slice",
140 | )
141 |
142 | parser.add_argument(
143 | "--fn_generation_prompt_num",
144 | type=int,
145 | default=-1,
146 | help="For ablations on function generation with diversity, control which prompt we use. Default is all.",
147 | )
148 |
149 | parser.add_argument(
150 | "--upper_bound_fns",
151 | type=bool,
152 | default=False,
153 | help="For ablations that select functions using ground truth instead of the FM.",
154 | )
155 |
156 | parser.add_argument(
157 | "--combiner_mode",
158 | type=str,
159 | default='mv',
160 | help="For ablations that select functions using ground truth instead of the FM.",
161 | )
162 |
163 | parser.add_argument(
164 | "--use_alg_filtering",
165 | type=str,
166 | default=True,
167 | help="Whether to filter functions based on quality.",
168 | )
169 |
170 | parser.add_argument(
171 | "--use_abstension",
172 | type=str,
173 | default=True,
174 | help="Whether to use the abstensions approach.",
175 | )
176 |
177 | args = parser.parse_args(args=[])
178 | for arg, val in profiler_args.items():
179 | setattr(args, arg, val)
180 | return args
181 |
182 |
183 | #################### GET SOME SAMPLE FILES TO SEED THE METADATA SEARCH #########################
184 |
185 | def sample_scripts(files, train_size=5):
186 | # "Train" split
187 | random.seed(0)
188 | if train_size <= len(files):
189 | sample_files = random.sample(files, train_size)
190 | else:
191 | sample_files = files
192 | sample_contents = []
193 | for sample_file in sample_files:
194 | with open(sample_file, 'r') as f:
195 | sample_contents.append(f.read())
196 | return sample_files
197 |
198 |
199 | #################### BOILERPLATE CHUNKING CODE, CRITICAL FOR LONG SEUQENCES ####################
200 | def chunk_file(
201 | parser, file, chunk_size=5000, mode="train", remove_tables=False, body_only=False
202 | ):
203 | content = get_file_contents(file)
204 | if "html" in parser:
205 | content, chunks = get_html_parse(
206 | content,
207 | chunk_size=chunk_size,
208 | mode=mode,
209 | remove_tables=remove_tables,
210 | body_only=body_only
211 | )
212 | else:
213 | content, chunks = get_txt_parse(content, chunk_size=chunk_size, mode=mode)
214 | return content, chunks
215 |
216 |
217 | # HTML --> CHUNKS
218 | def clean_html(content):
219 | for tag in ['script', 'style', 'svg']:
220 | content = content.split("\n")
221 | clean_content = []
222 | in_script = 0
223 | for c in content:
224 | if c.strip().strip("\t").startswith(f"<{tag}"):
225 | in_script = 1
226 | endstr = "" + tag # + ">"
227 | if endstr in c or "/>" in c:
228 | in_script = 0
229 | if not in_script:
230 | clean_content.append(c)
231 | content = "\n".join(clean_content)
232 | return content
233 |
234 |
235 | def get_flattened_items(content, chunk_size=500):
236 | flattened_divs = str(content).split("\n")
237 | flattened_divs = [ch for ch in flattened_divs if ch.strip() and ch.strip("\n").strip()]
238 |
239 | clean_flattened_divs = []
240 | for div in flattened_divs:
241 | if len(str(div)) > chunk_size:
242 | sub_divs = div.split("><")
243 | if len(sub_divs) == 1:
244 | clean_flattened_divs.append(div)
245 | else:
246 | clean_flattened_divs.append(sub_divs[0] + ">")
247 | for sd in sub_divs[1:-1]:
248 | clean_flattened_divs.append("<" + sd + ">")
249 | clean_flattened_divs.append("<" + sub_divs[-1])
250 | else:
251 | clean_flattened_divs.append(div)
252 | return clean_flattened_divs
253 |
254 |
255 | def get_html_parse(content, chunk_size=5000, mode="train", remove_tables=False, body_only=False):
256 | if remove_tables:
257 | soup = BeautifulSoup(content)
258 | tables = soup.find_all("table")
259 | for table in tables:
260 | if "infobox" not in str(table):
261 | content = str(soup)
262 | content = content.replace(str(table), "")
263 | soup = BeautifulSoup(content)
264 |
265 | if body_only:
266 | soup = BeautifulSoup(content)
267 | content = str(soup.find("body"))
268 | soup = BeautifulSoup(content)
269 |
270 | else:
271 | content = clean_html(content)
272 | clean_flattened_divs = []
273 | flattened_divs = get_flattened_items(content, chunk_size=chunk_size)
274 | for i, div in enumerate(flattened_divs):
275 | new_div = re.sub(r'style="[^"]*"', '', str(div))
276 | new_div = re.sub(r'', '', str(new_div))
277 | new_div = re.sub(r'', '', str(new_div))
278 | new_div = re.sub(r'', '', str(new_div))
279 | new_div = "\n".join([l for l in new_div.split("\n") if l.strip() and l.strip("\n").strip()])
280 | # new_div = BeautifulSoup(new_div) #.fsind_all("div")[0]
281 | if new_div:
282 | clean_flattened_divs.append(new_div)
283 |
284 | if mode == "eval":
285 | return []
286 |
287 | grouped_divs = []
288 | current_div = []
289 | current_length = 0
290 | max_length = chunk_size
291 | join_str = " " if use_raw_text else "\n"
292 | for div in clean_flattened_divs:
293 | str_div = str(div)
294 | len_div = len(str_div)
295 | if (current_length + len_div > max_length):
296 | grouped_divs.append(join_str.join(current_div))
297 | current_div = []
298 | current_length = 0
299 | elif not current_div and (current_length + len_div > max_length):
300 | grouped_divs.append(str_div)
301 | continue
302 | current_div.append(str_div)
303 | current_length += len_div
304 |
305 | return content, grouped_divs
306 |
307 |
308 | # GENERIC TXT --> CHUNKS
309 | def get_txt_parse(content, chunk_size=5000, mode="train"):
310 | # convert to chunks
311 | if mode == "train":
312 | chunks = content.split("\n")
313 | clean_chunks = []
314 | for chunk in chunks:
315 | if len(chunk) > chunk_size:
316 | sub_chunks = chunk.split(". ")
317 | clean_chunks.extend(sub_chunks)
318 | else:
319 | clean_chunks.append(chunk)
320 |
321 | chunks = clean_chunks.copy()
322 | clean_chunks = []
323 | for chunk in chunks:
324 | if len(chunk) > chunk_size:
325 | sub_chunks = chunk.split(", ")
326 | clean_chunks.extend(sub_chunks)
327 | else:
328 | clean_chunks.append(chunk)
329 |
330 | final_chunks = []
331 | cur_chunk = []
332 | cur_chunk_size = 0
333 | for chunk in clean_chunks:
334 | if cur_chunk_size + len(chunk) > chunk_size:
335 | final_chunks.append("\n".join(cur_chunk))
336 | cur_chunk = []
337 | cur_chunk_size = 0
338 | cur_chunk.append(chunk)
339 | cur_chunk_size += len(chunk)
340 | if cur_chunk:
341 | final_chunks.append("\n".join(cur_chunk))
342 | else:
343 | final_chunks = []
344 | return content, final_chunks
345 |
346 |
347 | def get_file_contents(file):
348 | text = ''
349 | if file.endswith(".swp"):
350 | return text
351 | try:
352 | with open(file) as f:
353 | text = f.read()
354 | except:
355 | with open(file, "rb") as f:
356 | text = f.read().decode("utf-8", "ignore")
357 | return text
358 |
359 |
360 | def clean_metadata(field):
361 | return field.replace("\t", " ").replace("\n", " ").strip().lower()
362 |
363 |
364 | def filter_file2chunks(file2chunks, sample_files, attribute):
365 | def get_attribute_parts(attribute):
366 | for char in ["/", "-", "(", ")", "[", "]", "{", "}", ":"]:
367 | attribute = attribute.replace(char, " ")
368 | attribute_parts = attribute.lower().split()
369 | return attribute_parts
370 |
371 | # filter chunks with simple keyword search
372 | attribute_chunks = defaultdict(list)
373 | starting_num_chunks = 0
374 | ending_num_chunks = 0
375 | ending_in_sample_chunks = 0
376 | starting_in_sample_chunks = 0
377 | for file, chunks in file2chunks.items():
378 | starting_num_chunks += len(chunks)
379 | if file in sample_files:
380 | starting_in_sample_chunks += len(chunks)
381 | cleaned_chunks = []
382 | for chunk in chunks:
383 | if attribute.lower() in chunk.lower():
384 | cleaned_chunks.append(chunk)
385 | if len(cleaned_chunks) == 0:
386 | for chunk in chunks:
387 | if attribute.lower().replace(" ", "") in chunk.lower().replace(" ", ""):
388 | cleaned_chunks.append(chunk)
389 | if len(cleaned_chunks) == 0:
390 | chunk2num_word_match = Counter()
391 | for chunk_num, chunk in enumerate(chunks):
392 | attribute_parts = get_attribute_parts(attribute.lower())
393 | for wd in attribute_parts:
394 | if wd.lower() in chunk.lower():
395 | chunk2num_word_match[chunk_num] += 1
396 | # sort chunks by number of words that match
397 | sorted_chunks = sorted(chunk2num_word_match.items(), key=lambda x: x[1], reverse=True)
398 | if len(sorted_chunks) > 0:
399 | cleaned_chunks.append(chunks[sorted_chunks[0][0]])
400 | if len(sorted_chunks) > 1:
401 | cleaned_chunks.append(chunks[sorted_chunks[1][0]])
402 | ending_num_chunks += len(cleaned_chunks)
403 | num_chunks = len(cleaned_chunks)
404 | num_chunks = min(num_chunks, 2)
405 |
406 | cleaned_chunks = cleaned_chunks[:num_chunks]
407 | attribute_chunks[file] = cleaned_chunks
408 | if file in sample_files:
409 | ending_in_sample_chunks += len(attribute_chunks[file])
410 | file2chunks = attribute_chunks
411 | if ending_num_chunks == 0 or ending_in_sample_chunks == 0:
412 | print(f"Removing because no chunks for attribute {attribute} in any file")
413 | return None
414 | print(f"For attribute {attribute}\n-- Starting with {starting_num_chunks} chunks\n-- Ending with {ending_num_chunks} chunks")
415 | print(f"-- {starting_in_sample_chunks} starting chunks in sample files\n-- {ending_in_sample_chunks} chunks in sample files")
416 |
417 | return file2chunks
418 |
419 |
420 | def clean_function_predictions(extraction, attribute=None):
421 | if extraction is None:
422 | return ''
423 | if type(extraction) == list:
424 | if extraction and type(extraction[0]) == list:
425 | full_answer = []
426 | for answer in extraction:
427 | if type(answer) == list:
428 | dedup_list = []
429 | for a in answer:
430 | if a not in dedup_list:
431 | dedup_list.append(a)
432 | answer = dedup_list
433 | answer = [str(a).strip().strip("\n") for a in answer]
434 | full_answer.append(", ".join(answer))
435 | else:
436 | full_answer.append(answer.strip().strip("\n"))
437 | full_answer = [a.strip() for a in full_answer]
438 | extraction = ", ".join(full_answer)
439 | elif extraction and len(extraction) == 1 and extraction[0] is None:
440 | extraction = ''
441 | else:
442 | dedup_list = []
443 | for a in extraction:
444 | if a not in dedup_list:
445 | dedup_list.append(a)
446 | extraction = dedup_list
447 | extraction = [(str(e)).strip().strip("\n") for e in extraction]
448 | extraction = ", ".join(extraction)
449 | if type(extraction) == "str" and extraction.lower() == "none":
450 | extraction = ""
451 | extraction = extraction.strip().replace(" ", " ")
452 | if extraction.lower().startswith(attribute.lower()):
453 | idx = extraction.lower().find(attribute.lower())
454 | extraction = extraction[idx+len(attribute):].strip()
455 | for char in [':', ","]:
456 | extraction = extraction.strip(char).strip()
457 | extraction = extraction.replace(",", ", ").replace(" ", " ")
458 | return extraction
459 |
460 |
461 | def check_vs_train_extractions(train_extractions, final_extractions, gold_key, attribute = None):
462 | clean_final_extractions = {}
463 |
464 | gold_values = train_extractions[gold_key]
465 | modes = []
466 | start_toks = []
467 | end_toks = []
468 | for file, gold in gold_values.items():
469 | if type(gold) == dict:
470 | gold = gold[attribute]
471 | if type(gold) == list:
472 | if gold and type(gold[0]) == list:
473 | gold = [g[0] for g in gold]
474 | gold = ", ".join(gold)
475 | else:
476 | gold = ", ".join(gold)
477 | gold = gold.lower()
478 | pred = final_extractions[file].lower()
479 | if not pred or not gold:
480 | continue
481 | if ("<" in pred and "<" not in gold) or (">" in pred and ">" not in gold):
482 | check_pred = BeautifulSoup(pred).text
483 | if check_pred in gold or gold in check_pred:
484 | modes.append("soup")
485 | elif gold in pred and len(pred) > len(gold):
486 | modes.append("longer")
487 | idx = pred.index(gold)
488 | if idx > 0:
489 | start_toks.append(pred[:idx-1])
490 | end_idx = idx + len(gold)
491 | if end_idx < len(pred):
492 | end_toks.append(pred[end_idx:])
493 |
494 | def long_substr(data):
495 | substr = ''
496 | if len(data) > 1 and len(data[0]) > 0:
497 | for i in range(len(data[0])):
498 | for j in range(len(data[0])-i+1):
499 | if j > len(substr) and is_substr(data[0][i:i+j], data):
500 | substr = data[0][i:i+j]
501 | return substr
502 |
503 | def is_substr(find, data):
504 | if len(data) < 1 and len(find) < 1:
505 | return False
506 | for i in range(len(data)):
507 | if find not in data[i]:
508 | return False
509 | return True
510 |
511 | longest_end_tok = long_substr(end_toks)
512 | longest_start_tok = long_substr(start_toks)
513 | if len(set(modes)) == 1:
514 | num_golds = len(gold_values)
515 | for file, extraction in final_extractions.items():
516 | if "longer" in modes:
517 | # gold longer than pred
518 | if len(end_toks) == num_golds and longest_end_tok in extraction and extraction.count(longest_end_tok) == 1:
519 | idx = extraction.index(longest_end_tok)
520 | extraction = extraction[:idx]
521 | if len(start_toks) == num_golds and longest_start_tok in extraction and extraction.count(longest_start_tok) == 1:
522 | idx = extraction.index(longest_start_tok)
523 | extraction = extraction[idx:]
524 | elif "soup" in modes:
525 | extraction = BeautifulSoup(extraction).text
526 | clean_final_extractions[file] = extraction
527 | else:
528 | clean_final_extractions = final_extractions
529 | return clean_final_extractions
530 |
--------------------------------------------------------------------------------
/evaporate/run_profiler.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import random
4 | import json
5 | import datetime
6 | from tqdm import tqdm
7 | import pickle
8 | import argparse
9 | from collections import defaultdict, Counter
10 |
11 | from evaporate.utils import get_structure, get_manifest_sessions, get_file_attribute
12 | from evaporate.profiler_utils import chunk_file, sample_scripts
13 | from evaporate.schema_identification import identify_schema
14 | from evaporate.profiler import run_profiler, get_file_attribute
15 | from evaporate.evaluate_synthetic import main as evaluate_synthetic_main
16 | from evaporate.configs import get_experiment_args
17 |
18 |
19 | random.seed(0)
20 | def get_data_lake_info(args, data_lake):
21 | extractions_file = None
22 |
23 | if 1:
24 | DATA_DIR = args.data_dir
25 | file_groups = os.listdir(args.data_dir)
26 | if not DATA_DIR.endswith("/"):
27 | DATA_DIR += "/"
28 | file_groups = [f"{DATA_DIR}{file_group}" for file_group in file_groups if not file_group.startswith(".")]
29 | full_file_groups = file_groups.copy()
30 | extractions_file = args.gold_extractions_file
31 | parser = "txt"
32 |
33 | return file_groups, extractions_file, parser, full_file_groups
34 |
35 |
36 | def chunk_files(file_group, parser, chunk_size, remove_tables, max_chunks_per_file, body_only):
37 | file2chunks = {}
38 | file2contents = {}
39 | for file in tqdm(file_group, total=len(file_group), desc="Chunking files"):
40 | content, chunks = chunk_file(
41 | parser,
42 | file,
43 | chunk_size=chunk_size,
44 | mode="train",
45 | remove_tables=remove_tables,
46 | body_only=body_only
47 | )
48 | if max_chunks_per_file > 0:
49 | chunks = chunks[:max_chunks_per_file]
50 | file2chunks[file] = chunks
51 | file2contents[file] = content
52 | return file2chunks, file2contents
53 |
54 |
55 | # chunking & preparing data
56 | def prepare_data(profiler_args, file_group, data_args, parser = "html"):
57 | data_lake = profiler_args.data_lake
58 | if profiler_args.body_only:
59 | body_only = profiler_args.body_only
60 | suffix = f"_bodyOnly{body_only}"
61 | else:
62 | suffix = ""
63 | # prepare the datalake: chunk all files
64 | manifest_sessions = get_manifest_sessions(profiler_args.MODELS, MODEL2URL=profiler_args.MODEL2URL, KEYS=profiler_args.KEYS)
65 | if os.path.exists(f"{data_args.cache_dir}/{data_lake}_size{len(file_group)}_chunkSize{profiler_args.chunk_size}_{suffix}_file2chunks.pkl"):
66 | with open(f"{data_args.cache_dir}/{data_lake}_size{len(file_group)}_chunkSize{profiler_args.chunk_size}_{suffix}_file2chunks.pkl", "rb") as f:
67 | file2chunks = pickle.load(f)
68 | with open(f"{data_args.cache_dir}/{data_lake}_size{len(file_group)}_chunkSize{profiler_args.chunk_size}_{suffix}_file2contents.pkl", "rb") as f:
69 | file2contents = pickle.load(f)
70 | else:
71 | file2chunks, file2contents = chunk_files(
72 | file_group,
73 | parser,
74 | profiler_args.chunk_size,
75 | profiler_args.remove_tables,
76 | profiler_args.max_chunks_per_file,
77 | profiler_args.body_only
78 | )
79 | if not os.path.exists(data_args.cache_dir):
80 | os.mkdir(data_args.cache_dir)
81 | with open(f"{data_args.cache_dir}/{data_lake}_size{len(file_group)}_chunkSize{profiler_args.chunk_size}_removeTables{profiler_args.remove_tables}{suffix}_file2chunks.pkl", "wb") as f:
82 | pickle.dump(file2chunks, f)
83 | with open(f"{data_args.cache_dir}/{data_lake}_size{len(file_group)}_chunkSize{profiler_args.chunk_size}_removeTables{profiler_args.remove_tables}{suffix}_file2contents.pkl", "wb") as f:
84 | pickle.dump(file2contents, f)
85 | return file2chunks, file2contents, manifest_sessions
86 |
87 |
88 | def get_run_string(
89 | data_lake, today, file_groups, profiler_args, do_end_to_end,
90 | train_size, dynamicbackoff, models
91 | ):
92 | body = profiler_args.body_only # Baseline systems only operate on the HTML body
93 | model_ct = len(models)
94 | if profiler_args.use_qa_model:
95 | model_ct += 1
96 | run_string = f"dataLake{data_lake}_date{today}_fileSize{len(file_groups)}_trainSize{train_size}_numAggregate{profiler_args.num_top_k_scripts}_chunkSize{profiler_args.chunk_size}_removeTables{profiler_args.remove_tables}_body{body}_cascading{do_end_to_end}_useBackoff{dynamicbackoff}_MODELS{model_ct}"
97 | return run_string
98 |
99 |
100 | def get_gold_metadata(args):
101 | # get the list of gold metadata for closed-IE runs
102 | try:
103 | with open(args.gold_extractions_file) as f:
104 | gold_file2extractions = json.load(f)
105 | except:
106 | with open(args.gold_extractions_file, "rb") as f:
107 | gold_file2extractions = pickle.load(f)
108 | frequency = Counter()
109 | for file, dic in gold_file2extractions.items():
110 | for k, v in dic.items():
111 | if k != "topic_entity_name":
112 | if type(v) == str and v:
113 | frequency[k] += 1
114 | elif type(v) == list and v and v[0]:
115 | frequency[k] += 1
116 | sorted_frequency = sorted(frequency.items(), key=lambda x: x[1], reverse=True)
117 | gold_metadata = [x[0] for x in sorted_frequency]
118 | gold_attributes = [m.lower() for m in gold_metadata if m not in ['topic_entity_name']]
119 | return gold_attributes
120 |
121 |
122 | def determine_attributes_to_remove(attributes, args, run_string, num_attr_to_cascade):
123 | attributes_reordered = {}
124 | attributes_to_remove = []
125 | attributes_to_metrics = {}
126 | attribute_to_first_extractions = {}
127 | mappings_names = {}
128 | for num, attribute in enumerate(attributes):
129 | attribute = attribute.lower()
130 | file_attribute = get_file_attribute(attribute)
131 | if not os.path.exists(f"{args.generative_index_path}/{run_string}_{file_attribute}_all_metrics.json"):
132 | continue
133 | if not os.path.exists(f"{args.generative_index_path}/{run_string}_{file_attribute}_file2metadata.json"):
134 | continue
135 | if num >= num_attr_to_cascade:
136 | os.remove(f"{args.generative_index_path}/{run_string}_{file_attribute}_all_metrics.json")
137 | os.remove(f"{args.generative_index_path}/{run_string}_{file_attribute}_file2metadata.json")
138 | continue
139 | with open(f"{args.generative_index_path}/{run_string}_{file_attribute}_all_metrics.json") as f:
140 | metrics = json.load(f)
141 | with open(f"{args.generative_index_path}/{run_string}_{file_attribute}_top_k_keys.json") as f:
142 | selected_keys = json.load(f)
143 | with open(f"{args.generative_index_path}/{run_string}_{file_attribute}_file2metadata.json") as f:
144 | file2metadata = json.load(f)
145 | attributes_reordered[attribute] = metrics[selected_keys[0]]
146 |
147 | if selected_keys and metrics:
148 | for a, m in attributes_to_metrics.items():
149 | if attribute.lower() in a.lower() or a.lower() in attribute.lower():
150 | if m == metrics[selected_keys[0]]['average_f1']:
151 | attributes_to_remove.append(attribute)
152 | mappings_names[a] = attribute
153 | mappings_names[attribute] = a
154 | break
155 |
156 | first_extractions = [m for i, (f, m) in enumerate(file2metadata.items()) if i < 5]
157 | if any(f != "" for f in first_extractions):
158 | first_extractions = " ".join(first_extractions)
159 | for a, m in attribute_to_first_extractions.items():
160 | if m == first_extractions:
161 | attributes_to_remove.append(attribute)
162 | mappings_names[a] = attribute
163 | mappings_names[attribute] = a
164 | break
165 |
166 | if attribute in attributes_to_remove:
167 | continue
168 | if selected_keys:
169 | attributes_to_metrics[attribute] = metrics[selected_keys[0]]['average_f1']
170 | attribute_to_first_extractions[attribute] = first_extractions
171 | return attributes_to_remove, mappings_names, attributes
172 |
173 |
174 | def measure_openie_results(
175 | attributes,
176 | args,
177 | profiler_args,
178 | run_string,
179 | gold_attributes,
180 | attributes_to_remove,
181 | file_groups,
182 | mappings_names
183 | ):
184 | file2extractions = defaultdict(dict)
185 | unique_attributes = set()
186 | num_extractions2results = {}
187 | data_lake = profiler_args.data_lake
188 | for attr_num, attribute in enumerate(attributes):
189 | attribute = attribute.lower()
190 | file_attribute = get_file_attribute(attribute)
191 | if os.path.exists(f"{args.generative_index_path}/{run_string}_{file_attribute}_file2metadata.json"):
192 | if attribute in attributes_to_remove:
193 | print(f"Removing: {attribute}")
194 | os.remove(f"{args.generative_index_path}/{run_string}_{file_attribute}_file2metadata.json")
195 | continue
196 |
197 | with open(f"{args.generative_index_path}/{run_string}_{file_attribute}_file2metadata.json") as f:
198 | file2metadata = json.load(f)
199 | for file, extraction in file2metadata.items():
200 | file2extractions[file][attribute] = extraction
201 | unique_attributes.add(attribute)
202 |
203 | if file2extractions:
204 | num_extractions = len(unique_attributes)
205 | nums = [1, 2, 3, 4, len(attributes) - 1, len(gold_attributes)]
206 | if file2extractions and ((num_extractions) % 5 == 0 or num_extractions in nums) or attr_num == len(attributes) - 1:
207 | if num_extractions in num_extractions2results:
208 | continue
209 | with open(f"{args.generative_index_path}/{run_string}_file2extractions.json", "w") as f:
210 | json.dump(file2extractions, f)
211 |
212 | results = evaluate_synthetic_main(
213 | run_string,
214 | args,
215 | profiler_args,
216 | data_lake,
217 | sample_files=file_groups,
218 | stage='openie',
219 | mappings_names=mappings_names
220 | )
221 | num_extractions2results[num_extractions] = results
222 | return num_extractions2results
223 |
224 | def prerun_profiler(profiler_args):
225 | file2chunks, file2contents, manifest_sessions = prepare_data(
226 | profiler_args, profiler_args.full_file_groups, profiler_args, profiler_args.parser
227 | )
228 | manifest_sessions = {
229 | k: v for k, v in manifest_sessions.items() if k in profiler_args.MODELS
230 | }
231 | gold_attributes = get_gold_metadata(profiler_args)
232 | try:
233 | with open(profiler_args.gold_extractions_file) as f:
234 | gold_extractions_tmp = json.load(f)
235 | except:
236 | with open(profiler_args.gold_extractions_file, "rb") as f:
237 | gold_extractions_tmp = pickle.load(f)
238 | gold_extractions = {}
239 | for file, extractions in gold_extractions_tmp.items():
240 | gold_extractions[os.path.join(profiler_args.data_dir, file.split('/')[-1])] = extractions
241 | manifest_sessions[profiler_args.GOLD_KEY] = {}
242 | manifest_sessions[profiler_args.GOLD_KEY]['__name'] = 'gold_extraction_file'
243 | for attribute in gold_attributes:
244 | manifest_sessions[profiler_args.GOLD_KEY][attribute] = {}
245 | for file in profiler_args.file_groups:
246 | manifest_sessions[profiler_args.GOLD_KEY][attribute][file] = gold_extractions[file][attribute]
247 |
248 | sample_files = sample_scripts(
249 | profiler_args.file_groups,
250 | train_size=profiler_args.train_size,
251 | )
252 | data_dict = {
253 | "file2chunks": file2chunks,
254 | "file2contents": file2contents,
255 | "manifest_sessions": manifest_sessions,
256 | "gold_attributes": gold_attributes,
257 | "sample_files" : sample_files,
258 | "gold_extractions": gold_extractions
259 | }
260 | return data_dict
261 |
262 | def identify_attributes(profiler_args, data_dict, evaluation = False):
263 | file2chunks = data_dict["file2chunks"]
264 | file2contents = data_dict["file2contents"]
265 | manifest_sessions = data_dict["manifest_sessions"]
266 | sample_files = sample_scripts(
267 | profiler_args.file_groups,
268 | train_size=profiler_args.train_size,
269 | )
270 | t0 = time.time()
271 | num_toks = identify_schema(
272 | profiler_args.run_string,
273 | profiler_args,
274 | file2chunks,
275 | file2contents,
276 | sample_files,
277 | manifest_sessions,
278 | profiler_args.data_lake,
279 | profiler_args
280 | )
281 | t1 = time.time()
282 | with open(f"{profiler_args.generative_index_path}/{profiler_args.run_string}_identified_schema.json") as f:
283 | most_common_fields = json.load(f)
284 | with open(f"{profiler_args.generative_index_path}/{profiler_args.run_string}_order_of_addition.json") as f:
285 | order_of_addition = json.load(f)
286 | order = {item: (len(order_of_addition) - i) for i, item in enumerate(order_of_addition)}
287 | ctr = Counter(most_common_fields)
288 | pred_metadata = sorted(
289 | ctr.most_common(profiler_args.num_attr_to_cascade),
290 | key=lambda x: (x[1], order[x[0]]),
291 | reverse=True
292 | )
293 | attributes = [item[0].lower() for item in pred_metadata]
294 | if evaluation :
295 | evaluation_result = evaluate_synthetic_main(
296 | profiler_args.run_string,
297 | profiler_args,
298 | profiler_args,
299 | profiler_args.data_lake,
300 | stage='schema_id'
301 | )
302 | else:
303 | evaluation_result = None
304 | return attributes, t1-t0, num_toks, evaluation_result
305 |
306 | def get_attribute_function(profiler_args, data_dict, attribute):
307 | sample_files = sample_scripts(
308 | profiler_args.file_groups,
309 | train_size=profiler_args.train_size,
310 | )
311 | t0 = time.time()
312 | num_toks, success = run_profiler(
313 | profiler_args.run_string,
314 | profiler_args,
315 | data_dict["file2chunks"],
316 | data_dict["file2contents"],
317 | sample_files,
318 | profiler_args.full_file_groups,
319 | data_dict["manifest_sessions"],
320 | attribute,
321 | profiler_args
322 | )
323 | t1 = time.time()
324 | try:
325 | file_attribute = get_file_attribute(attribute)
326 | with open(f"{profiler_args.generative_index_path}/{profiler_args.run_string}_{file_attribute}_functions.json") as f:
327 | function_dictionary = json.load(f)
328 | with open(f"{profiler_args.generative_index_path}/{profiler_args.run_string}_{file_attribute}_top_k_keys.json") as f:
329 | selected_keys = json.load(f)
330 | except:
331 | selected_keys = None
332 | function_dictionary = None
333 | return function_dictionary, selected_keys, t1-t0, num_toks
334 | def run_experiment(profiler_args):
335 | do_end_to_end = profiler_args.do_end_to_end
336 | num_attr_to_cascade = profiler_args.num_attr_to_cascade
337 | train_size = profiler_args.train_size
338 | data_lake = profiler_args.data_lake
339 |
340 | print(f"Data lake")
341 | today = datetime.datetime.today().strftime("%m%d%Y")
342 |
343 | _, _, _, _, args = get_structure(data_lake, profiler_args)
344 | file_groups, extractions_file, parser, full_file_groups = get_data_lake_info(args, data_lake)
345 | file2chunks, file2contents, manifest_sessions = prepare_data(
346 | profiler_args, full_file_groups, args, parser
347 | )
348 | manifest_sessions = {
349 | k: v for k, v in manifest_sessions.items() if k in profiler_args.MODELS
350 | }
351 | gold_attributes = get_gold_metadata(args)
352 |
353 | results_by_train_size = defaultdict(dict)
354 | total_time_dict = defaultdict(dict)
355 |
356 | if 1:
357 | total_tokens_prompted = 0
358 |
359 | print(f"\n\nData-lake: {data_lake}, Train size: {train_size}")
360 | setattr(profiler_args, 'train_size', train_size)
361 |
362 | run_string = get_run_string(
363 | data_lake, today, full_file_groups, profiler_args,
364 | do_end_to_end, train_size,
365 | profiler_args.use_dynamic_backoff,
366 | profiler_args.EXTRACTION_MODELS,
367 | )
368 |
369 | sample_files = sample_scripts(
370 | file_groups,
371 | train_size=profiler_args.train_size,
372 | )
373 |
374 | # top-level schema identification
375 | if do_end_to_end:
376 | t0 = time.time()
377 | num_toks = identify_schema(
378 | run_string,
379 | args,
380 | file2chunks,
381 | file2contents,
382 | sample_files,
383 | manifest_sessions,
384 | data_lake,
385 | profiler_args
386 | )
387 | t1 = time.time()
388 | total_time = t1-t0
389 | total_tokens_prompted += num_toks
390 | total_time_dict[f'schemaId'][f'totalTime_trainSize{train_size}'] = int(total_time)
391 |
392 | results = evaluate_synthetic_main(
393 | run_string,
394 | args,
395 | profiler_args,
396 | data_lake,
397 | stage='schema_id'
398 | )
399 | results_by_train_size[train_size]['schema_id'] = results
400 |
401 | if 1:
402 | if do_end_to_end:
403 | with open(f"{args.generative_index_path}/{run_string}_identified_schema.json") as f:
404 | most_common_fields = json.load(f)
405 | with open(f"{args.generative_index_path}/{run_string}_order_of_addition.json") as f:
406 | order_of_addition = json.load(f)
407 | order = {item: (len(order_of_addition) - i) for i, item in enumerate(order_of_addition)}
408 | ctr = Counter(most_common_fields)
409 | pred_metadata = sorted(
410 | ctr.most_common(num_attr_to_cascade),
411 | key=lambda x: (x[1], order[x[0]]),
412 | reverse=True
413 | )
414 | attributes = [item[0].lower() for item in pred_metadata]
415 | else:
416 | attributes = gold_attributes
417 |
418 | # top-level information extraction
419 | num_collected = 0
420 | for i, attribute in enumerate(attributes):
421 | print(f"\n\nExtracting {attribute} ({i+1} / {len(attributes)})")
422 | t0 = time.time()
423 | num_toks, success = run_profiler(
424 | run_string,
425 | args,
426 | file2chunks,
427 | file2contents,
428 | sample_files,
429 | full_file_groups,
430 | manifest_sessions,
431 | attribute,
432 | profiler_args
433 | )
434 | t1 = time.time()
435 | total_time = t1-t0
436 | total_tokens_prompted += num_toks
437 | total_time_dict[f'extract'][f'totalTime_trainSize{train_size}'] = int(total_time)
438 | if success:
439 | num_collected += 1
440 | if num_collected >= num_attr_to_cascade:
441 | break
442 |
443 | # run closed ie eval
444 | results = evaluate_synthetic_main(
445 | run_string,
446 | args,
447 | profiler_args,
448 | data_lake,
449 | gold_attributes=gold_attributes,
450 | stage='extract'
451 | )
452 | results_by_train_size[train_size]['extract'] = results
453 |
454 | # Determine whether to remove any attributes based on the extractions
455 | # Potentially can rerank the attributes based on the metric comparison to big model
456 | if do_end_to_end:
457 | attributes_to_remove, mappings_names, attributes = determine_attributes_to_remove(
458 | attributes,
459 | args,
460 | run_string,
461 | num_attr_to_cascade,
462 | )
463 | numextractions2results = measure_openie_results(
464 | attributes,
465 | args,
466 | profiler_args,
467 | run_string,
468 | gold_attributes,
469 | attributes_to_remove,
470 | full_file_groups,
471 | mappings_names
472 | )
473 | if 'openie' not in results_by_train_size[train_size]:
474 | results_by_train_size[train_size]['openie'] = {}
475 | results_by_train_size[train_size]['openie'] = numextractions2results
476 |
477 | results_by_train_size[train_size]['total_tokens_prompted'] = total_tokens_prompted
478 | results_by_train_size[train_size]['num_total_files'] = len(full_file_groups)
479 | results_by_train_size[train_size]['num_sample_files'] = len(sample_files)
480 | result_path_dir = os.path.join(profiler_args.base_data_dir, "results_dumps")
481 | if not os.path.exists(result_path_dir):
482 | os.mkdir(result_path_dir)
483 | print(run_string)
484 | with open(f"{result_path_dir}/{run_string}_results_by_train_size.pkl", "wb") as f:
485 | pickle.dump(results_by_train_size, f)
486 | print(f"Saved!")
487 |
488 | print(f"Total tokens prompted: {total_tokens_prompted}")
489 |
490 |
491 |
492 |
493 |
494 | def main():
495 | #todo: make into two types of args: running_args and data_args, set by the user
496 | profiler_args = get_experiment_args()
497 |
498 | # model_dict = {
499 | # 'MODELS': ["text-davinci-003"],
500 | # 'EXTRACTION_MODELS': ["text-davinci-003"],
501 | # 'GOLD_KEY': "text-davinci-003",
502 | # }
503 | # Example of how to use a locally-hosted FM
504 | # model_dict = {
505 | # 'MODELS': [" EleutherAI/gpt-j-6B"],
506 | # 'EXTRACTION_MODELS': [" EleutherAI/gpt-j-6B"],
507 | # 'GOLD_KEY': " EleutherAI/gpt-j-6B",
508 | # 'MODEL2URL': {
509 | # " EleutherAI/gpt-j-6B": "http://127.0.0.1:5000"
510 | # },
511 | # }
512 |
513 | run_experiment(profiler_args)
514 |
515 |
516 | if __name__ == "__main__":
517 | main()
518 |
--------------------------------------------------------------------------------
/evaporate/evaluate_synthetic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import json
4 | import pickle
5 |
6 | import html
7 | from bs4 import BeautifulSoup
8 | from collections import Counter, defaultdict
9 | from evaporate.utils import get_file_attribute
10 | from evaporate.evaluate_synthetic_utils import text_f1
11 |
12 |
13 | # Compute recall from two sets
14 | def set_recall(pred, gt):
15 | return len(set(pred) & set(gt)) / len(set(gt))
16 |
17 |
18 | # Compute precision from two sets
19 | def set_precision(pred, gt):
20 | return len(set(pred) & set(gt)) / len(set(pred))
21 |
22 |
23 | # Compute F1 from precision and recall
24 | def compute_f1(precision, recall):
25 | if recall > 0. or precision > 0.:
26 | return 2. * (precision * recall) / (precision + recall)
27 | else:
28 | return 0.
29 |
30 |
31 | def evaluate_schema_identification(run_string, args, group_name, train_size=-1):
32 | with open(f"{args.generative_index_path}/{run_string}_identified_schema.json") as f:
33 | most_common_fields = json.load(f)
34 |
35 | try:
36 | with open(args.gold_extractions_file) as f:
37 | gold_file2extractions = json.load(f)
38 | except:
39 | with open(args.gold_extractions_file, "rb") as f:
40 | gold_file2extractions = pickle.load(f)
41 |
42 | for file, dic in gold_file2extractions.items():
43 | gold_metadata = list(dic.keys())
44 | gold_metadata = [m for m in gold_metadata if m not in ['topic_entity_name']]
45 | break
46 |
47 | ctr = Counter(most_common_fields)
48 | results = {}
49 | for k in [len(gold_metadata), 1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 100, len(most_common_fields)]:
50 | if not most_common_fields:
51 | results[k] = {
52 | "recall": 0,
53 | "precision": 0,
54 | "f1": 0,
55 | "num_gold_attributes": k,
56 | }
57 | continue
58 | gold_metadata = [item.lower() for item in gold_metadata]
59 | pred_metadata = ctr
60 |
61 | limit = k
62 | pred_metadata = sorted(pred_metadata.most_common(limit), key=lambda x: (x[1], x[0]), reverse=True)
63 | pred_metadata = [item[0].lower() for item in pred_metadata]
64 | cleaned_pred_metadata = set()
65 | for pred in pred_metadata:
66 | if not pred:
67 | continue
68 | cleaned_pred_metadata.add(pred)
69 | cleaned_gold_metadata = set()
70 | for gold in gold_metadata:
71 | cleaned_gold_metadata.add(gold)
72 |
73 | recall = [x for x in cleaned_gold_metadata if x in cleaned_pred_metadata]
74 | precision = [x for x in cleaned_pred_metadata if x in cleaned_gold_metadata]
75 | recall = len(recall) / len(cleaned_gold_metadata)
76 | precision = len(precision) / len(cleaned_pred_metadata)
77 | f1 = compute_f1(precision, recall)
78 |
79 | results[k] = {
80 | "recall": recall,
81 | "precision": precision,
82 | "f1": f1,
83 | "num_gold_attributes": limit,
84 | }
85 |
86 | print(f"@k = %d --- Recall: %.3f, Precision: %.3f, F1: %.3f" % (k, recall, precision, f1))
87 | print()
88 | return results
89 |
90 |
91 | def clean_comparison(extraction, attribute='', exact_match=False):
92 | # formatting transformations
93 |
94 | if type(extraction) == list:
95 | if extraction and type(extraction[0]) == list:
96 | full_answer = []
97 | for answer in extraction:
98 | if type(answer) == list:
99 | dedup_list = []
100 | for a in answer:
101 | if a not in dedup_list:
102 | dedup_list.append(a)
103 | answer = dedup_list
104 | answer = [str(a).strip().strip("\n") for a in answer]
105 | full_answer.append(", ".join(answer))
106 | else:
107 | full_answer.append(answer.strip().strip("\n"))
108 | full_answer = [a.strip() for a in full_answer]
109 | extraction = ", ".join(full_answer)
110 | else:
111 | dedup_list = []
112 | for a in extraction:
113 | if a not in dedup_list:
114 | dedup_list.append(a)
115 | extraction = dedup_list
116 | extraction = [(str(e)).strip().strip("\n") for e in extraction if e]
117 | extraction = ", ".join(extraction)
118 | elif type(extraction) == "str" and (
119 | extraction.lower() == "none" or any(
120 | phrase in extraction.lower() for phrase in ["not reported", "none available", "n/a"]
121 | )
122 | ):
123 | extraction = ""
124 | if type(extraction) == float and math.isnan(extraction):
125 | extraction = ''
126 | if type(extraction) != str:
127 | extraction = str(extraction)
128 |
129 | if type(extraction) == str:
130 | if ("<" in extraction) and (">" in extraction):
131 | extraction = BeautifulSoup(extraction).text
132 |
133 | extraction = extraction.strip().replace(" ", " ").lower()
134 | attribute_variations = [f"{attribute}(s)".lower(), f"{attribute.strip()}(s)".lower(), attribute.lower(), attribute]
135 | for a in attribute_variations:
136 | extraction = extraction.replace(a, "").strip()
137 | for char in ["'", '"', "(", ")", ",", "/", "]", "[", ":"]:
138 | extraction = extraction.replace(char, "").strip()
139 | extraction = html.unescape(extraction)
140 | for char in ['&', '&', "-", "_", "\n", "\t", "http:", "<", ">"]:
141 | extraction = extraction.replace(char, " ").strip()
142 | if exact_match:
143 | extraction = extraction.replace(" ", "")
144 | if extraction == " ":
145 | extraction = ""
146 | extraction = extraction.strip()
147 | return extraction
148 |
149 |
150 | def evaluate_extraction_quality(run_string, args, gold_extractions_file, gold_extractions_file_dir, gold_attributes=None):
151 | all_attribute_f1 = 0
152 | all_attribute_total = 0
153 | total_runtime_overall_attributes = 0
154 | attribute2f1 = {}
155 | attribute2scripts = {}
156 |
157 | # load gold extractions file
158 | try:
159 | with open(gold_extractions_file) as f:
160 | gold_extractions_tmp = json.load(f)
161 | except:
162 | with open(gold_extractions_file, "rb") as f:
163 | gold_extractions_tmp = pickle.load(f)
164 | gold_extractions = {}
165 | for file, extractions in gold_extractions_tmp.items():
166 | gold_extractions[os.path.join(gold_extractions_file_dir, file.split('/')[-1])] = extractions
167 |
168 | for attribute in gold_attributes:
169 | attribute = attribute.lower()
170 | # load predicted extractions
171 | fileattribute = get_file_attribute(attribute)
172 | if not os.path.exists(f"{args.generative_index_path}/{run_string}_{fileattribute}_file2metadata.json"):
173 | print(f"Missing file for {attribute}")
174 | continue
175 | with open(f"{args.generative_index_path}/{run_string}_{fileattribute}_file2metadata.json") as f:
176 | file2metadata = json.load(f)
177 |
178 | try:
179 | with open(f"{args.generative_index_path}/{run_string}_{fileattribute}_functions.json") as f:
180 | function_dictionary = json.load(f)
181 | with open(f"{args.generative_index_path}/{run_string}_{fileattribute}_top_k_keys.json") as f:
182 | selected_keys = json.load(f)
183 | attribute2scripts[attribute] = selected_keys
184 | except:
185 | function_dictionary = {}
186 | selected_keys = []
187 | pass
188 | total_runtime = 0
189 | for key in selected_keys:
190 | if key in function_dictionary:
191 | runtime = function_dictionary[key]['runtime']
192 | total_runtime += runtime
193 |
194 | preds = []
195 | golds = []
196 | for file, gold_entry in gold_extractions.items():
197 | for attr, gold_value in gold_entry.items():
198 | attr = clean_comparison(attr)
199 | attribute = clean_comparison(attribute)
200 | if attr.lower() != attribute.lower():
201 | continue
202 | if file not in file2metadata:
203 | continue
204 | pred_value = file2metadata[file]
205 |
206 | value_check = ''
207 | pred_value_check = ''
208 | if type(pred_value) == list and type(pred_value[0]) == str:
209 | pred_value_check = sorted([p.strip() for p in pred_value])
210 | elif type(pred_value) == str and "," in pred_value:
211 | pred_value_check = sorted([p.strip() for p in pred_value.split(",")])
212 | if type(gold_value) == list:
213 | value_check = gold_value[0]
214 | if "," in gold_value:
215 | value_check = sorted([p.strip() for p in gold_value.split(",")])
216 | if value_check and pred_value_check and value_check == pred_value_check:
217 | gold_value = pred_value
218 |
219 | # SWDE doesn't include the full passage in many cases (e.g. "IMDB synopsis")
220 | pred_value = clean_comparison(pred_value, attribute=attribute)
221 | gold_value = clean_comparison(gold_value, attribute=attribute)
222 | if pred_value.lower().strip(".").startswith(gold_value.lower().strip(".")):
223 | pred_value = " ".join(pred_value.split()[:len(gold_value.split())])
224 | preds.append(pred_value)
225 | golds.append(gold_value)
226 |
227 | if not preds:
228 | total_f1, total_f1_median = 0, 0
229 | if golds and preds:
230 | total_f1, total_f1_median = text_f1(preds, golds, attribute=attribute)
231 | else:
232 | print(f"Skipping eval of attribute: {attribute}")
233 | continue
234 |
235 | if preds:
236 | all_attribute_f1 += (total_f1)
237 | all_attribute_total += 1
238 | attribute2f1[attribute] = total_f1
239 |
240 | total_runtime_overall_attributes += total_runtime
241 |
242 | num_function_scripts = 0
243 | for k, v in attribute2f1.items():
244 | scripts = []
245 | if k in attribute2scripts:
246 | scripts = attribute2scripts[k]
247 | if any(s for s in scripts if "function" in s):
248 | num_function_scripts += 1
249 | print(f"{k}, text-f1 = {v} --- {scripts}")
250 |
251 | try:
252 | overall_f1 = all_attribute_f1 / all_attribute_total
253 | print(f"\nOverall f1 across %d attributes: %.3f" % (all_attribute_total, overall_f1))
254 | print(f"Used functions for {num_function_scripts} out of {len(attribute2f1)} attributes")
255 | print(f"Average time: {total_runtime_overall_attributes/all_attribute_total} seconds, {all_attribute_total} fns.\n\n")
256 |
257 | results = {
258 | "f1": all_attribute_f1 / all_attribute_total,
259 | "total_attributes": all_attribute_total,
260 | "attribute2f1": attribute2f1,
261 | }
262 | except:
263 | results = {
264 | "f1": 0,
265 | "total_attributes": all_attribute_total,
266 | "attribute2f1": attribute2f1,
267 | }
268 |
269 | return results
270 |
271 |
272 | def determine_attribute_slices(gold_extractions, slice_results):
273 | num_occurences, num_characters = defaultdict(int), defaultdict(int)
274 | num_documents = len(gold_extractions)
275 | for file, extraction_dict in gold_extractions.items():
276 | for key, value in extraction_dict.items():
277 | if type(value) == str and value:
278 | num_occurences[key] += 1
279 | num_characters[key] += len(value)
280 | elif type(value) == list and value[0]:
281 | num_occurences[key] += 1
282 | num_characters[key] += len(value[0])
283 |
284 | # calculate the average length of the attribute
285 | for attr, total_len in num_characters.items():
286 | num_characters[attr] = total_len / num_occurences[attr]
287 |
288 | # split into the "head", "tail", and "unstructured"
289 | attribute_slices = defaultdict(set)
290 | for attr, num_occur in num_occurences.items():
291 | attribute_slices["all"].add(attr)
292 |
293 | # skip the rest if not slicing results
294 | if not slice_results:
295 | continue
296 |
297 | num_char = num_characters[attr]
298 | if int(num_documents * 0.5) <= num_occur:
299 | attribute_slices["head"].add(attr)
300 | else:
301 | attribute_slices["tail"].add(attr)
302 |
303 | if num_char >= 20:
304 | attribute_slices["unstructured"].add(attr)
305 | else:
306 | attribute_slices["structured"].add(attr)
307 |
308 | return attribute_slices
309 |
310 |
311 | def evaluate_openie_quality(
312 | run_string,
313 | args,
314 | gold_extractions_file,
315 | sample_files=None,
316 | slice_results=False,
317 | mappings_names={}
318 | ):
319 | # load pred extractions file
320 | with open(f"{args.generative_index_path}/{run_string}_file2extractions.json") as f:
321 | pred_extractions = json.load(f)
322 |
323 | # alternate gold attribute naming
324 | if args.set_dicts:
325 | with open(args.set_dicts) as f:
326 | set_dicts = json.load(f)
327 | else:
328 | set_dicts = {}
329 |
330 | # load gold extractions file
331 | try:
332 | with open(gold_extractions_file) as f:
333 | gold_extractions = json.load(f)
334 | except:
335 | with open(gold_extractions_file, "rb") as f:
336 | gold_extractions = pickle.load(f)
337 |
338 | pred_attributes = set()
339 | for file, extraction_dict in pred_extractions.items():
340 | for key, value in extraction_dict.items():
341 | pred_attributes.add(key)
342 |
343 | # split the attribute into slices -> "head", "tail", and "unstructured"
344 | attribute_slices = determine_attribute_slices(gold_extractions, slice_results)
345 |
346 | results = {}
347 | for attribute_slice, gold_attributes in attribute_slices.items():
348 |
349 | # lenient attribute scoring method: https://arxiv.org/pdf/2201.10608.pdf
350 | gold_attribute_mapping = {}
351 | for gold_attribute in gold_attributes:
352 | if gold_attribute in pred_attributes or not set_dicts:
353 | gold_attribute_mapping[gold_attribute] = gold_attribute
354 | continue
355 | if gold_attribute in set_dicts:
356 | alternate_golds = set_dicts[gold_attribute]
357 | else:
358 | alternate_golds = [gold_attribute]
359 | found = 0
360 | for alternate_gold in alternate_golds:
361 | if alternate_gold in pred_attributes:
362 | gold_attribute_mapping[gold_attribute] = alternate_gold
363 | found = 1
364 | if not found:
365 | if gold_attribute.strip('s') in pred_attributes:
366 | gold_attribute_mapping[gold_attribute] = gold_attribute.strip('s')
367 | elif gold_attribute+"s" in pred_attributes:
368 | gold_attribute_mapping[gold_attribute] = gold_attribute+"s"
369 | elif gold_attribute.strip('(s)') in pred_attributes:
370 | gold_attribute_mapping[gold_attribute] = gold_attribute.strip('(s)')
371 | elif gold_attribute+"(s)" in pred_attributes:
372 | gold_attribute_mapping[gold_attribute] = gold_attribute+"(s)"
373 | elif gold_attribute.replace(" ", "") in pred_attributes:
374 | gold_attribute_mapping[gold_attribute] = gold_attribute.replace(" ", "")
375 | elif any(pred_attribute.replace(" ", "") in gold_attributes for pred_attribute in pred_attributes):
376 | for pred_attribute in pred_attributes:
377 | if pred_attribute.replace(" ", "") in gold_attributes:
378 | gold_attribute_mapping[gold_attribute] = pred_attribute
379 | elif gold_attribute in mappings_names and mappings_names[gold_attribute] in pred_attributes:
380 | gold_attribute_mapping[gold_attribute] = mappings_names[gold_attribute]
381 | else:
382 | gold_attribute_mapping[gold_attribute] = gold_attribute
383 |
384 | pred_set = set()
385 | skipped = set()
386 | all_measurements = defaultdict(dict)
387 | for file, extraction_dict in pred_extractions.items():
388 | if sample_files and file not in sample_files:
389 | continue
390 | for key, value in extraction_dict.items():
391 | if key not in attribute_slices["all"]:
392 | if key.replace(" ", "") in attribute_slices["all"]:
393 | key = key.replace(" ", "")
394 |
395 | # skip predicted attributes that are in a different slice
396 | if key in attribute_slices["all"] and key not in gold_attributes:
397 | skipped.add(key)
398 | continue
399 |
400 | clean_key = clean_comparison(key, exact_match=True)
401 | clean_value = clean_comparison(value, attribute=key, exact_match=True)
402 | if clean_value:
403 | pred_set.add((file, clean_key, clean_value))
404 | if file not in all_measurements[clean_key]:
405 | all_measurements[clean_key][file] = {
406 | "pred": "",
407 | "gold": "",
408 | }
409 | all_measurements[clean_key][file]['pred'] = clean_value
410 |
411 | clean_pred_attributes = set([x[1] for x in pred_set])
412 | # resolve mapping between gold and pred attributes
413 | gold_attribute_mapping = {}
414 | for gold_attribute in gold_attributes:
415 | if gold_attribute in clean_pred_attributes:
416 | gold_attribute_mapping[gold_attribute] = gold_attribute
417 | continue
418 | found = False
419 | if set_dicts and gold_attribute in set_dicts:
420 | alternate_golds = set_dicts[gold_attribute]
421 | for alternate_gold in alternate_golds:
422 | if alternate_gold in clean_pred_attributes:
423 | gold_attribute_mapping[gold_attribute] = alternate_gold
424 | found = True
425 | if not found:
426 | if gold_attribute.strip('s') in clean_pred_attributes:
427 | gold_attribute_mapping[gold_attribute] = gold_attribute.strip('s')
428 | elif gold_attribute+"s" in clean_pred_attributes:
429 | gold_attribute_mapping[gold_attribute] = gold_attribute+"s"
430 | else:
431 | gold_attribute_mapping[gold_attribute] = gold_attribute
432 |
433 | num_attributes = len(clean_pred_attributes)
434 |
435 | gold_set = set()
436 | for file, extraction_dict in gold_extractions.items():
437 | if sample_files and file not in sample_files:
438 | continue
439 | for key, value in extraction_dict.items():
440 | # ignore attributes in a different slice
441 | if key not in gold_attributes:
442 | continue
443 |
444 | if key == "topic_entity_name":
445 | if "name" in pred_attributes:
446 | gold_attribute_mapping[key] = "name"
447 |
448 | key = gold_attribute_mapping[key]
449 | if key not in pred_attributes:
450 | if key.replace(" ", "") in pred_attributes:
451 | key = key.replace(" ", "")
452 |
453 | # sort list-based attribute values for consistency.
454 | if file in pred_extractions and key in pred_extractions[file]:
455 | pred_value = pred_extractions[file][key]
456 | value_check = ''
457 | pred_value_check = ''
458 | if type(pred_value) == list and type(pred_value[0]) == str:
459 | pred_value_check = sorted([p.strip() for p in pred_value])
460 | elif type(pred_value) == str and "," in pred_value:
461 | pred_value_check = sorted([p.strip() for p in pred_value.split(",")])
462 | if type(value) == list:
463 | value_check = value[0]
464 | if "," in value:
465 | value_check = sorted([p.strip() for p in value.split(",")])
466 | if value_check and pred_value_check and value_check == pred_value_check:
467 | value = pred_value
468 |
469 | clean_key = clean_comparison(key, exact_match=True)
470 | clean_value = clean_comparison(value, attribute=key, exact_match=True)
471 |
472 | if clean_value:
473 | gold_set.add((file, clean_key, clean_value))
474 | if file not in all_measurements[clean_key]:
475 | all_measurements[clean_key][file] = {
476 | "pred": "",
477 | "gold": "",
478 | }
479 | all_measurements[clean_key][file]['gold'] = clean_value
480 |
481 | if not pred_set or not gold_set:
482 | results[attribute_slice] = {
483 | "precision": 0,
484 | "recall": 0,
485 | "f1": 0,
486 | "num_files_evaluated": len(pred_extractions),
487 | }
488 | else:
489 | # exact match over all fields
490 | precision = set_precision(pred_set, gold_set)
491 | recall = set_recall(pred_set, gold_set)
492 | f1 = compute_f1(precision, recall)
493 |
494 | results[attribute_slice] = {
495 | "precision": precision,
496 | "recall": recall,
497 | "f1": f1,
498 | "num_files_evaluated": len(pred_extractions),
499 | }
500 | print(f"[%s] OpenIE Precision (%d attributes): Precision: %.3f Recall: %.3f F1: %.3f" % (attribute_slice, num_attributes, precision, recall, f1))
501 | return results if slice_results else results["all"]
502 |
503 |
504 | def main(
505 | run_string,
506 | args,
507 | profiler_args,
508 | data_lake = "wiki_nba_players",
509 | sample_files=None,
510 | stage='',
511 | gold_attributes=[],
512 | mappings_names={}
513 | ):
514 | gold_extractions_file = args.gold_extractions_file
515 | train_size = profiler_args.train_size
516 |
517 | overall_results = {}
518 |
519 | if stage and stage != 'schema_id':
520 | pass
521 | else:
522 | schema_id_results = evaluate_schema_identification(
523 | run_string,
524 | args,
525 | data_lake,
526 | train_size=train_size,
527 | )
528 | overall_results["schema_id"] = schema_id_results
529 |
530 | if stage and stage != 'extract':
531 | pass
532 | else:
533 | extraction_results = evaluate_extraction_quality(
534 | run_string,
535 | args,
536 | gold_extractions_file,
537 | profiler_args.data_dir,
538 | gold_attributes=gold_attributes
539 | )
540 | overall_results["extraction"] = extraction_results
541 |
542 | if stage and stage != 'openie':
543 | pass
544 | else:
545 | openie_results = evaluate_openie_quality(
546 | run_string,
547 | args,
548 | gold_extractions_file,
549 | sample_files=sample_files,
550 | slice_results = profiler_args.slice_results,
551 | mappings_names = mappings_names
552 | )
553 | overall_results["openie"] = openie_results
554 |
555 | return overall_results
556 |
557 |
558 | if __name__ == "__main__":
559 | main()
--------------------------------------------------------------------------------