├── .gitignore
├── README.md
├── assets
├── bias_results.png
├── delphi_hybrid.png
├── examples.png
├── main_results.png
├── norm_bank_content.png
└── overall.png
├── data
└── datasheet.md
└── src
├── delphi
├── evaluate
│ ├── evaluate_dynahate_v9.py
│ ├── evaluate_dynahate_yesno.py
│ ├── evaluate_ethics_converted.py
│ ├── evaluate_ethics_v9.py
│ ├── evaluate_joint_st_wild_v9.py
│ ├── evaluate_latenthatred_v9.py
│ ├── evaluate_latenthatred_yesno.py
│ ├── evaluate_utils.py
│ └── select_check_point_wild_v9.py
└── train
│ ├── evaluate.py
│ ├── fine-tune.py
│ ├── mixtures.py
│ ├── predict.py
│ ├── rates.py
│ ├── tasks.py
│ └── util.py
├── delphi_hybrid
├── collective_reasoning
│ ├── ITW
│ │ ├── BaseRuler.py
│ │ ├── EnsembleRuler.py
│ │ └── Ruler.py
│ ├── ITW_NormBank
│ │ ├── ITWBaseRuler.py
│ │ ├── NormBankBaseRuler.py
│ │ └── Ruler.py
│ └── NormBank
│ │ ├── BaseRuler.py
│ │ ├── EnsembleRuler.py
│ │ └── Ruler.py
├── components
│ ├── COMETGenerator.py
│ ├── CacheHandler
│ │ ├── COMETCacheHandler.py
│ │ ├── CacheHandler.py
│ │ ├── DelphiCacheHandler.py
│ │ └── ParaphraseCacheHandler.py
│ ├── CompositionalityParser.py
│ ├── DelphiScorer.py
│ ├── GPT3Scorer.py
│ ├── LMScorer.py
│ ├── MoralSaliencyKeywordCounter.py
│ ├── MoralSaliencyKeywordIdentifier.py
│ ├── Paraphraser.py
│ ├── WANLIScorer.py
│ ├── bank.py
│ ├── constants.py
│ ├── main_utils.py
│ └── utils.py
└── prepare_data
│ ├── compile_data_v6.py
│ ├── compile_demo_data.py
│ ├── compile_gold_labels.py
│ ├── filter_paraphrases.py
│ └── prepare_data.py
├── delphi_plus
├── evalute
│ ├── evaluate_declare_only.py
│ ├── evaluate_distribution.py
│ ├── evaluate_utils.py
│ ├── select_declare_only.py
│ └── select_distribution.py
└── train
│ ├── evaluate.py
│ ├── fine-tune.py
│ ├── mixtures.py
│ ├── predict.py
│ ├── rates.py
│ ├── tasks.py
│ └── util.py
└── utils
└── text2class.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | data/*
10 | results/
11 | old/
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # poetry
102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103 | # This is especially recommended for binary packages to ensure reproducibility, and is more
104 | # commonly ignored for libraries.
105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106 | #poetry.lock
107 |
108 | # pdm
109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110 | #pdm.lock
111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112 | # in version control.
113 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
114 | .pdm.toml
115 | .pdm-python
116 | .pdm-build/
117 |
118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
119 | __pypackages__/
120 |
121 | # Celery stuff
122 | celerybeat-schedule
123 | celerybeat.pid
124 |
125 | # SageMath parsed files
126 | *.sage.py
127 |
128 | # Environments
129 | .env
130 | .venv
131 | env/
132 | venv/
133 | ENV/
134 | env.bak/
135 | venv.bak/
136 |
137 | # Spyder project settings
138 | .spyderproject
139 | .spyproject
140 |
141 | # Rope project settings
142 | .ropeproject
143 |
144 | # mkdocs documentation
145 | /site
146 |
147 | # mypy
148 | .mypy_cache/
149 | .dmypy.json
150 | dmypy.json
151 |
152 | # Pyre type checker
153 | .pyre/
154 |
155 | # pytype static type analyzer
156 | .pytype/
157 |
158 | # Cython debug symbols
159 | cython_debug/
160 |
161 | # PyCharm
162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
164 | # and can be added to the global gitignore or merged into this file. For a more nuclear
165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
166 | #.idea/
167 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # An Empirical Investigation of Machines' Capabilities for Moral Judgment with the Delphi Experiment
2 |
3 |
4 |
5 |
6 |
7 |
8 | **Authors:**
9 | [Liwei Jiang](https://liweijiang.me),
10 | [Jena D. Hwang](https://jenahwang.github.io),
11 | [Chandra Bhagavatula](https://www.chandrab.page),
12 | [Ronan Le Bras](https://rlebras.github.io),
13 | [Jenny Liang](https://jennyliang.me),
14 | [Sydney Levine](https://sites.google.com/site/sydneymlevine/),
15 | [Jesse Dodge](https://jessedodge.github.io),
16 | [Keisuke Sakaguchi](https://keisuke-sakaguchi.github.io),
17 | [Maxwell Forbes](https://maxwellforbes.com),
18 | [Jack Hessel](https://jmhessel.com),
19 | [Jon Borchardt](https://www.linkedin.com/in/borchardt/),
20 | [Taylor Sorensen](https://tsor13.github.io),
21 | [Saadia Gabriel](https://saadiagabriel.com),
22 | [Yulia Tsvetkov](https://homes.cs.washington.edu/~yuliats/),
23 | [Oren Etzioni](https://homes.cs.washington.edu/~etzioni/site/),
24 | [Maarten Sap](http://maartensap.com),
25 | [Regina Rini](https://reginarini.net),
26 | [Yejin Choi](https://homes.cs.washington.edu/~yejin/),
27 |
28 |
29 | > As our society adopts increasingly powerful AI systems for pervasive use, there are growing concerns about machine morality---or lack thereof. Millions of users already rely upon the outputs of AI systems, such as chatbots, as decision aids. Meanwhile, AI researchers continue to grapple with the challenge of aligning these systems with human morality and values. In response to this challenge, we build and test Delphi, an open-source AI system trained to predict human moral judgments. The computational framework of Delphi is grounded in the philosophical moral framework proposed by the prominent moral philosopher John Rawls. Our results speak to the promises and limits of machine's capabilities to learn about human morality. On the one hand, Delphi demonstrates improved generalization capabilities over those exhibited by off-the-shelf neural language models. At the same time, Delphi's failures also underscore important challenges in this arena. For instance, Delphi has limited cultural awareness and is susceptible to pervasive biases. Despite these shortcomings, we demonstrate several compelling use cases of Delphi, including incorporating it as a component within an ensemble of AI systems. Finally, we computationally demonstrate the potential of Rawls' prospect of hybrid approaches for reliable moral reasoning, inspiring future research in computational morality.
30 |
31 | ## The Theoretical and Computational Frameworks of Delphi
32 |
33 |
39 |
40 |
41 |
42 | > (a) The theoretical framework of ethics proposed by the prominent moral philosopher John Rawls. In 1951, Rawls proposed a “decision procedure of ethics” that takes a bottom-up approach to capture patterns of human ethics via crowd- sourcing moral opinions of a wide variety of people. Later in 1971, Rawls complemented the theoretial procedure with top-down constraints in his most famous work, A Theory of Justice. Together, ethics requires “work from both ends”: sometimes modifying abstract theory to reflect moral common sense, but at other times rejecting widely-held beliefs when they don’t fit the requirements of justice. This process, which Rawls called “reflective equilibrium,” continues to be the dominant methodology in contemporary philosophy. (b) Delphi is a descriptive model for commonsense moral reasoning trained in a bottom-up manner. Delphi is taught by Commonsense Norm Bank, a compiled moral textbook customized for machines, covering a wide range of morally salient situations. Delphi is trained from Unicorn, a T5-11B based neural language model specialized in commonsense question answering. Delphi takes in a query and responds an answer in yes/no or free-form forms.
43 |
44 |
45 | Overview of Commonsense Norm Bank Content
46 |
47 | > Representative N-grams cover topics including people, relationships, actions, life & society, cognition, and others. The lemmatized and normalized 4-grams used for the topic analysis are bolded. Auxiliary words from the original form of data instances that are not used in the topics analysis are unbolded.
48 |
49 |
50 |
51 |
52 |
53 |
54 | Examples from Delphi
55 |
56 | > Delphi shows impressive ability to generalize to unseen situations beyond Commonsense Norm Bank, and is robust to adjust its judgment against changing contexts. Colors of labels indicate Delphi’s classification results (green: positive, gray: neutral, red: negative). Textual labels come from Delphi’s open-text responses.
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 | Main Results of Delphi
65 |
66 | > (a) Delphi achieves better performance on Norm Bank comparing to GPT-3 baselines. (b) Comparing the effect of the size of the base T5 model. (c) Ablation results showing the scale of training data improves Delphi’s learning. (d) Ablation results showing the compositionality of training instances improves Delphi’s learning. (e) Delphi, with minimal supervisions, outperforms baseline models on hate speech detection under both in-distribution and out-of-distribution settings. (g) Plugging Delphi into language generation models helps improve the prosocial implication scores of the generated stories, without sacrificing the language quality. (g) Delphi outperforms other baselines on transferring knowledge to specific theoretically motivated moral frameworks.
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 | Social Bias Evaluation Results of Delphi
75 |
76 | > (a) Results for the Universal Declaration of Human Rights probing, including top identities that Delphi shows biases against and their level of biases, and the average % error for each identity group. (b) Delphi and Delphi+’s performance under current-world and ideal-world settings. Statistical significance test is performed between Delphi under the current-world compared to other models or settings.
77 |
78 |
79 |
80 |
81 |
82 |
83 | ## An Illustration of the Delphi-Hybrid Framework and an Example Output Moral Constraint Graph
84 |
85 |
86 |
92 |
93 |
94 |
95 |
96 | > (a) A hybrid system that incorporates an optional symbolically guided rea- soning mechanism to complement the neural language model based Delphi. (b) An example of the moral constraint graph produced by Delphihybrid for the event “Mass genocide for greater good.” Nodes denote judgments derived either from top-down moral principles or bottom-up Delphi. Edges denote logical violations (i.e., identity, entailment, and contradiction) between nodes. ❌ denotes inconsistent nodes identified by the constrained optimization step. Note that each top-down moral principle may result in multiple nodes depending on engineering details (e.g., the same rule “Do not kill” applied at the full event level or constituent level). The final judgment is negative.
97 |
98 |
99 | ## Codebase Structure
100 |
101 | This codebase contains the training and evaluation code for Delphi, Delphi+, and the Delphi-Hybrid system.
102 |
103 | ### Delphi:
104 |
105 | - `src/delphi/evaluate`: scripts for evaluating the Delphi models on the yes/no QA and freeform QA tasks, as well as downstream tasks like Hate Speech Detection.
106 |
107 | - `src/delphi/train`: the scripts for finetuning T5 for Delphi.
108 |
109 | ### Delphi+:
110 |
111 | - `src/delphi_plus/evaluate`: scripts for evaluating the Delphi+ models on the yes/no QA and freeform QA tasks, as well as downstream tasks like Hate Speech Detection.
112 |
113 | - `src/delphi/train`: the scripts for finetuning T5 for Delphi+.
114 |
115 | ### Delphi-Hybrid:
116 |
117 | - `src/delphi_plus/collective_reasoning`: codebase for the collective reasoning component of Delphi-Hybrid.
118 |
119 | - `src/delphi/components`: components of the Delphi-Hybrid system.
120 |
121 | - `src/delphi/prepare_data`: scripts for preparing test data for Delphi-Hybrid experiments.
122 |
123 | ### Datasheet:
124 |
125 | - `data/datasheet.md`: the datasheet for the Commonsense Norm Bank dataset.
126 |
127 | ## Data and Model Access
128 | You can access the Commonsense Norm Bank dataset by [filling out this form](https://forms.gle/VoAVuPUJFNChWhSj8).
129 |
130 | For accessing the Delphi model checkpoints and API calls please feel free to reach out to Liwei Jiang at [lwjiang@cs.washington.edu](lwjiang@cs.washington.edu).
131 |
132 | If you find our paper or data useful, please cite the paper:
133 | ```
134 | @article{jiang2025delphi,
135 | author = {Liwei Jiang and Jena D. Hwang and Chandra Bhagavatula and Ronan Le Bras and Jenny T. Liang and Sydney Levine and Jesse Dodge and Keisuke Sakaguchi and Maxwell Forbes and Jack Hessel and Jon Borchardt and Taylor Sorensen and Saadia Gabriel and Yulia Tsvetkov and Oren Etzioni and Maarten Sap and Regina Rini and Yejin Choi},
136 | title = {Investigating machine moral judgement through the Delphi experiment},
137 | journal = {Nature Machine Intelligence},
138 | volume = {7},
139 | pages = {145--160},
140 | year = {2025},
141 | doi = {10.1038/s42256-024-00700-y}
142 | }
143 |
144 | ```
145 |
146 |
--------------------------------------------------------------------------------
/assets/bias_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liweijiang/delphi/cc7130578fb10f4db122ac537c1b387bc02b7837/assets/bias_results.png
--------------------------------------------------------------------------------
/assets/delphi_hybrid.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liweijiang/delphi/cc7130578fb10f4db122ac537c1b387bc02b7837/assets/delphi_hybrid.png
--------------------------------------------------------------------------------
/assets/examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liweijiang/delphi/cc7130578fb10f4db122ac537c1b387bc02b7837/assets/examples.png
--------------------------------------------------------------------------------
/assets/main_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liweijiang/delphi/cc7130578fb10f4db122ac537c1b387bc02b7837/assets/main_results.png
--------------------------------------------------------------------------------
/assets/norm_bank_content.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liweijiang/delphi/cc7130578fb10f4db122ac537c1b387bc02b7837/assets/norm_bank_content.png
--------------------------------------------------------------------------------
/assets/overall.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liweijiang/delphi/cc7130578fb10f4db122ac537c1b387bc02b7837/assets/overall.png
--------------------------------------------------------------------------------
/src/delphi/evaluate/evaluate_dynahate_v9.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import f1_score
2 | import sys
3 | import numpy as np
4 | sys.path.append("script/evaluate")
5 |
6 | from evaluate_utils import *
7 |
8 | def get_gold_class(round_id, data_split):
9 | """
10 | Get gold inputs and targets class labels; format (1 )
11 | """
12 | if data_split == "validation":
13 | data_split = "dev"
14 | data_base_path = f"gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/data/v9_downstream/dynahate/{round_id}/{data_split}.tsv"
15 | df_inputs = pd.read_csv(data_base_path, sep="\t")
16 |
17 | inputs_all = list(df_inputs["inputs"])
18 | inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all]
19 |
20 | targets_all = list(df_inputs["targets"])
21 | targets = [int(i.split(" ")
22 | [0].split("")[-1]) for i in targets_all]
23 | return inputs, targets
24 |
25 |
26 | def get_pred_class(bucket, base_path, task_name, check_point, base_model):
27 | """
28 | Get preds class labels
29 | """
30 | preds_blob = bucket.get_blob(
31 | base_path + f"{task_name}_{check_point}_predictions")
32 | preds_blob_list = preds_blob.download_as_string().decode(
33 | 'utf-8').split("\n")[1:]
34 |
35 | preds_class = []
36 | for i in preds_blob_list:
37 | try:
38 | if "v10" in base_model or "v11" in base_model:
39 | preds_class.append(
40 | int(i.split("[/class]")[0].split("[class]")[-1]))
41 | else:
42 | preds_class.append(
43 | int(i.split(" ⁇ /class>")[0].split(" ⁇ class>")[-1]))
44 | except:
45 | print("output form not identifiable:", i)
46 | preds_class.append(99)
47 | return preds_class
48 |
49 |
50 | def main_get_accuracy(base_path, data_split, training_data_type, round_ids=None, check_points=None,
51 | is_print=False, all_list=None, is_save_results=True):
52 | base_path += f"{data_split}_eval/"
53 | bucket_name = base_path.split("/")[0]
54 | result_prefix = "/".join(base_path.split("/")[1:])
55 | base_path = "/".join(base_path.split("/")[1:])
56 | base_model = base_path.split("/")[-5]
57 | training_data = base_path.split("/")[-4]
58 |
59 | client = storage.Client()
60 | bucket = client.get_bucket(bucket_name)
61 |
62 | if base_model == "v10-delphi":
63 | base_path = base_path.replace(training_data, "new_" + training_data)
64 |
65 | if check_points == None:
66 | check_points = get_check_points(
67 | client, bucket_name, result_prefix, after_check_point=-1)[1:]
68 | # check_points.sort(reverse=True)
69 | for check_point in check_points:
70 | all_inputs = []
71 | all_targets = []
72 | all_preds = []
73 | all_f1s = []
74 | all_accs = []
75 | all_round_ids = []
76 | for round_id in round_ids:
77 | if training_data_type == "":
78 | task_name = f"dynahate_round_{round_id}"
79 | else:
80 | task_name = f"dynahate_round_{round_id}_{training_data_type}"
81 |
82 | if base_model == "v10-delphi":
83 | task_name = "new_" + task_name
84 |
85 | inputs, targets = get_gold_class(round_id, data_split)
86 | preds = get_pred_class(
87 | bucket, base_path, task_name, check_point, base_model)
88 | if base_model == "v10-delphi":
89 | preds_new = [t - 1 for t in preds]
90 | preds = preds_new
91 |
92 | all_round_ids += [round_id] * len(preds)
93 | all_inputs += inputs
94 | all_targets += targets
95 | all_preds += preds
96 |
97 | f1 = f1_score(targets, preds, average='macro')
98 | acc = get_accuracy(targets, preds, accuracy_type="exact")
99 |
100 | all_f1s.append(f1)
101 | all_accs.append(acc)
102 | if is_print:
103 | print(
104 | f"round ({round_id}) {check_point}: f1 -- {f1} | accuracy -- {acc}")
105 |
106 | f1 = f1_score(all_targets, all_preds, average='macro')
107 | acc = get_accuracy(all_targets, all_preds, accuracy_type="exact")
108 | print(f"round (all) {check_point}: f1 -- {f1} | accuracy -- {acc}")
109 |
110 | if is_save_results:
111 | df_data = pd.DataFrame()
112 | df_data["round_id"] = all_round_ids
113 | df_data["input"] = all_inputs
114 | df_data["target"] = all_targets
115 | df_data["pred"] = all_preds
116 |
117 | df_data.to_csv(
118 | f"{training_data}-{base_model}-{data_split}.tsv", index=False, sep="\t")
119 |
120 | return all_list
121 |
122 |
123 | def main_r1_st(training_data):
124 | base_model_2_ckpt = {
125 | "dynahate_round_1_st": {"v11-delphi-declare": [1290200],
126 | "v10-delphi": [1323500],
127 | "v9-delphi": [1341200],
128 | "v9-delphi-new": [1264700],
129 | "unicorn-pt": [1127000],
130 | "11B": [1102000]},
131 |
132 | "dynahate_round_1_st_100_shot": {"v11-delphi-declare": [1290200],
133 | "v10-delphi": [1349000],
134 | "v9-delphi": [1325900],
135 | "v9-delphi-new": [1285100],
136 | "unicorn-pt": [1106600],
137 | "11B": [1076500]},
138 | }
139 |
140 | print(training_data)
141 | # "v11-delphi-declare", "v10-delphi" "v10-delphi", "v9-delphi", "unicorn-pt", "11B"
142 | for base_model in ["v9-delphi-new", "unicorn-pt", "11B"]:
143 | print("=" * 10, base_model)
144 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/"
145 |
146 | check_points = base_model_2_ckpt[training_data][base_model]
147 | main_get_accuracy(base_path, "validation", "st", [1], check_points)
148 | main_get_accuracy(base_path, "test", "st", [
149 | 1, 2, 3, 4], check_points, is_print=True)
150 | #
151 |
152 |
153 | def main_all_st(training_data):
154 | base_model_2_ckpt = {
155 | "dynahate_all_st": {"v11-delphi-declare": [1290200], # 1290200
156 | "v10-delphi": [1292900],
157 | "v9-delphi": [1402400],
158 | "v9-delphi-new": [1387100],
159 | "unicorn-pt": [1157600],
160 | "11B": [1132600]},
161 |
162 | "dynahate_all_st_100_shot": {"v11-delphi-declare": [1331000],
163 | "v10-delphi": [1354100],
164 | "v9-delphi": [1392200],
165 | "v9-delphi-new": [1407500],
166 | "unicorn-pt": [1147400],
167 | "11B": [1076500]},
168 | }
169 |
170 | all_list = []
171 |
172 | print(training_data)
173 | for base_model in ["v9-delphi-new", "unicorn-pt", "11B"]:
174 | print("=" * 10, base_model)
175 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/"
176 | # check_points = None
177 | check_points = base_model_2_ckpt[training_data][base_model]
178 | main_get_accuracy(base_path, "validation", "st", [
179 | 1, 2, 3, 4], check_points, is_print=False)
180 | all_list = main_get_accuracy(base_path, "test", "st", [
181 | 1, 2, 3, 4], check_points, is_print=True, all_list=all_list)
182 |
183 | return all_list
184 |
185 |
186 | if __name__ == "__main__":
187 | main_r1_st("dynahate_round_1_st_100_shot")
188 | main_all_st("dynahate_all_st_100_shot")
189 |
--------------------------------------------------------------------------------
/src/delphi/evaluate/evaluate_dynahate_yesno.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import f1_score
2 | import sys
3 | import numpy as np
4 | sys.path.append("script/evaluate")
5 | from evaluate_utils import *
6 |
7 |
8 | def get_gold_class(round_id, data_split, training_data_type):
9 | """
10 | Get gold inputs and targets class labels; format (1 )
11 | """
12 | if data_split == "validation":
13 | data_split = "dev"
14 | data_base_path = f"gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/data/v11_downstream/dynahate/{round_id}_{training_data_type}/{data_split}.tsv"
15 | df_inputs = pd.read_csv(data_base_path, sep="\t")
16 |
17 | inputs_all = list(df_inputs["inputs"])
18 | inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all]
19 |
20 | targets_all = list(df_inputs["targets"])
21 | targets = [int(i.split("[/class] [text]")[0].split("[class]")[-1])
22 | for i in targets_all]
23 | return inputs, targets
24 |
25 |
26 | def get_pred_class(bucket, base_path, task_name, check_point, base_model):
27 | """
28 | Get preds class labels
29 | """
30 | preds_blob = bucket.get_blob(
31 | base_path + f"{task_name}_{check_point}_predictions")
32 | preds_blob_list = preds_blob.download_as_string().decode(
33 | 'utf-8').split("\n")[1:]
34 |
35 | preds_class = []
36 | for i in preds_blob_list:
37 | try:
38 | preds_class.append(
39 | int(i.split("[/class]")[0].split("[class]")[-1]))
40 | except:
41 | print("output form not identifiable:", i)
42 | preds_class.append(99)
43 | return preds_class
44 |
45 |
46 | def main_get_accuracy(base_path, data_split, training_data_type, round_ids=None, check_points=None, is_print=False):
47 | base_path += f"{data_split}_eval/"
48 | bucket_name = base_path.split("/")[0]
49 | result_prefix = "/".join(base_path.split("/")[1:])
50 | base_path = "/".join(base_path.split("/")[1:])
51 | base_model = base_path.split("/")[-5]
52 | training_data = base_path.split("/")[-4]
53 |
54 | client = storage.Client()
55 | bucket = client.get_bucket(bucket_name)
56 |
57 | if check_points == None:
58 | check_points = get_check_points(
59 | client, bucket_name, result_prefix, after_check_point=-1)[1:]
60 |
61 | for check_point in check_points: # check_points
62 | all_inputs = []
63 | all_targets = []
64 | all_preds = []
65 | all_f1s = []
66 | all_accs = []
67 | for round_id in round_ids:
68 | task_name = f"dynahate_round_{round_id}_{training_data_type}"
69 |
70 | inputs, targets = get_gold_class(
71 | round_id, data_split, training_data_type)
72 | preds = get_pred_class(
73 | bucket, base_path, task_name, check_point, base_model)
74 | targets, preds = remove_unknown_elements(targets, preds)
75 | f1 = f1_score(targets, preds, average='macro')
76 | acc = get_accuracy(targets, preds, accuracy_type="exact")
77 |
78 | all_inputs += inputs
79 | all_targets += targets
80 | all_preds += preds
81 | all_f1s.append(f1)
82 | all_accs.append(acc)
83 | if is_print:
84 | print(
85 | f"round ({round_id}) {check_point}: f1 -- {f1} | accuracy -- {acc}")
86 |
87 | f1 = f1_score(all_targets, all_preds, average='macro')
88 | acc = get_accuracy(all_targets, all_preds, accuracy_type="exact")
89 | print(
90 | f"round {str(round_ids)} {check_point}: f1 -- {f1} | accuracy -- {acc}")
91 |
92 |
93 | def main_r1_yesno(training_data, training_data_type):
94 | base_model_2_ckpt = {
95 | "dynahate_round_1_yesno": {"v11-delphi-declare": [1280000],
96 | "v10-delphi": [],
97 | "v9-delphi": [],
98 | "unicorn-pt": [],
99 | "11B": []},
100 |
101 | "dynahate_round_1_yesno_100_shot": {"v11-delphi-declare": [1290200],
102 | "v10-delphi": [],
103 | "v9-delphi": [],
104 | "unicorn-pt": [],
105 | "11B": []},
106 |
107 | "dynahate_round_1_yesno_class_only": {"v11-delphi-declare": [1290200],
108 | "v10-delphi": [],
109 | "v9-delphi": [],
110 | "unicorn-pt": [1132100],
111 | "11B": []},
112 |
113 | "dynahate_round_1_yesno_class_only_100_shot": {"v11-delphi-declare": [1234100],
114 | "v10-delphi": [],
115 | "v9-delphi": [],
116 | "unicorn-pt": [1035200],
117 | "11B": []},
118 |
119 | "dynahate_round_1_discriminate": {"v11-delphi-declare": [1331000],
120 | "v10-delphi": [],
121 | "v9-delphi": [],
122 | "unicorn-pt": [],
123 | "11B": []},
124 |
125 | "dynahate_round_1_discriminate_100_shot": {"v11-delphi-declare": [1269800],
126 | "v10-delphi": [],
127 | "v9-delphi": [],
128 | "unicorn-pt": [1086200],
129 | "11B": []},
130 | }
131 |
132 | print(training_data)
133 | # "v11-delphi-declare", "v10-delphi" "v10-delphi", "v9-delphi", "unicorn-pt", "11B"
134 | for base_model in ["v11-delphi-declare", "unicorn-pt"]:
135 | print("=" * 10, base_model)
136 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/"
137 | check_points = None
138 | main_get_accuracy(base_path, "validation", training_data_type, round_ids=[
139 | 1], check_points=check_points, is_print=False)
140 |
141 |
142 | def main_all_yesno(training_data, training_data_type):
143 | base_model_2_ckpt = {
144 | "dynahate_all_yesno": {"v11-delphi-declare": [],
145 | "v10-delphi": [],
146 | "v9-delphi": [],
147 | "unicorn-pt": [],
148 | "11B": []},
149 |
150 | "dynahate_all_yesno_100_shot": {"v11-delphi-declare": [1254500],
151 | "v10-delphi": [],
152 | "v9-delphi": [],
153 | "unicorn-pt": [],
154 | "11B": []},
155 |
156 | "dynahate_all_yesno_class_only": {"v11-delphi-declare": [],
157 | "v10-delphi": [],
158 | "v9-delphi": [],
159 | "unicorn-pt": [],
160 | "11B": []},
161 |
162 | "dynahate_all_yesno_class_only_100_shot": {"v11-delphi-declare": [1336100],
163 | "v10-delphi": [],
164 | "v9-delphi": [],
165 | "unicorn-pt": [1101500],
166 | "11B": []},
167 |
168 | "dynahate_all_discriminate": {"v11-delphi-declare": [1351400],
169 | "v10-delphi": [],
170 | "v9-delphi": [],
171 | "unicorn-pt": [],
172 | "11B": []},
173 |
174 | "dynahate_all_discriminate_100_shot": {"v11-delphi-declare": [1239200],
175 | "v10-delphi": [],
176 | "v9-delphi": [],
177 | "unicorn-pt": [1070900],
178 | "11B": []},
179 | }
180 |
181 | print(training_data)
182 | for base_model in ["v11-delphi-declare",
183 | "unicorn-pt"]:
184 | print("=" * 10, base_model)
185 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/"
186 | check_points = None
187 | main_get_accuracy(base_path, "validation", training_data_type, round_ids=[
188 | 1, 2, 3, 4], check_points=check_points, is_print=False)
189 |
190 |
191 | if __name__ == "__main__":
192 | main_r1_yesno()
193 |
--------------------------------------------------------------------------------
/src/delphi/evaluate/evaluate_ethics_converted.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("script/evaluate")
3 | from evaluate_utils import *
4 |
5 | acc_type = {"ethics_deontology": 4,
6 | "ethics_justice": 4,
7 | "ethics_virtue": 5,
8 | "ethics_util": "exact",
9 | "ethics_cm": "exact"}
10 |
11 |
12 | def get_gold_class(training_data, data_split):
13 | """
14 | Get gold inputs and targets class labels; format ([class]1[/class] [text] [/text])
15 | """
16 | if data_split == "validation":
17 | data_split = "test"
18 | elif data_split == "test":
19 | data_split = "test_hard"
20 |
21 | task_name = training_data.split("_")[1]
22 |
23 | data_base_path = f"gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/data/v11_downstream/ethics/ethics_converted/{task_name}/{data_split}.tsv"
24 | df_inputs = pd.read_csv(data_base_path, sep="\t")
25 | df_inputs["targets"] = df_inputs["targets"].astype(str)
26 |
27 | inputs_all = list(df_inputs["inputs"])
28 | inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all]
29 |
30 | targets_all = list(df_inputs["targets"])
31 | targets = [int(i.split("[/class]")[0].split("[class]")[-1])
32 | for i in targets_all]
33 | return inputs, targets
34 |
35 |
36 | def get_pred_class(bucket, base_path, training_data, check_point, base_model):
37 | """
38 | Get preds class labels
39 | """
40 | preds_blob = bucket.get_blob(
41 | base_path + f"{training_data}_{check_point}_predictions")
42 | preds_blob_list = preds_blob.download_as_string().decode(
43 | 'utf-8').split("\n")[1:]
44 |
45 | preds_class = []
46 | for i in preds_blob_list:
47 | try:
48 | if "v10" in base_model or "v11" in base_model or "unicorn-pt" in base_model:
49 | preds_class.append(
50 | int(i.split("[/class]")[0].split("[class]")[-1]))
51 | else:
52 | preds_class.append(
53 | int(i.split(" ⁇ /class>")[0].split(" ⁇ class>")[-1]))
54 | except:
55 | print("output form not identifiable:", i)
56 | preds_class.append(99)
57 | return preds_class
58 |
59 |
60 | def get_ethics_accuracy(targets, preds, accuracy_type="exact"):
61 | accuracies = []
62 | for i in range(len(targets)):
63 | t_c = targets[i]
64 | p_c = preds[i]
65 | if t_c != 99 and p_c != 99:
66 | if accuracy_type == "exact":
67 | accuracies.append(int(t_c == p_c))
68 | elif accuracy_type == "non conflict":
69 | accuracies.append(int(t_c * p_c >= 0))
70 | else:
71 | accuracies.append(
72 | not int((t_c == -1 or p_c == -1) and (t_c * p_c != 1)))
73 |
74 | if type(accuracy_type) == type(0):
75 | group_acc = []
76 | for i in range(0, len(accuracies), accuracy_type):
77 | if (accuracies[i] * accuracies[i + 1] * accuracies[i + 2] * accuracies[i + 3]) == 1:
78 | group_acc.append(1)
79 | else:
80 | group_acc.append(0)
81 | return sum(group_acc) / len(group_acc)
82 | else:
83 | return sum(accuracies) / len(accuracies)
84 |
85 |
86 | def main_get_accuracy(base_path, data_split, check_points=None):
87 | base_model = base_path.split("/")[-4]
88 | training_data = base_path.split("/")[-3]
89 | if base_model == "v10-delphi":
90 | base_path = base_path.replace(training_data, "new_" + training_data)
91 |
92 | base_path += f"{data_split}_eval/"
93 | bucket_name = base_path.split("/")[0]
94 | result_prefix = "/".join(base_path.split("/")[1:])
95 | base_path = "/".join(base_path.split("/")[1:])
96 | training_data = base_path.split("/")[-4]
97 | if "100_shot" in training_data:
98 | training_data = training_data.replace("_100_shot", "")
99 | base_model = base_path.split("/")[-5]
100 |
101 | client = storage.Client()
102 | bucket = client.get_bucket(bucket_name)
103 |
104 | if check_points == None:
105 | check_points = get_check_points(
106 | client, bucket_name, result_prefix, after_check_point=-1)[1:]
107 |
108 | for check_point in check_points:
109 | inputs, targets = get_gold_class(training_data, data_split)
110 | preds = get_pred_class(
111 | bucket, base_path, training_data, check_point, base_model)
112 | if base_model == "v10-delphi" and "util" not in training_data:
113 | preds_new = [t - 1 for t in preds]
114 | preds = preds_new
115 |
116 | index_to_remove = []
117 | for i, p in enumerate(preds):
118 | if p not in [-1, 1]:
119 | index_to_remove.append(i)
120 | print("remove:", i, p)
121 | targets = np.delete(targets, index_to_remove).tolist()
122 | preds = np.delete(preds, index_to_remove).tolist()
123 |
124 | acc = get_ethics_accuracy(
125 | targets, preds, accuracy_type=acc_type["_".join(training_data.split("_")[:2])])
126 | print(f"{check_point}: accuracy -- {acc}")
127 |
128 |
129 | def main_all(training_data):
130 | print("-" * 30, f"{training_data}", "-" * 30)
131 | base_model_2_ckpt = {"ethics_cm": {"v10-delphi": [],
132 | "v9-delphi": [1448300],
133 | "unicorn-pt": [1213700],
134 | "11B": []},
135 |
136 | "ethics_deontology": {"v10-delphi": [],
137 | "v9-delphi": [],
138 | "unicorn-pt": [],
139 | "11B": []},
140 |
141 | "ethics_justice": {"v10-delphi": [],
142 | "v9-delphi": [1356500],
143 | "unicorn-pt": [1137200],
144 | "11B": []},
145 |
146 | "ethics_virtue": {"v10-delphi": [],
147 | "v9-delphi": [1356500],
148 | "unicorn-pt": [1050500],
149 | "11B": []},
150 |
151 | "ethics_util": {"v10-delphi": [],
152 | "v9-delphi": [1351400],
153 | "unicorn-pt": [1035200],
154 | "11B": []},
155 |
156 |
157 | "ethics_cm_100_shot": {"v10-delphi": [],
158 | "v9-delphi": [1315700],
159 | "unicorn-pt": [1055600],
160 | "11B": []},
161 |
162 | "ethics_deontology_100_shot": {"v10-delphi": [],
163 | "v9-delphi": [1274900],
164 | "unicorn-pt": [1030100],
165 | "11B": []},
166 |
167 | "ethics_justice_100_shot": {"v10-delphi": [],
168 | "v9-delphi": [1290200],
169 | "unicorn-pt": [1065800],
170 | "11B": []},
171 |
172 | "ethics_virtue_100_shot": {"v10-delphi": [],
173 | "v9-delphi": [1295300],
174 | "unicorn-pt": [1060700],
175 | "11B": []},
176 |
177 | "ethics_util_100_shot": {"v10-delphi": [],
178 | "v9-delphi": [1295300],
179 | "unicorn-pt": [1076000],
180 | "11B": []},
181 | }
182 |
183 | for base_model in ["v11-delphi-declare", "unicorn-pt"]:
184 | print("=" * 10, base_model)
185 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/"
186 | check_points = None
187 | main_get_accuracy(base_path, "validation", check_points)
188 |
189 |
190 | if __name__ == "__main__":
191 | main_all("ethics_cm_converted_class_only")
192 | main_all("ethics_deontotlogy_converted_class_only")
193 | main_all("ethics_justice_converted_class_only")
194 | main_all("ethics_virtue_converted_class_only")
195 |
--------------------------------------------------------------------------------
/src/delphi/evaluate/evaluate_ethics_v9.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("script/evaluate")
3 | from evaluate_utils import *
4 |
5 |
6 | acc_type = {"ethics_deontology": 4,
7 | "ethics_justice": 4,
8 | "ethics_virtue": 5,
9 | "ethics_util": "exact",
10 | "ethics_cm": "exact"}
11 |
12 |
13 | def get_gold_class(training_data, data_split):
14 | """
15 | Get gold inputs and targets class labels; format (1 )
16 | """
17 | if data_split == "validation":
18 | data_split = "test"
19 | elif data_split == "test":
20 | data_split = "test_hard"
21 |
22 | task_name = training_data.split("_")[-1]
23 |
24 | data_base_path = f"gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/data/v9_downstream/ethics/ethics_st/{task_name}/{data_split}.tsv"
25 | df_inputs = pd.read_csv(data_base_path, sep="\t")
26 | df_inputs["targets"] = df_inputs["targets"].astype(str)
27 |
28 | inputs_all = list(df_inputs["inputs"])
29 | inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all]
30 |
31 | targets_all = list(df_inputs["targets"])
32 | targets = [int(i.split("")[0].split("")[-1])
33 | for i in targets_all]
34 | return inputs, targets
35 |
36 |
37 | def get_pred_class(bucket, base_path, training_data, check_point, base_model):
38 | """
39 | Get preds class labels
40 | """
41 | preds_blob = bucket.get_blob(
42 | base_path + f"{training_data}_{check_point}_predictions")
43 | preds_blob_list = preds_blob.download_as_string().decode(
44 | 'utf-8').split("\n")[1:]
45 |
46 | preds_class = []
47 | for i in preds_blob_list:
48 | try:
49 | if "v10" in base_model:
50 | preds_class.append(
51 | int(i.split("[/class]")[0].split("[class]")[-1]))
52 | else:
53 | preds_class.append(
54 | int(i.split(" ⁇ /class>")[0].split(" ⁇ class>")[-1]))
55 | except:
56 | print("output form not identifiable:", i)
57 | preds_class.append(99)
58 | return preds_class
59 |
60 |
61 | def get_ethics_accuracy(targets, preds, accuracy_type="exact"):
62 | accuracies = []
63 | for i in range(len(targets)):
64 | t_c = targets[i]
65 | p_c = preds[i]
66 | if t_c != 99 and p_c != 99:
67 | if accuracy_type == "exact":
68 | accuracies.append(int(t_c == p_c))
69 | elif accuracy_type == "non conflict":
70 | accuracies.append(int(t_c * p_c >= 0))
71 | else:
72 | accuracies.append(
73 | not int((t_c == -1 or p_c == -1) and (t_c * p_c != 1)))
74 |
75 | if accuracy_type == 4:
76 | group_acc = []
77 | for i in range(0, len(accuracies), accuracy_type):
78 | if (accuracies[i] * accuracies[i + 1] * accuracies[i + 2] * accuracies[i + 3]) == 1:
79 | group_acc += [1] * accuracy_type
80 | else:
81 | group_acc += [0] * accuracy_type
82 | return sum(group_acc) / len(group_acc)
83 | elif accuracy_type == 5:
84 | group_acc = []
85 | for i in range(0, len(accuracies), accuracy_type):
86 | if (accuracies[i] * accuracies[i + 1] * accuracies[i + 2] * accuracies[i + 3] * accuracies[i + 4]) == 1:
87 | group_acc += [1] * accuracy_type
88 | else:
89 | group_acc += [0] * accuracy_type
90 | return sum(group_acc) / len(group_acc)
91 | else:
92 | return sum(accuracies) / len(accuracies)
93 |
94 |
95 | def main_get_accuracy(base_path, data_split, check_points=None, is_save_results=True):
96 | base_model = base_path.split("/")[-4]
97 | training_data = base_path.split("/")[-3]
98 | if base_model == "v10-delphi":
99 | base_path = base_path.replace(training_data, "new_" + training_data)
100 |
101 | base_path += f"{data_split}_eval/"
102 | bucket_name = base_path.split("/")[0]
103 | result_prefix = "/".join(base_path.split("/")[1:])
104 | base_path = "/".join(base_path.split("/")[1:])
105 | training_data = base_path.split("/")[-4]
106 | if "100_shot" in training_data:
107 | training_data = training_data.replace("_100_shot", "")
108 | base_model = base_path.split("/")[-5]
109 |
110 | client = storage.Client()
111 | bucket = client.get_bucket(bucket_name)
112 |
113 | if check_points == None:
114 | check_points = get_check_points(
115 | client, bucket_name, result_prefix, after_check_point=-1)[1:]
116 |
117 | for check_point in check_points:
118 | inputs, targets = get_gold_class(training_data, data_split)
119 | preds = get_pred_class(
120 | bucket, base_path, training_data, check_point, base_model)
121 | if base_model == "v10-delphi" and "util" not in training_data:
122 | preds_new = [t - 1 for t in preds]
123 | preds = preds_new
124 | index_to_remove = []
125 | for i, p in enumerate(preds):
126 | if p not in [-1, 1, 2]:
127 | print("bad:", i, p)
128 |
129 | acc = get_ethics_accuracy(
130 | targets, preds, accuracy_type=acc_type[training_data.replace("new_", "")])
131 | print(f"{check_point}: accuracy -- {acc}")
132 |
133 | if is_save_results:
134 | df_data = pd.DataFrame()
135 | df_data["input"] = inputs
136 | df_data["target"] = targets
137 | df_data["pred"] = preds
138 |
139 | df_data.to_csv(
140 | f"100_shot-{training_data}-{base_model}-{data_split}.tsv", index=False, sep="\t")
141 |
142 |
143 | def main_all(training_data):
144 | print("-" * 30, f"{training_data}", "-" * 30)
145 | base_model_2_ckpt = {"ethics_cm": {"v10-delphi": [],
146 | "v9-delphi": [1448300],
147 | "v9-delphi-new": [1290200],
148 | "unicorn-pt": [1213700],
149 | "11B": [1153000]},
150 |
151 | "ethics_deontology": {"v10-delphi": [],
152 | "v9-delphi": [1361600],
153 | "v9-delphi-new": [1315700],
154 | "unicorn-pt": [1157600],
155 | "11B": [1387600]},
156 |
157 | "ethics_justice": {"v10-delphi": [],
158 | "v9-delphi": [1356500],
159 | "v9-delphi-new": [1249400],
160 | "unicorn-pt": [1137200],
161 | "11B": [1112200]},
162 |
163 | "ethics_virtue": {"v10-delphi": [],
164 | "v9-delphi": [1356500],
165 | "v9-delphi-new": [1336100],
166 | "unicorn-pt": [1050500],
167 | "11B": [1040800]},
168 |
169 | "ethics_util": {"v10-delphi": [],
170 | "v9-delphi": [1351400],
171 | "v9-delphi-new": [1249400],
172 | "unicorn-pt": [1035200],
173 | "11B": [1015300]},
174 |
175 |
176 | "ethics_cm_100_shot": {"v10-delphi": [],
177 | "v9-delphi": [1315700],
178 | "v9-delphi-new": [1264700],
179 | "unicorn-pt": [1055600],
180 | "11B": [1112200]},
181 |
182 | "ethics_deontology_100_shot": {"v10-delphi": [],
183 | "v9-delphi": [1274900],
184 | "v9-delphi-new": [1315700],
185 | "unicorn-pt": [1030100],
186 | "11B": [1275400]},
187 |
188 | "ethics_justice_100_shot": {"v10-delphi": [],
189 | "v9-delphi": [1290200],
190 | "v9-delphi-new": [1331000],
191 | "unicorn-pt": [1065800],
192 | "11B": [1504900]},
193 |
194 | "ethics_virtue_100_shot": {"v10-delphi": [],
195 | "v9-delphi": [1295300],
196 | "v9-delphi-new": [1264700],
197 | "unicorn-pt": [1060700],
198 | "11B": [1020400]},
199 |
200 | "ethics_util_100_shot": {"v10-delphi": [],
201 | "v9-delphi": [1295300],
202 | "v9-delphi-new": [1341200],
203 | "unicorn-pt": [1076000],
204 | "11B": [1275400]},
205 | }
206 |
207 | for base_model in ["v9-delphi-new", "unicorn-pt", "11B"]:
208 | print("=" * 10, base_model)
209 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/"
210 | check_points = base_model_2_ckpt[training_data][base_model]
211 | main_get_accuracy(base_path, "validation", check_points)
212 | main_get_accuracy(base_path, "test", check_points)
213 |
214 |
215 | if __name__ == "__main__":
216 | main_all("ethics_cm_100_shot")
217 | main_all("ethics_deontology_100_shot")
218 | main_all("ethics_justice_100_shot")
219 | main_all("ethics_virtue_100_shot")
220 | main_all("ethics_util_100_shot")
221 |
--------------------------------------------------------------------------------
/src/delphi/evaluate/evaluate_latenthatred_v9.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import precision_recall_fscore_support
2 | import sys
3 | import numpy as np
4 | sys.path.append("script/evaluate")
5 | from evaluate_utils import *
6 |
7 |
8 | def get_gold_class(data_split):
9 | """
10 | Get gold inputs and targets class labels; format (1 )
11 | """
12 | if data_split == "validation":
13 | data_split = "dev"
14 |
15 | data_base_path = f"gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/data/v9_downstream/latenthatred/latenthatred_st/{data_split}.tsv"
16 | df_inputs = pd.read_csv(data_base_path, sep="\t")
17 | df_inputs["targets"] = df_inputs["targets"].astype(str)
18 |
19 | inputs_all = list(df_inputs["inputs"])
20 | inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all]
21 |
22 | targets_all = list(df_inputs["targets"])
23 | targets = [int(i.split("")[0].split("")[-1])
24 | for i in targets_all]
25 | return inputs, targets
26 |
27 |
28 | def get_pred_class(bucket, base_path, training_data, check_point, base_model):
29 | """
30 | Get preds class labels
31 | """
32 | preds_blob = bucket.get_blob(
33 | base_path + f"{training_data}_{check_point}_predictions")
34 | preds_blob_list = preds_blob.download_as_string().decode(
35 | 'utf-8').split("\n")[1:]
36 |
37 | preds_class = []
38 | for i in preds_blob_list:
39 | try:
40 | if "v10" in base_model or "v11" in base_model:
41 | preds_class.append(
42 | int(i.split("[/class]")[0].split("[class]")[-1]))
43 | else:
44 | preds_class.append(
45 | int(i.split(" ⁇ /class>")[0].split(" ⁇ class>")[-1]))
46 | except:
47 | print("output form not identifiable:", i)
48 | preds_class.append(99)
49 | return preds_class
50 |
51 |
52 | def main_get_accuracy(base_path, data_split, check_points=None, is_save_results=True):
53 | base_model = base_path.split("/")[-4]
54 | training_data = base_path.split("/")[-3]
55 | if base_model == "v10-delphi":
56 | base_path = base_path.replace(training_data, "new_" + training_data)
57 | task_name = "new_latenthatred"
58 | else:
59 | task_name = "latenthatred"
60 |
61 | base_path += f"{data_split}_eval/"
62 | bucket_name = base_path.split("/")[0]
63 | result_prefix = "/".join(base_path.split("/")[1:])
64 | base_path = "/".join(base_path.split("/")[1:])
65 | # training_data = base_path.split("/")[-4]
66 | base_model = base_path.split("/")[-5]
67 |
68 | client = storage.Client()
69 | bucket = client.get_bucket(bucket_name)
70 |
71 | if check_points == None:
72 | check_points = get_check_points(
73 | client, bucket_name, result_prefix, after_check_point=-1)[1:]
74 |
75 | for check_point in check_points:
76 | inputs, targets = get_gold_class(data_split)
77 | preds = get_pred_class(
78 | bucket, base_path, task_name, check_point, base_model)
79 | if base_model == "v10-delphi":
80 | preds_new = [t - 1 for t in preds]
81 | preds = preds_new
82 |
83 | if is_save_results:
84 | df_data = pd.DataFrame()
85 | df_data["input"] = inputs
86 | df_data["target"] = targets
87 | df_data["pred"] = preds
88 |
89 | df_data.to_csv(
90 | f"zero_shot-{training_data}-{base_model}-{data_split}.tsv", index=False, sep="\t")
91 |
92 |
93 | def main_all_0001(training_data):
94 | print("-" * 30, f"{training_data}", "-" * 30)
95 | base_model_2_ckpt = {"latenthatred": {
96 | "v10-delphi": [],
97 | "v9-delphi": [],
98 | "unicorn-pt": [],
99 | "11B": [], },
100 |
101 | "latenthatred_100_shot": {
102 | "v10-delphi": [],
103 | "v9-delphi": [],
104 | "unicorn-pt": [],
105 | "11B": [], }
106 | }
107 |
108 | for base_model in ["v9-delphi", "unicorn-pt"]:
109 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0001_bs-16/"
110 | check_points = None
111 | main_get_accuracy(base_path, "validation", check_points)
112 |
113 |
114 | def main_zero_shot(training_data):
115 | print("-" * 30, f"{training_data}", "-" * 30)
116 | base_model_2_ckpt = {"dynahate_round_1_st": {"v10-delphi": [],
117 | "v9-delphi": [1341200],
118 | "v9-delphi-new": [1264700],
119 | "unicorn-pt": [1127000],
120 | "11B": [1102000]},
121 |
122 | "dynahate_all_st": {"v10-delphi": [],
123 | "v9-delphi": [1402400],
124 | "v9-delphi-new": [1387100],
125 | "unicorn-pt": [1157600],
126 | "11B": [1132600]},
127 |
128 | "dynahate_round_1_st_100_shot": {"v10-delphi": [],
129 | "v9-delphi": [1325900],
130 | "unicorn-pt": [1106600],
131 | "11B": [1076500]},
132 |
133 | "dynahate_all_st_100_shot": {"v10-delphi": [],
134 | "v9-delphi": [1392200],
135 | "unicorn-pt": [1147400],
136 | "11B": [1076500]}
137 | }
138 |
139 | for base_model in ["v9-delphi-new", "unicorn-pt", "11B"]:
140 | print("=" * 10, base_model)
141 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/"
142 | check_points = base_model_2_ckpt[training_data][base_model]
143 | main_get_accuracy(base_path, "test", check_points)
144 |
145 |
146 | def main_all(training_data):
147 | print("-" * 30, f"{training_data}", "-" * 30)
148 | base_model_2_ckpt = {"latenthatred": {"v11-delphi-declare": [1356500],
149 | "v10-delphi": [1318400],
150 | "v9-delphi": [1397300],
151 | "v9-delphi-new": [1300400],
152 | "unicorn-pt": [1157600],
153 | "11B": [1071400], },
154 |
155 | "latenthatred_100_shot": {"v11-delphi-declare": [1274900],
156 | "v10-delphi": [1328600],
157 | "v9-delphi": [1285100],
158 | "v9-delphi-new": [1259600],
159 | "unicorn-pt": [1055600],
160 | "11B": [1102000], }
161 | }
162 |
163 | for base_model in ["v9-delphi-new", "unicorn-pt", "11B"]:
164 | print("=" * 10, base_model)
165 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/"
166 | check_points = base_model_2_ckpt[training_data][base_model]
167 | main_get_accuracy(base_path, "test", check_points)
168 |
169 |
170 | if __name__ == "__main__":
171 | main_zero_shot("dynahate_round_1_st")
172 |
--------------------------------------------------------------------------------
/src/delphi/evaluate/evaluate_latenthatred_yesno.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import precision_recall_fscore_support
2 | import sys
3 | import numpy as np
4 | sys.path.append("script/evaluate")
5 | from evaluate_utils import *
6 |
7 |
8 | def get_gold_class(data_split):
9 | """
10 | Get gold inputs and targets class labels; format (1 )
11 | """
12 | if data_split == "validation":
13 | data_split = "dev"
14 |
15 | data_base_path = f"gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/data/v11_downstream/latenthatred/latenthatred_yesno/{data_split}.tsv"
16 | df_inputs = pd.read_csv(data_base_path, sep="\t")
17 | df_inputs["targets"] = df_inputs["targets"].astype(str)
18 |
19 | inputs_all = list(df_inputs["inputs"])
20 | inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all]
21 |
22 | targets_all = list(df_inputs["targets"])
23 | targets = [int(i.split("[/class]")[0].split("[class]")[-1])
24 | for i in targets_all]
25 | return inputs, targets
26 |
27 |
28 | def get_pred_class(bucket, base_path, training_data, check_point, base_model):
29 | """
30 | Get preds class labels
31 | """
32 | preds_blob = bucket.get_blob(
33 | base_path + f"{training_data}_{check_point}_predictions")
34 | preds_blob_list = preds_blob.download_as_string().decode(
35 | 'utf-8').split("\n")[1:]
36 |
37 | preds_class = []
38 | for i in preds_blob_list:
39 | try:
40 | preds_class.append(
41 | int(i.split("[/class]")[0].split("[class]")[-1]))
42 | except:
43 | print("output form not identifiable:", i)
44 | preds_class.append(99)
45 | return preds_class
46 |
47 |
48 | def main_get_accuracy(base_path, data_split, check_points=None):
49 | base_model = base_path.split("/")[-4]
50 | training_data = base_path.split("/")[-3]
51 |
52 | base_path += f"{data_split}_eval/"
53 | bucket_name = base_path.split("/")[0]
54 | result_prefix = "/".join(base_path.split("/")[1:])
55 | base_path = "/".join(base_path.split("/")[1:])
56 | training_data = base_path.split("/")[-4]
57 | base_model = base_path.split("/")[-5]
58 |
59 | if "100_shot" in training_data:
60 | training_data = training_data.replace("_100_shot", "")
61 |
62 | client = storage.Client()
63 | bucket = client.get_bucket(bucket_name)
64 |
65 | if check_points == None:
66 | check_points = get_check_points(
67 | client, bucket_name, result_prefix, after_check_point=-1)[1:]
68 |
69 | for check_point in check_points:
70 | inputs, targets = get_gold_class(data_split)
71 | preds = get_pred_class(
72 | bucket, base_path, training_data, check_point, base_model)
73 | targets, preds = remove_unknown_elements(targets, preds)
74 | scores = precision_recall_fscore_support(
75 | targets, preds, average='binary')
76 | acc = get_accuracy(targets, preds, accuracy_type="exact")
77 |
78 | print(check_point, scores, acc)
79 |
80 |
81 | def main_all(training_data):
82 | print("-" * 30, f"{training_data}", "-" * 30)
83 | base_model_2_ckpt = {"latenthatred": {"v11-delphi-declare": [],
84 | "v10-delphi": [],
85 | "v9-delphi": [],
86 | "unicorn-pt": [],
87 | "11B": [], },
88 |
89 | "latenthatred_100_shot": {"v11-delphi-declare": [],
90 | "v10-delphi": [],
91 | "v9-delphi": [],
92 | "unicorn-pt": [],
93 | "11B": [], }
94 | }
95 |
96 | for base_model in ["v11-delphi-declare", "unicorn-pt"]:
97 | print("=" * 10, base_model)
98 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/"
99 | check_points = None
100 | main_get_accuracy(base_path, "validation", check_points)
101 |
102 |
103 | if __name__ == "__main__":
104 | main_all("latenthatred_yesno")
105 | main_all("latenthatred_yesno_class_only")
106 |
--------------------------------------------------------------------------------
/src/delphi/train/evaluate.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 |
3 | """
4 | Evaluate the model checkpoint
5 | """
6 |
7 | import mixtures
8 | import tasks
9 | import t5
10 | import os
11 | import sys
12 | import util
13 | import seqio
14 | import click
15 | import logging
16 | import tensorflow.compat.v1 as tf
17 |
18 | print("python", sys.version)
19 | print("t5", t5.__version__)
20 | print("tf", tf.__version__)
21 | print("seqio", seqio.__version__)
22 |
23 | tf.disable_v2_behavior()
24 |
25 | # N.B. We must import tasks and mixtures here so that they are registered and available for evaluation.
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 |
30 | @click.command()
31 | @click.argument("mixture", type=str)
32 | @click.argument("results_dir", type=str)
33 | # The name of the TPU. Defaults to the TPU_NAME environment variable.
34 | @click.argument("tpu-name", type=str)
35 | # The topology of the TPU. Defaults to the TPU_TOPOLOGY environment variable.
36 | @click.argument("tpu-topology", type=str)
37 | @click.argument("split", type=str)
38 | @click.argument("checkpoint", type=int)
39 | @click.option(
40 | "--model-parallelism",
41 | type=int,
42 | default=8,
43 | help="The degree of model parallelism to use. Defaults to 8.",
44 | )
45 | def evaluate(
46 | mixture: str,
47 | results_dir: str,
48 | split: str,
49 | checkpoint: int,
50 | model_parallelism: int,
51 | tpu_name: str,
52 | tpu_topology: str,
53 | ) -> None:
54 | """
55 | Evaluate the model located at RESULTS_DIR on MIXTURE.
56 | """
57 |
58 | print(tpu_name)
59 | print(tpu_topology)
60 |
61 | # Initialize arguments
62 | if tpu_topology == "v3-32":
63 | batch_size = 16
64 | model_parallelism = 8
65 | elif tpu_topology == "v3-8":
66 | batch_size = 8
67 | model_parallelism = 8
68 | else:
69 | print("ERROR: tpu_topology invalid")
70 | return
71 |
72 | # Validate arguments
73 | util.validate_path(results_dir)
74 |
75 | checkpoints = util.get_result_check_points(
76 | results_dir, split, "ethics_virtue")
77 | checkpoints = util.get_result_check_points(
78 | results_dir, split, "latenthatred")
79 |
80 | print("-" * 10, "checkpoints todo", "-" * 10)
81 |
82 | if checkpoint == 100:
83 | checkpoints_to_eval = None
84 | elif checkpoint == 0:
85 | checkpoints_to_eval = checkpoints
86 | else:
87 | checkpoints_to_eval = [checkpoint]
88 |
89 | print(checkpoints_to_eval)
90 |
91 | # Run evaluation
92 | model = t5.models.MtfModel(
93 | model_dir=results_dir,
94 | tpu=tpu_name,
95 | tpu_topology=tpu_topology,
96 | model_parallelism=model_parallelism,
97 | batch_size=batch_size,
98 | sequence_length={"inputs": 512, "targets": 128},
99 | learning_rate_schedule=None,
100 | save_checkpoints_steps=5000,
101 | keep_checkpoint_max=None,
102 | iterations_per_loop=100,
103 | )
104 |
105 | model.eval(
106 | mixture_or_task_name=mixture,
107 | checkpoint_steps=checkpoints_to_eval,
108 | split=split,
109 | )
110 |
111 |
112 | if __name__ == "__main__":
113 | evaluate()
114 |
--------------------------------------------------------------------------------
/src/delphi/train/fine-tune.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 |
3 | """Fine-tune T5 based models."""
4 |
5 | import mixtures
6 | import tasks
7 | import warnings
8 | import util
9 | import t5
10 | import sys
11 | import seqio
12 | import click
13 | import logging
14 | import tensorflow.compat.v1 as tf
15 |
16 | print("python", sys.version)
17 | print("t5", t5.__version__)
18 | print("tf", tf.__version__)
19 | print("seqio", seqio.__version__)
20 |
21 | # We must import tasks and mixtures here so that the tasks and mixtures are registered and available for training.
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 | v = tf.compat.v1.logging.FATAL
26 | tf.compat.v1.logging.set_verbosity(v)
27 | tf.disable_v2_behavior()
28 |
29 | config = tf.ConfigProto()
30 | config.gpu_options.allow_growth = True
31 | session = tf.InteractiveSession(config=config)
32 |
33 | warnings.filterwarnings("ignore", category=DeprecationWarning)
34 |
35 | PRETRAINED_MODELS = {
36 | "small": ("gs://t5-data/pretrained_models/small/", -1),
37 | "base": ("gs://t5-data/pretrained_models/base/", -1),
38 | "large": ("gs://t5-data/pretrained_models/large/", -1),
39 | "3B": ("gs://t5-data/pretrained_models/3B/", -1),
40 | "11B": ("gs://t5-data/pretrained_models/11B/", -1),
41 | "unicorn-pt": ("gs://ai2-mosaic-public/projects/rainbow/v1.0/unicorns/lr-2e-3_batch-size-32/", -1),
42 | "v9-delphi": ("gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/v9/unicorn-pt/sbic_commonsense_morality_joint_all_proportional/lr-0.0001_bs-16/", 1264700),
43 | "v9-delphi-new": ("gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/v9/unicorn-pt/sbic_commonsense_morality_joint_all_proportional/lr-0.0001_bs-16/", 1239200),
44 | "v9-delphi-large": ("gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/v9/large/sbic_commonsense_morality_joint_all_proportional/lr-0.0001_bs-8/", 1643100),
45 | }
46 |
47 |
48 | @click.command()
49 | @click.argument("mixture", type=str)
50 | @click.argument("results_dir", type=str)
51 | # The name of the TPU. Defaults to the TPU_NAME environment variable.
52 | @click.argument("tpu-name", type=str)
53 | # The topology of the TPU. Defaults to the TPU_TOPOLOGY environment variable.
54 | @click.argument("tpu-topology", type=str)
55 | @click.argument("pretrained-model", type=str)
56 | @click.option(
57 | "--split",
58 | type=str,
59 | default="train",
60 | help="The split on which to train. Defaults to 'train'.",
61 | )
62 | @click.option(
63 | "--n-steps",
64 | type=int,
65 | default=600000,
66 | help="The number of gradient updates. Defaults to 25,000.",
67 | )
68 | @click.option(
69 | "--save-checkpoints-steps",
70 | type=int,
71 | default=5000,
72 | help=(
73 | "The number of steps to take before saving a checkpoint. Defaults to"
74 | " 5000."
75 | ),
76 | )
77 | @click.option(
78 | "--n-checkpoints-to-keep",
79 | type=int,
80 | default=300,
81 | help=(
82 | "The number of checkpoints to keep during fine-tuning. Defaults"
83 | " to 4."
84 | ),
85 | )
86 | @click.option(
87 | "--learning-rate",
88 | type=float,
89 | default=2e-4,
90 | help="The learning rate to use for training. Defaults to 3e-3.",
91 | )
92 | @click.option(
93 | "--continue_finetune",
94 | type=bool,
95 | default=False,
96 | help="Whether to continue training from an existing checkpoint.",
97 | )
98 | def fine_tune(
99 | mixture: str,
100 | results_dir: str,
101 | split: str,
102 | pretrained_model: str,
103 | n_steps: int,
104 | learning_rate: float,
105 | save_checkpoints_steps: int,
106 | n_checkpoints_to_keep: int,
107 | tpu_name: str,
108 | tpu_topology: str,
109 | continue_finetune: bool,
110 | ) -> None:
111 | """
112 | Fine-tune the model on MIXTURE, writing results to RESULTS_DIR.
113 | """
114 |
115 | # Initialize arguments
116 | if tpu_topology == "v3-32":
117 | batch_size = 16
118 | model_parallelism = 32
119 | elif tpu_topology == "v3-8":
120 | batch_size = 8
121 | model_parallelism = 8
122 | else:
123 | print("ERROR: tpu_topology invalid")
124 | return
125 |
126 | pretrained_checkpoint_step = -1
127 |
128 | # Get result path given arguments
129 | result_path = util.get_result_path(
130 | results_dir, pretrained_model, mixture, learning_rate, batch_size)
131 |
132 | # Validate path
133 | util.validate_path(results_dir, pretrained_model, PRETRAINED_MODELS)
134 |
135 | # Process arguments
136 | if pretrained_model in PRETRAINED_MODELS:
137 | pretrained_model, pretrained_checkpoint_step = PRETRAINED_MODELS[pretrained_model]
138 |
139 | # If the training stops before finishing and we want to continue from the last checkpoint
140 | if continue_finetune:
141 | pretrained_model = result_path
142 |
143 | # Print arguments
144 | util.print_arguments(result_path, results_dir, mixture, split, pretrained_model,
145 | pretrained_checkpoint_step, n_steps, batch_size, model_parallelism,
146 | save_checkpoints_steps, n_checkpoints_to_keep, learning_rate,
147 | tpu_name, tpu_topology, tasks, continue_finetune)
148 |
149 | # Run fine-tuning
150 | model = t5.models.MtfModel(
151 | model_dir=result_path,
152 | tpu=tpu_name,
153 | tpu_topology=tpu_topology,
154 | model_parallelism=model_parallelism,
155 | batch_size=batch_size,
156 | sequence_length={"inputs": 512, "targets": 128},
157 | learning_rate_schedule=learning_rate,
158 | save_checkpoints_steps=save_checkpoints_steps,
159 | keep_checkpoint_max=n_checkpoints_to_keep,
160 | iterations_per_loop=100,
161 | )
162 |
163 | model.finetune(
164 | mixture_or_task_name=mixture,
165 | pretrained_model_dir=pretrained_model,
166 | pretrained_checkpoint_step=pretrained_checkpoint_step,
167 | finetune_steps=n_steps,
168 | split=split,
169 | )
170 |
171 |
172 | if __name__ == "__main__":
173 | fine_tune()
174 |
--------------------------------------------------------------------------------
/src/delphi/train/mixtures.py:
--------------------------------------------------------------------------------
1 | """
2 | Data mixtures
3 | """
4 |
5 | import os
6 | import t5
7 | import tasks
8 | import rates
9 | import seqio
10 | import functools
11 |
12 | import util
13 |
14 | # ################### register mixtures ###################
15 | # seqio.MixtureRegistry.add(
16 | # "commonsense_morality_joint_all_proportional",
17 | # ["moral_acceptability",
18 | # "moral_agreement",
19 | # "moral_comparison"],
20 | # default_rate=rates.MIXING_RATES["proportional"]
21 | # )
22 | # util.print_mixture_examples("commonsense_morality_joint_all_proportional")
23 | #
24 | # seqio.MixtureRegistry.add(
25 | # "commonsense_morality_separate_all_proportional",
26 | # ["moral_acceptability_class",
27 | # "moral_acceptability_text",
28 | # "moral_agreement_class",
29 | # "moral_agreement_text",
30 | # "moral_comparison_class"],
31 | # default_rate=rates.MIXING_RATES["proportional"]
32 | # )
33 | # util.print_mixture_examples("commonsense_morality_separate_all_proportional", num_ex=1)
34 |
35 |
36 | # seqio.MixtureRegistry.add(
37 | # "sbic_commonsense_morality_joint_comparison_double_all_proportional",
38 | # ["sbic_moral_acceptability",
39 | # "sbic_moral_agreement",
40 | # "sbic_moral_comparison_double"],
41 | # default_rate=rates.MIXING_RATES["proportional"]
42 | # )
43 | # util.print_mixture_examples("sbic_commonsense_morality_joint_comparison_double_all_proportional")
44 | #
45 | #
46 | # seqio.MixtureRegistry.add(
47 | # "sbic_commonsense_morality_separate_all_proportional",
48 | # ["sbic_moral_acceptability_class",
49 | # "sbic_moral_acceptability_text",
50 | # "sbic_moral_agreement_class",
51 | # "sbic_moral_agreement_text",
52 | # "sbic_moral_comparison_class"],
53 | # default_rate=rates.MIXING_RATES["proportional"]
54 | # )
55 | # util.print_mixture_examples("sbic_commonsense_morality_separate_all_proportional")
56 | #
57 | #
58 | # seqio.MixtureRegistry.add(
59 | # "commonsense_morality_separate_wo_agreement_class_all_proportional",
60 | # ["moral_acceptability_class",
61 | # "moral_acceptability_text",
62 | # "moral_agreement_text",
63 | # "moral_comparison_class"],
64 | # default_rate=rates.MIXING_RATES["proportional"]
65 | # )
66 | # util.print_mixture_examples("commonsense_morality_separate_wo_agreement_class_all_proportional")
67 | #
68 | #
69 | # seqio.MixtureRegistry.add(
70 | # "sbic_commonsense_morality_separate_wo_agreement_class_all_proportional",
71 | # ["sbic_moral_acceptability_class",
72 | # "sbic_moral_acceptability_text",
73 | # "sbic_moral_agreement_text",
74 | # "sbic_moral_comparison_class"],
75 | # default_rate=rates.MIXING_RATES["proportional"]
76 | # )
77 | # util.print_mixture_examples("sbic_commonsense_morality_separate_wo_agreement_class_all_proportional")
78 | #
79 | #
80 | # seqio.MixtureRegistry.add(
81 | # "commonsense_morality_separate_text_only_all_proportional",
82 | # ["moral_acceptability_text",
83 | # "moral_agreement_text"],
84 | # default_rate=rates.MIXING_RATES["proportional"]
85 | # )
86 | # util.print_mixture_examples("commonsense_morality_separate_text_only_all_proportional")
87 | #
88 | #
89 | # seqio.MixtureRegistry.add(
90 | # "sbic_commonsense_morality_separate_text_only_all_proportional",
91 | # ["sbic_moral_acceptability_text",
92 | # "sbic_moral_agreement_text"],
93 | # default_rate=rates.MIXING_RATES["proportional"]
94 | # )
95 | # util.print_mixture_examples("sbic_commonsense_morality_separate_text_only_all_proportional")
96 | # , 1, 10, 30, 60, 90
97 |
98 |
99 | seqio.MixtureRegistry.add(
100 | "sbic_commonsense_morality_joint_all_proportional",
101 | ["sbic_moral_acceptability",
102 | "sbic_moral_agreement",
103 | "sbic_moral_comparison"],
104 | default_rate=rates.MIXING_RATES["proportional"]
105 | )
106 | util.print_mixture_examples("sbic_commonsense_morality_joint_all_proportional")
107 |
108 | # ################## ethics values raw (for fine-tuning) ##################
109 | # seqio.MixtureRegistry.add(
110 | # "ethics_values_raw_with_cm_long_all_proportional",
111 | # ["ethics_cm_long_raw",
112 | # "ethics_deontology_raw",
113 | # "ethics_justice_raw",
114 | # "ethics_util_raw",
115 | # "ethics_virtue_raw"],
116 | # default_rate=rates.MIXING_RATES["proportional"]
117 | # )
118 | #
119 | # seqio.MixtureRegistry.add(
120 | # "ethics_values_raw_with_cm_overall_all_proportional",
121 | # ["ethics_cm_raw",
122 | # "ethics_deontology_raw",
123 | # "ethics_justice_raw",
124 | # "ethics_util_raw",
125 | # "ethics_virtue_raw"],
126 | # default_rate=rates.MIXING_RATES["proportional"]
127 | # )
128 | #
129 | # seqio.MixtureRegistry.add(
130 | # "ethics_values_raw_without_cm_all_proportional",
131 | # ["ethics_deontology_raw",
132 | # "ethics_justice_raw",
133 | # "ethics_util_raw",
134 | # "ethics_virtue_raw"],
135 | # default_rate=rates.MIXING_RATES["proportional"]
136 | # )
137 | #
138 | #
139 | # ################## ethics values (for pre-train on ethics) ##################
140 | # seqio.MixtureRegistry.add(
141 | # "ethics_values_with_cm_long_all_proportional",
142 | # ["ethics_cm_long",
143 | # "ethics_deontology",
144 | # "ethics_justice",
145 | # "ethics_util",
146 | # "ethics_virtue"],
147 | # default_rate=rates.MIXING_RATES["proportional"]
148 | # )
149 | #
150 | # seqio.MixtureRegistry.add(
151 | # "ethics_values_with_cm_overall_all_proportional",
152 | # ["ethics_cm",
153 | # "ethics_deontology",
154 | # "ethics_justice",
155 | # "ethics_util",
156 | # "ethics_virtue"],
157 | # default_rate=rates.MIXING_RATES["proportional"]
158 | # )
159 | #
160 | # seqio.MixtureRegistry.add(
161 | # "ethics_values_without_cm_all_proportional",
162 | # ["ethics_deontology",
163 | # "ethics_justice",
164 | # "ethics_util",
165 | # "ethics_virtue"],
166 | # default_rate=rates.MIXING_RATES["proportional"]
167 | # )
168 |
169 |
170 | ################## commonsense norm bank + ethics values ##################
171 |
172 | # seqio.MixtureRegistry.add(
173 | # "sbic_joint_norm_bank_ethics_all_proportional",
174 | # ["sbic_moral_acceptability",
175 | # "sbic_moral_agreement",
176 | # "sbic_moral_comparison",
177 | # "ethics_cm",
178 | # "ethics_deontology",
179 | # "ethics_justice",
180 | # "ethics_util",
181 | # "ethics_virtue"],
182 | # default_rate=rates.MIXING_RATES["proportional"]
183 | # )
184 | # # util.print_mixture_examples("sbic_joint_norm_bank_ethics_all_proportional")
185 |
186 | # seqio.MixtureRegistry.add(
187 | # "sbic_commonsense_morality_joint_all_proportional_demo_v4",
188 | # ["sbic_moral_acceptability",
189 | # "sbic_moral_agreement",
190 | # "sbic_moral_comparison",
191 | # "demo_v4"],
192 | # default_rate=rates.MIXING_RATES["proportional"]
193 | # )
194 | # util.print_mixture_examples("sbic_commonsense_morality_joint_all_proportional_demo_v4")
195 | #
196 |
197 |
198 | ################# commonsense norm bank + ablations ##################
199 | proportions = [0.01] # , 1, 10, 30, 60, 90, "base"
200 | for proportion in proportions:
201 | seqio.MixtureRegistry.add(
202 | f"sbic_commonsense_morality_joint_all_proportional_new_{proportion}",
203 | [f"sbic_moral_acceptability_{proportion}",
204 | f"sbic_moral_agreement_{proportion}",
205 | f"sbic_moral_comparison_{proportion}"],
206 | default_rate=rates.MIXING_RATES["proportional"]
207 | )
208 | util.print_mixture_examples(
209 | f"sbic_commonsense_morality_joint_all_proportional_new_{proportion}")
210 |
211 |
212 | ################## commonsense norm bank + wild ablations ##################
213 | # proportions = [10, 20, 40, 60, 80, 100]
214 | # for p in proportions:
215 | # seqio.MixtureRegistry.add(
216 | # f"sbic_commonsense_morality_joint_all_proportional_wild_{p}",
217 | # [ f"wild_train_{p}",
218 | # "sbic_moral_acceptability",
219 | # "sbic_moral_agreement",
220 | # "sbic_moral_comparison"],
221 | # default_rate=rates.MIXING_RATES["proportional"]
222 | # )
223 | # util.print_mixture_examples(f"sbic_commonsense_morality_joint_all_proportional_wild_{p}")
224 |
225 | # seqio.MixtureRegistry.add(
226 | # "sbic_commonsense_morality_joint_all_proportional_wild_woz_100",
227 | # [ f"wild_train_woz_100",
228 | # "sbic_moral_acceptability",
229 | # "sbic_moral_agreement",
230 | # "sbic_moral_comparison"],
231 | # default_rate=rates.MIXING_RATES["proportional"]
232 | # )
233 | # util.print_mixture_examples(f"sbic_commonsense_morality_joint_all_proportional_wild_woz_100")
234 |
235 |
236 | # seqio.MixtureRegistry.add(
237 | # "wild_hard_test",
238 | # [ "race_test",
239 | # "gender_test"],
240 | # default_rate=rates.MIXING_RATES["proportional"]
241 | # )
242 | # util.print_mixture_examples(f"wild_hard_test")
243 |
244 |
245 | seqio.MixtureRegistry.add(
246 | "all",
247 | [f"wild_train_100",
248 | "sbic_moral_acceptability",
249 | "sbic_moral_agreement",
250 | "sbic_moral_comparison",
251 | "race_test",
252 | "gender_test"],
253 | default_rate=rates.MIXING_RATES["proportional"]
254 | )
255 | util.print_mixture_examples(f"all")
256 |
257 |
258 | seqio.MixtureRegistry.add(
259 | "hard_all",
260 | [f"wild_train_100",
261 | "race_test",
262 | "gender_test"],
263 | default_rate=rates.MIXING_RATES["proportional"]
264 | )
265 | util.print_mixture_examples(f"hard_all")
266 |
267 |
268 | seqio.MixtureRegistry.add(
269 | "race_gender",
270 | ["race_test",
271 | "gender_test"],
272 | default_rate=rates.MIXING_RATES["proportional"]
273 | )
274 |
275 | dynahate = False
276 | if dynahate:
277 | # seqio.MixtureRegistry.add(
278 | # "dynahate_all",
279 | # [f"dynahate_round_1",
280 | # f"dynahate_round_2",
281 | # f"dynahate_round_3",
282 | # f"dynahate_round_4", ],
283 | # default_rate=rates.MIXING_RATES["proportional"]
284 | # )
285 | #
286 | # seqio.MixtureRegistry.add(
287 | # "dynahate_all_100_shot",
288 | # [f"dynahate_round_1_100_shot",
289 | # f"dynahate_round_2_100_shot",
290 | # f"dynahate_round_3_100_shot",
291 | # f"dynahate_round_4_100_shot", ],
292 | # default_rate=rates.MIXING_RATES["proportional"]
293 | # )
294 |
295 | # seqio.MixtureRegistry.add(
296 | # "dynahate_all_nat",
297 | # [f"dynahate_round_1_nat",
298 | # f"dynahate_round_2_nat",
299 | # f"dynahate_round_3_nat",
300 | # f"dynahate_round_4_nat", ],
301 | # default_rate=rates.MIXING_RATES["proportional"]
302 | # )
303 | #
304 | # seqio.MixtureRegistry.add(
305 | # "dynahate_all_nat_100_shot",
306 | # [f"dynahate_round_1_nat_100_shot",
307 | # f"dynahate_round_2_nat_100_shot",
308 | # f"dynahate_round_3_nat_100_shot",
309 | # f"dynahate_round_4_nat_100_shot", ],
310 | # default_rate=rates.MIXING_RATES["proportional"]
311 | # )
312 |
313 | seqio.MixtureRegistry.add(
314 | "dynahate_all_st",
315 | [f"dynahate_round_1_st",
316 | f"dynahate_round_2_st",
317 | f"dynahate_round_3_st",
318 | f"dynahate_round_4_st",],
319 | default_rate=rates.MIXING_RATES["proportional"]
320 | )
321 |
322 | seqio.MixtureRegistry.add(
323 | "dynahate_all_st_100_shot",
324 | [f"dynahate_round_1_st_100_shot",
325 | f"dynahate_round_2_st_100_shot",
326 | f"dynahate_round_3_st_100_shot",
327 | f"dynahate_round_4_st_100_shot",],
328 | default_rate=rates.MIXING_RATES["proportional"]
329 | )
330 |
331 | # seqio.MixtureRegistry.add(
332 | # "dynahate_all_st_clean",
333 | # [ f"dynahate_round_1_st_clean",
334 | # f"dynahate_round_2_st_clean",
335 | # f"dynahate_round_3_st_clean",
336 | # f"dynahate_round_4_st_clean",],
337 | # default_rate=rates.MIXING_RATES["proportional"]
338 | # )
339 | #
340 | # seqio.MixtureRegistry.add(
341 | # "dynahate_all_st_clean_100_shot",
342 | # [ f"dynahate_round_1_st_clean_100_shot",
343 | # f"dynahate_round_2_st_clean_100_shot",
344 | # f"dynahate_round_3_st_clean_100_shot",
345 | # f"dynahate_round_4_st_clean_100_shot",],
346 | # default_rate=rates.MIXING_RATES["proportional"]
347 | # )
348 | #
349 | # seqio.MixtureRegistry.add(
350 | # "dynahate_all_bc",
351 | # [ f"dynahate_round_1_bc",
352 | # f"dynahate_round_2_bc",
353 | # f"dynahate_round_3_bc",
354 | # f"dynahate_round_4_bc",],
355 | # default_rate=rates.MIXING_RATES["proportional"]
356 | # )
357 | #
358 | # seqio.MixtureRegistry.add(
359 | # "dynahate_all_bc_100_shot",
360 | # [ f"dynahate_round_1_bc_100_shot",
361 | # f"dynahate_round_2_bc_100_shot",
362 | # f"dynahate_round_3_bc_100_shot",
363 | # f"dynahate_round_4_bc_100_shot",],
364 | # default_rate=rates.MIXING_RATES["proportional"]
365 | # )
366 |
367 | # seqio.MixtureRegistry.add(
368 | # "declare_only",
369 | # [f"freeform",
370 | # f"yesno"],
371 | # default_rate=rates.MIXING_RATES["proportional"]
372 | # )
373 |
--------------------------------------------------------------------------------
/src/delphi/train/predict.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 |
3 | """Evaluate the model on the rainbow datasets."""
4 |
5 | import t5
6 | import logging
7 | import click
8 | import tensorflow.compat.v1 as tf
9 |
10 | # Improve logging.
11 | from contextlib import contextmanager
12 |
13 |
14 | tf.disable_v2_behavior()
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | def getSubstringBetweenMarkers(source_string, start_marker, end_marker):
20 | start = source_string.find(start_marker) + len(start_marker)
21 | end = source_string.find(end_marker)
22 | return source_string[start: end]
23 |
24 |
25 | @contextmanager
26 | def tf_verbosity_level(level):
27 | og_level = tf.logging.get_verbosity()
28 | tf.logging.set_verbosity(level)
29 | yield
30 | tf.logging.set_verbosity(og_level)
31 |
32 |
33 | @click.command()
34 | @click.option(
35 | "--batch-size",
36 | type=int,
37 | default=64,
38 | help=(
39 | "The batch size to use for prediction. For efficient prediction on the"
40 | " TPU, choose a multiple of either 8 or 128. Defaults to 64."
41 | ),
42 | )
43 | @click.option(
44 | "--model-parallelism",
45 | type=int,
46 | default=8,
47 | help="The degree of model parallelism to use. Defaults to 8.",
48 | )
49 | @click.option(
50 | "--tpu-name",
51 | type=str,
52 | default="de-tpu-5",
53 | required=True,
54 | help="The name of the TPU. Defaults to the TPU_NAME environment variable.",
55 | )
56 | @click.option(
57 | "--tpu-topology",
58 | type=str,
59 | default="v3-32",
60 | required=True,
61 | help=(
62 | "The topology of the TPU. Defaults to the TPU_TOPOLOGY environment"
63 | " variable."
64 | ),
65 | )
66 | def predict(
67 | batch_size: int,
68 | model_parallelism: int,
69 | tpu_name: str,
70 | tpu_topology: str,
71 | ) -> None:
72 | """Evaluate the model located at RESULTS_DIR on MIXTURE."""
73 | eval_data = "UNDHR.idty.0"
74 |
75 | data_version = "v11"
76 | model_type = "distribution"
77 | check_point = 1249400
78 |
79 | lr = 0.0001
80 | bs = 16
81 | bucket_name = "ai2-tpu-europe-west4"
82 | models_dir = f"gs://{bucket_name}/projects/liweij/mosaic-commonsense-morality/model/{data_version}/" \
83 | f"unicorn-pt/{model_type}/lr-{lr}_bs-{bs}"
84 | training_type = "joint"
85 |
86 | # Run evaluation.
87 | model = t5.models.MtfModel(
88 | model_dir=models_dir,
89 | tpu=tpu_name,
90 | tpu_topology=tpu_topology,
91 | model_parallelism=model_parallelism,
92 | batch_size=batch_size,
93 | sequence_length={"inputs": 512, "targets": 128},
94 | learning_rate_schedule=None,
95 | save_checkpoints_steps=5000,
96 | keep_checkpoint_max=None,
97 | iterations_per_loop=100,
98 | )
99 |
100 | predict_joint_inputs_paths = ["gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/"
101 | f"data/qualitative_eval/{training_type}/" + eval_data + "_qualitative_eval.tsv"]
102 | predict_joint_outputs_paths = [
103 | models_dir.replace("model", "preds") + "/raw/" + eval_data + "_qualitative_eval.tsv"]
104 |
105 | for i in range(len(predict_joint_inputs_paths)):
106 | predict_joint_inputs_path = predict_joint_inputs_paths[i]
107 | predict_joint_outputs_path = predict_joint_outputs_paths[i]
108 |
109 | # Ignore any logging so that we only see the model's answers to the questions.
110 | with tf_verbosity_level('ERROR'):
111 | # Min size for small model on v2-8 with parallelism 1.
112 | model.batch_size = 8
113 | model.predict(
114 | input_file=predict_joint_inputs_path,
115 | output_file=predict_joint_outputs_path,
116 | # Select the most probable output token at each step.
117 | temperature=0,
118 | checkpoint_steps=check_point,
119 | )
120 |
121 |
122 | if __name__ == "__main__":
123 | predict()
124 |
--------------------------------------------------------------------------------
/src/delphi/train/rates.py:
--------------------------------------------------------------------------------
1 | """
2 | Mixing rates
3 | """
4 |
5 | import seqio
6 |
7 | def equal_rate(task: seqio.Task):
8 | """Mix the datasets in equal amounts.
9 |
10 | Parameters
11 | ----------
12 | task : t5.data.Task
13 | The task.
14 |
15 | Returns
16 | -------
17 | float
18 | The constant: ``1.0``.
19 | """
20 | return 1.0
21 |
22 |
23 | def proportional_rate(task: seqio.Task):
24 | """Mix the datasets proportionally.
25 |
26 | Parameters
27 | ----------
28 | task : t5.data.Task
29 | The task.
30 |
31 | Returns
32 | -------
33 | float
34 | The number of examples in the task's training set.
35 | """
36 | return float(task.num_input_examples("train"))
37 |
38 |
39 | # constants
40 |
41 | MIXING_RATES = {
42 | "equal": equal_rate,
43 | "proportional": proportional_rate,
44 | }
45 | """A dictionary mapping mixing rates' names to their implementations."""
46 |
--------------------------------------------------------------------------------
/src/delphi/train/util.py:
--------------------------------------------------------------------------------
1 | """
2 | Util functions for fine-tuning and evaluating models
3 | """
4 | import seqio
5 | import pandas as pd
6 | from google.cloud import storage
7 | import tensorflow_datasets as tfds
8 | import tensorflow as tf
9 |
10 |
11 | def create_folder(client, bucket, destination_folder_name):
12 | """
13 | Create a folder in Google Cloud Storage if such folder doesn't exist already
14 | """
15 | if not storage.Blob(bucket=bucket, name=destination_folder_name).exists(client):
16 | blob = bucket.blob(destination_folder_name)
17 | blob.upload_from_string('')
18 | print('Created: {}'.format(destination_folder_name))
19 | else:
20 | print('Exists: {}'.format(destination_folder_name))
21 |
22 |
23 | def print_task_examples(task_name, split="validation", num_ex=1):
24 | """
25 | Print examples from tasks
26 | """
27 | print("#" * 20, task_name, "#" * 20)
28 | task = seqio.TaskRegistry.get(task_name)
29 | ds = task.get_dataset(split=split, sequence_length={
30 | "inputs": 512, "targets": 128})
31 | for i, ex in enumerate(tfds.as_numpy(ds.take(num_ex))):
32 | print(i, ex)
33 | print("test", task.num_input_examples("test"))
34 | print("train", task.num_input_examples("train"))
35 | print("validation", task.num_input_examples("validation"))
36 |
37 |
38 | def print_mixture_examples(mixture_name, split="validation", num_ex=1):
39 | """
40 | Print examples from mixtures
41 | """
42 | print("#" * 20, mixture_name, "#" * 20)
43 | mixture = seqio.MixtureRegistry.get(mixture_name)
44 | ds = mixture.get_dataset(split=split,
45 | sequence_length={"inputs": 512, "targets": 128})
46 |
47 | for i, ex in enumerate(tfds.as_numpy(ds.take(num_ex))):
48 | print(i, ex)
49 | print("test", mixture.num_input_examples("test"))
50 | print("train", mixture.num_input_examples("train"))
51 | print("validation", mixture.num_input_examples("validation"))
52 |
53 |
54 | def get_num_elements_csv(file_name):
55 | """
56 | Get the total number of elements in a given csv/tsv file
57 | """
58 | df = pd.read_csv(file_name, delimiter="\t")
59 | return df.shape[0]
60 |
61 |
62 | def get_num_elements_split(split_paths):
63 | """
64 | Get the number of elements in each split of a dataset
65 | """
66 | num_elements_split = {}
67 | for split, path in split_paths.items():
68 | num_elements_split[split] = get_num_elements_csv(path)
69 | return num_elements_split
70 |
71 |
72 | def get_result_check_points(result_prefix, split, eval_data_type, after_check_point=-1):
73 | """
74 | Get a list of model checkpoints that haven't generated on the designated data split yet
75 | """
76 | client = storage.Client()
77 | bucket_name = "ai2-tpu-europe-west4"
78 | result_prefix = result_prefix.split(bucket_name + "/")[-1] + "/"
79 |
80 | check_points = []
81 | done_check_points = []
82 | for blob in client.list_blobs(bucket_name, prefix=result_prefix):
83 | blob_name = str(blob).split("/")[-1]
84 | if ".meta" in blob_name:
85 | check_point = int(blob_name.split(".meta")[0].split("-")[-1])
86 | if check_point > after_check_point:
87 | check_points.append(check_point)
88 |
89 | print("-" * 10, "checkpoints all", "-" * 10)
90 | print(check_points)
91 |
92 | for blob in client.list_blobs(bucket_name, prefix=result_prefix + f"{split}_eval/"):
93 | blob_name = str(blob).split("/")[-1]
94 | if "_predictions" in blob_name and eval_data_type in blob_name and "_predictions_clean" not in blob_name:
95 | check_point_done = int(blob_name.split(
96 | "_predictions")[0].split("_")[-1])
97 | # check_point_done = int(blob_name.split("_")[0].split("_")[-1])
98 | if check_point_done in check_points:
99 | done_check_points.append(check_point_done)
100 | check_points.remove(check_point_done)
101 |
102 | print("-" * 10, "checkpoints done", "-" * 10)
103 | print(done_check_points)
104 | return check_points
105 |
106 |
107 | def validate_path(results_dir, pretrained_model=None, PRETRAINED_MODELS=None):
108 | """
109 | Validate result path
110 | """
111 | if PRETRAINED_MODELS != None:
112 | if not results_dir.startswith("gs://"):
113 | raise ValueError(
114 | f"RESULTS_DIR ({results_dir}) must be a GCS path.")
115 |
116 | if pretrained_model.startswith("gs://"):
117 | if not tf.io.gfile.exists(pretrained_model):
118 | raise IOError(
119 | f"--pretrained-model ({pretrained_model}) does not exist."
120 | )
121 | else:
122 | if pretrained_model not in PRETRAINED_MODELS:
123 | raise ValueError(
124 | f"--pretrained-model ({pretrained_model}) not recognized. It"
125 | f" must either be a GCS path or one of"
126 | f' {", ".join(PRETRAINED_MODELS.keys())}.')
127 | else:
128 | if not results_dir.startswith("gs://"):
129 | raise ValueError(
130 | f"RESULTS_DIR ({results_dir}) must be a GCS path.")
131 | elif not tf.io.gfile.exists(results_dir):
132 | raise IOError(f"RESULTS_DIR ({results_dir}) doesn't exist.")
133 |
134 |
135 | def print_arguments(result_path, results_dir, mixture, split, pretrained_model,
136 | pretrained_checkpoint_step, n_steps, batch_size, model_parallelism,
137 | save_checkpoints_steps, n_checkpoints_to_keep, learning_rate,
138 | tpu_name, tpu_topology, tasks, continue_finetune):
139 | print("=" * 10, "results_dir")
140 | print(results_dir)
141 |
142 | print("=" * 10, "mixture")
143 | print(mixture)
144 |
145 | print("=" * 10, "split")
146 | print(split)
147 |
148 | print("=" * 10, "pretrained_model")
149 | print(pretrained_model)
150 |
151 | print("=" * 10, "pretrained_checkpoint_step")
152 | print(pretrained_checkpoint_step)
153 |
154 | print("=" * 10, "n_steps")
155 | print(n_steps)
156 |
157 | print("=" * 10, "batch_size")
158 | print(batch_size)
159 |
160 | print("=" * 10, "model_parallelism")
161 | print(model_parallelism)
162 |
163 | print("=" * 10, "save_checkpoints_steps")
164 | print(save_checkpoints_steps)
165 |
166 | print("=" * 10, "n_checkpoints_to_keep")
167 | print(n_checkpoints_to_keep)
168 |
169 | print("=" * 10, "learning_rate")
170 | print(learning_rate)
171 |
172 | print("=" * 10, "tpu_name")
173 | print(tpu_name)
174 |
175 | print("=" * 10, "tpu_topology")
176 | print(tpu_topology)
177 |
178 | print("=" * 10, "result_path")
179 | print(result_path)
180 |
181 | print("=" * 10, "data_version")
182 | print(tasks.data_version)
183 |
184 | print("=" * 10, "continue_finetune")
185 | print(continue_finetune)
186 |
187 |
188 | def get_result_path(
189 | results_dir: str,
190 | pretrained_model: str,
191 | mixture: str,
192 | learning_rate: float,
193 | batch_size: int
194 | ) -> str:
195 | """
196 | Get a result path given arguments
197 | """
198 | result_path = results_dir + \
199 | "/" + pretrained_model + \
200 | "/" + mixture + \
201 | f"/lr-{learning_rate}_bs-{batch_size}"
202 | return result_path
203 |
--------------------------------------------------------------------------------
/src/delphi_hybrid/components/COMETGenerator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | from treelib import Node, Tree
5 | from transformers import GPT2LMHeadModel, GPT2TokenizerFast
6 |
7 | sys.path.append(os.getcwd())
8 | from src.delphi_hybrid.components.bank import *
9 |
10 |
11 | class COMETGenerator():
12 | def __init__(self, model_name="gpt2-xl-atomic2020", device_id=0, server="beaker"): # beaker_batch, local
13 | if model_name == "gpt2-xl-atomic2020":
14 | if server == "beaker_batch":
15 | base_path_atomic_2020 = "/model/atomic2020"
16 | else:
17 | base_path_atomic_2020 = "/net/nfs.cirrascale/mosaic/liweij/model/atomic2020"
18 |
19 | CUDA_DEVICE = f"cuda:{device_id}" if torch.cuda.is_available() else 'cpu'
20 | self.device = torch.device(CUDA_DEVICE)
21 | print(f"COMETGenerator device: {self.device}")
22 | self.model_name = model_name
23 | self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2-xl")
24 | self.model = GPT2LMHeadModel.from_pretrained(base_path_atomic_2020,
25 | pad_token_id=self.tokenizer.eos_token_id).to(self.device)
26 | self.add_special_tokens()
27 |
28 | self.BOS_TOKEN = self.tokenizer.bos_token
29 | self.EOS_TOKEN = self.tokenizer.eos_token
30 | self.GEN_TOKEN = "[GEN]"
31 |
32 | def add_special_tokens(self):
33 | self.tokenizer.add_special_tokens({
34 | 'eos_token': '[EOS]',
35 | 'additional_special_tokens': [
36 | 'LocationOfAction',
37 | 'HinderedBy',
38 | 'HasFirstSubevent',
39 | 'NotHasProperty',
40 | 'NotHasA',
41 | 'HasA',
42 | 'AtLocation',
43 | 'NotCapableOf',
44 | 'CausesDesire',
45 | 'HasPainCharacter',
46 | 'NotDesires',
47 | 'MadeUpOf',
48 | 'InstanceOf',
49 | 'SymbolOf',
50 | 'xReason',
51 | 'isAfter',
52 | 'HasPrerequisite',
53 | 'UsedFor',
54 | 'MadeOf',
55 | 'MotivatedByGoal',
56 | 'Causes',
57 | 'oEffect',
58 | 'CreatedBy',
59 | 'ReceivesAction',
60 | 'NotMadeOf',
61 | 'xWant',
62 | 'PartOf',
63 | 'DesireOf',
64 | 'HasPainIntensity',
65 | 'xAttr',
66 | 'DefinedAs',
67 | 'oReact',
68 | 'xIntent',
69 | 'HasSubevent',
70 | 'oWant',
71 | 'HasProperty',
72 | 'IsA',
73 | 'HasSubEvent',
74 | 'LocatedNear',
75 | 'Desires',
76 | 'isFilledBy',
77 | 'isBefore',
78 | 'InheritsFrom',
79 | 'xNeed',
80 | 'xEffect',
81 | 'xReact',
82 | 'HasLastSubevent',
83 | 'RelatedTo',
84 | 'CapableOf',
85 | 'NotIsA',
86 | 'ObjectUse',
87 | '[GEN]'
88 | ]
89 | })
90 | self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
91 | self.model.resize_token_embeddings(len(self.tokenizer))
92 |
93 | def __name__(self):
94 | return self.model_name
95 |
96 | def generate(self, head_event, relation):
97 | input_string = head_event + " " + relation + " [GEN]"
98 | input_ids = self.tokenizer(input_string, return_tensors='pt').to(self.device).input_ids
99 | outputs = self.model.generate(input_ids, max_length=200, output_scores=True, return_dict_in_generate=True)
100 |
101 | decoded_sequence = self.tokenizer.decode(outputs["sequences"][0])
102 | tail_event = decoded_sequence.split(self.GEN_TOKEN)[-1].split(self.EOS_TOKEN)[0]
103 | return tail_event
104 |
105 | def generate_beam(self, head_event, relation, num_beams=5, num_return_sequences=5, max_length=100):
106 | input_string = head_event + " " + relation + " [GEN]"
107 | tokenized_input = self.tokenizer(input_string,
108 | max_length=max_length,
109 | truncation=True,
110 | return_tensors='pt').to(self.device)
111 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
112 |
113 | outputs = self.model.generate(input_ids=tokenized_input.input_ids,
114 | attention_mask=tokenized_input.attention_mask,
115 | # output_scores=True,
116 | # return_dict_in_generate=True,
117 | num_beams=num_beams,
118 | max_length=max_length,
119 | num_return_sequences=num_return_sequences,)
120 |
121 | decoded_sequences = self.tokenizer.batch_decode(outputs)
122 |
123 | tail_events = [ds.split(self.GEN_TOKEN)[-1].split(self.EOS_TOKEN)[0] for ds in decoded_sequences]
124 |
125 | return tail_events
126 |
127 | def generate_all_relations(self, event):
128 | comet_inferences = {}
129 | for relation in comet_relations:
130 | tail_events = self.generate_beam(event, relation)
131 | comet_inferences[relation] = tail_events
132 | return comet_inferences
133 |
134 | if __name__ == "__main__":
135 | comet_generator = COMETGenerator(device_id=0)
136 |
137 | head_events = ["a bear",
138 | "being a stupid bear",
139 | "performing genocide",
140 | "a protected bear"]
141 |
142 | # for head_event in head_events:
143 | for relation in comet_relations:
144 | tail_events = comet_generator.generate_beam(head_events[0], relation)
145 |
--------------------------------------------------------------------------------
/src/delphi_hybrid/components/CacheHandler/COMETCacheHandler.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | from tqdm import tqdm
5 |
6 | sys.path.append(os.getcwd())
7 | from src.delphi_hybrid.components.utils import *
8 | from src.delphi_hybrid.components.COMETGenerator import *
9 | from src.delphi_hybrid.components.CacheHandler.CacheHandler import *
10 |
11 | class COMETCacheHandler(CacheHandler):
12 | def __init__(self, filename="comet_subset", cache_dir="cache", device_id=0):
13 | if filename != None and "comet" not in filename:
14 | print("ERROR: wrong cache file!")
15 | super().__init__("comet", cache_dir, filename)
16 | self.comet_generator = COMETGenerator(device_id=device_id)
17 |
18 | def _generate_instance(self, event):
19 | return self.comet_generator.generate_all_relations(event)
20 |
21 |
22 | if __name__ == "__main__":
23 | events = list(read_json(data_base_path + f"cache_norm_bank/all_sequences.json").keys())
24 | print(events)
25 |
26 | cache_handler = COMETCacheHandler(filename="norm_bank_comet", cache_dir="cache_norm_bank")
27 | for event in tqdm(events):
28 | comet_instance = cache_handler.save_instance(event)
29 |
--------------------------------------------------------------------------------
/src/delphi_hybrid/components/CacheHandler/CacheHandler.py:
--------------------------------------------------------------------------------
1 | """
2 | File: CacheHandler.py
3 | Note: The parent class of all CacheHandlers
4 | """
5 |
6 | import os
7 | import sys
8 | import abc
9 | import json
10 | from tqdm import tqdm
11 |
12 | sys.path.append(os.getcwd())
13 | from src.delphi_hybrid.components.utils import *
14 |
15 | class CacheHandler():
16 | __metaclass__ = abc.ABCMeta
17 |
18 | def __init__(self, cache_type, cache_dir="cache", filename=None):
19 | self.cache_type = cache_type
20 | self.filename = filename
21 | self.cache_dir = cache_dir
22 | self._load_cache()
23 | print(f"{cache_type} cache size: {len(self.cache)}")
24 |
25 | def __name__(self):
26 | return self.cache_type + "CacheHandler"
27 |
28 | def _get_cache_path(self):
29 | if self.filename == None:
30 | path = data_base_path + f"{self.cache_dir}/{self.cache_type}.json"
31 | else:
32 | path = data_base_path + f"{self.cache_dir}/{self.filename}.json"
33 | return path
34 |
35 | def _load_cache(self):
36 | cache_file_path = self._get_cache_path()
37 | cache_dir_path = "/".join(cache_file_path.split("/")[:-1])
38 | if not os.path.exists(cache_file_path):
39 | ensure_dir(cache_dir_path)
40 | with open(cache_file_path, 'w') as f:
41 | json.dump({}, f)
42 | print(f"Loading {self.cache_type} cache from ...")
43 | print(cache_file_path)
44 | self.cache = read_json(cache_file_path)
45 | print(f"...done loading {self.cache_type} cache!")
46 |
47 | def _save_cache(self):
48 | cache_file_path = self._get_cache_path()
49 | with open(cache_file_path, 'w') as f:
50 | json.dump(self.cache, f)
51 | print("Saved cache to", cache_file_path)
52 |
53 | def _add_instance(self, cache_key, instance, is_save):
54 | self.cache[cache_key] = instance
55 | if is_save:
56 | self._save_cache()
57 |
58 | @abc.abstractmethod
59 | def _generate_instance(self, event):
60 | pass
61 |
62 | def fetch_instance(self, event):
63 | if event not in self.cache or (event in self.cache and self.cache[event] == None):
64 | return None
65 | else:
66 | return self.cache[event]
67 |
68 | def get_instance(self, event, is_save=False):
69 | if event not in self.cache or (event in self.cache and self.cache[event] == None):
70 | instance = self._generate_instance(event)
71 | self._add_instance(event, instance, is_save)
72 | else:
73 | # print("[Note] Instance exists in cache!")
74 | instance = self.cache[event]
75 | return instance
76 |
77 | def update_instance(self, event, is_save=True):
78 | """
79 | Regenerate instance (no matter if it's already in the cache or not) and update the cache
80 | """
81 | instance = self._generate_instance(event)
82 | self._add_instance(event, instance, is_save=is_save)
83 | return instance
84 |
85 | def save_instance(self, event):
86 | return self.get_instance(event, is_save=True)
87 |
88 | # def save_all_instance(self, events):
89 | # for event in tqdm(events):
90 | # self.save_instance(event)
91 |
92 | def save_all_instance(self, events, save_interval=1):
93 | for i, event in enumerate(tqdm(events)):
94 | if type(event) != type(""):
95 | continue
96 | self.get_instance(event)
97 | if i % save_interval == 0:
98 | self._save_cache()
99 | self._save_cache()
100 |
101 | def regenerate_all_instance(self, events, save_interval=1):
102 | for i, event in enumerate(tqdm(events)):
103 | if type(event) != type(""):
104 | continue
105 | self.update_instance(event, is_save=False)
106 | if i % save_interval == 0:
107 | self._save_cache()
108 | self._save_cache()
109 |
110 | if __name__ == "__main__":
111 | delphi_cache_handler = CacheHandler("delphi_scores")
112 | delphi_cache = delphi_cache_handler.cache
113 |
--------------------------------------------------------------------------------
/src/delphi_hybrid/components/CacheHandler/DelphiCacheHandler.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 |
5 | sys.path.append(os.getcwd())
6 | from src.delphi_hybrid.components.DelphiScorer import *
7 | from src.delphi_hybrid.components.CacheHandler.CacheHandler import *
8 |
9 | class DelphiCacheHandler(CacheHandler):
10 | def __init__(self, filename=None, cache_dir="cache", model="t5-11b-1239200", device_id=0, server="local"): #"beaker_batch"
11 | if filename != None and "delphi" not in filename:
12 | print("ERROR: wrong cache file!")
13 | super().__init__("delphi_scores", filename=filename, cache_dir=cache_dir)
14 | self.delphi_generator = DelphiScorer(model=model, device_id=device_id, server=server)
15 |
16 | def _generate_instance(self, event):
17 | class_label, probs, text_label = self.delphi_generator.generate_with_score(event)
18 | return {"class_label": class_label,
19 | "prob_1": probs[0],
20 | "prob_0": probs[1],
21 | "prob_minus_1": probs[2],
22 | "text_label": text_label}
23 |
24 | if __name__ == "__main__":
25 | events = []
26 | for split in ["test", "validation"]:
27 | input_file = data_base_path + f"cache_norm_bank/events/clean_{split}.moral_acceptability.tsv"
28 | df_data = pd.read_csv(input_file, sep="\t")
29 | events += df_data["clean_event"].tolist()
30 |
31 | cache_handler = DelphiCacheHandler(filename="norm_bank_delphi", cache_dir="cache_norm_bank")
32 | for event in tqdm(events):
33 | comet_instance = cache_handler.save_instance(event)
34 |
--------------------------------------------------------------------------------
/src/delphi_hybrid/components/CacheHandler/ParaphraseCacheHandler.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 |
5 | sys.path.append(os.getcwd())
6 | from src.delphi_hybrid.components.Paraphraser import *
7 | from src.delphi_hybrid.components.CacheHandler.CacheHandler import *
8 |
9 | class ParaphraseCacheHandler(CacheHandler):
10 | def __init__(self, filename=None, cache_dir="cache", num_paraphrases=8):
11 | super().__init__("paraphrases", cache_dir, filename)
12 | self.paraphraser = Paraphraser()
13 | self.num_paraphrases = num_paraphrases
14 |
15 | def _clean_paraphrases(self, event, paraphrases):
16 | paraphrases = list(set(paraphrases))
17 | if event in paraphrases:
18 | paraphrases.remove(event)
19 |
20 | for paraphrase in paraphrases:
21 | is_qualified= self.paraphraser.qualify_paraphrase(event, paraphrase)
22 | if not is_qualified:
23 | paraphrases.remove(paraphrase)
24 | return paraphrases
25 |
26 | def save_instance(self, event):
27 | instance = []
28 | if event in self.cache:
29 | instance = self.cache[event]
30 |
31 | if len(instance) < self.num_paraphrases:
32 | instance += self.paraphraser.generate_paraphrases(event, num_paraphrases=self.num_paraphrases)["paraphrases"]
33 | instance = list(set(instance))
34 | self._add_instance(event, instance, is_save=True)
35 | else:
36 | print("[Note] Enough paraphrases in cache!")
37 | print("Num paraphrases:", len(instance))
38 | return instance
39 |
40 | def update_instance(self, event, is_save=True):
41 | return None
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser(description="Generate answers with GPT-3.")
46 | parser.add_argument("--section_id", type=int, default="section_id")
47 | args = parser.parse_args()
48 |
49 | input_file = data_base_path + "cache_norm_bank/events/clean_test.moral_acceptability.tsv"
50 | df_data = pd.read_csv(input_file, sep="\t")
51 | events = df_data["clean_event"].tolist()
52 |
53 | section_id = 0
54 |
55 | cache_handler = ParaphraseCacheHandler(filename=f"paraphrases/norm_bank_paraphrases_{section_id}",
56 | cache_dir="cache_norm_bank",
57 | num_paraphrases=10)
58 | for e in tqdm(events):
59 | print(len(e))
60 | cache_handler.save_instance(e)
61 |
--------------------------------------------------------------------------------
/src/delphi_hybrid/components/CompositionalityParser.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from tqdm import tqdm
4 |
5 | sys.path.append(os.getcwd())
6 | from src.delphi_hybrid.components.utils import *
7 |
8 |
9 | class CompositionalityParser():
10 | def __init__(self):
11 | self.to_exclude_list = ["able", "want", "refuse", "try", "due", "go",
12 | "supposed", "claim", "pretend",
13 | # https://grammar.collinsdictionary.com/us/easy-learning/which-verbs-are-followed-by-the-to-infinitive-in-english
14 | "agree", "arrange", "attempt", "choose", "decide", "fail", "hope", "learn", "manage", "offer", "plan", "seem", "come",
15 | "how", "when", "which", "what", "struggle", "remember"] # "lie",
16 |
17 | self.phrases_to_replace = {"in order to": "to", "so that": "so"}
18 |
19 | self.relative_pronounce_list = [
20 | "that", "which", "who", "whom", "where"]
21 |
22 | self.for_exclude_list = ["pay"] # "lie",
23 |
24 | self.adj_exclude_list = ["nuclear"]
25 | self.conj_additional_list = [
26 | "otherwise", "to", "for", "so", "because", "by", "or"] # , "if", "when"
27 |
28 | def fix_event(self, event):
29 | for pr in self.phrases_to_replace:
30 | if pr in event:
31 | event = event.replace(pr, self.phrases_to_replace[pr])
32 | return event
33 |
34 | def organize_subevents(self, tokens, lemmatized_tokens, poss, deps, idxs, idxs_to_segment, original_tokens):
35 | all_segments = []
36 | last_idx = 0
37 | for current_idx in idxs_to_segment:
38 | if last_idx in idxs_to_segment:
39 | all_segments.append([tokens[last_idx],
40 | [tokens[last_idx]],
41 | [lemmatized_tokens[last_idx]],
42 | [poss[last_idx]],
43 | [deps[last_idx]],
44 | [idxs[last_idx]],
45 | [original_tokens[last_idx]],
46 | "conjunction"])
47 | last_idx += 1
48 | all_segments.append([" ".join(tokens[last_idx:current_idx]),
49 | tokens[last_idx:current_idx],
50 | lemmatized_tokens[last_idx:current_idx],
51 | poss[last_idx:current_idx],
52 | deps[last_idx:current_idx],
53 | idxs[last_idx:current_idx],
54 | original_tokens[last_idx:current_idx],
55 | "content"])
56 | last_idx = current_idx
57 |
58 | if last_idx in idxs_to_segment:
59 | all_segments.append([tokens[last_idx],
60 | [tokens[last_idx]],
61 | [lemmatized_tokens[last_idx]],
62 | [poss[last_idx]],
63 | [deps[last_idx]],
64 | [idxs[last_idx]],
65 | [original_tokens[last_idx]],
66 | "conjunction"])
67 | last_idx += 1
68 | all_segments.append([" ".join(tokens[last_idx:]),
69 | tokens[last_idx:],
70 | lemmatized_tokens[last_idx:],
71 | poss[last_idx:],
72 | deps[last_idx:],
73 | idxs[last_idx:],
74 | original_tokens[last_idx:],
75 | "content"])
76 | segments = [segment for segment in all_segments if segment[0] != ""]
77 | return segments
78 |
79 | def get_subevents(self, tokens, lemmatized_tokens, poss, deps):
80 | original_tokens = [None for _ in range(len(tokens))]
81 | conjs_idxs_global = [i for i, pos in enumerate(poss) if (
82 | (len(pos) > 3 and pos[-4:] == "CONJ") or lemmatized_tokens[i] in self.conj_additional_list)]
83 | conjs_idxs_global += [len(original_tokens) - 1]
84 |
85 | idxs = [i for i in range(len(tokens))]
86 | idxs_to_segment = []
87 | for i, token in enumerate(tokens):
88 | pos = poss[i]
89 |
90 | if (len(pos) > 3 and pos[-4:] == "CONJ") or token in self.conj_additional_list:
91 | if i != 0 and lemmatized_tokens[i - 1] not in self.to_exclude_list \
92 | and poss[i - 1] not in ["ADJ"]:
93 | # print(pos, token)
94 | conj_idx_local = conjs_idxs_global.index(i)
95 | # print(conj_idx_local)
96 |
97 | if conj_idx_local < (len(conjs_idxs_global) - 1):
98 | next_conj_idx_global = conjs_idxs_global[conj_idx_local + 1]
99 |
100 | # if there's no verb in the next sequence, then don't segment
101 | if "VERB" in poss[i: next_conj_idx_global + 1] \
102 | or "be" in lemmatized_tokens[i: next_conj_idx_global + 1]:
103 | # print(pos, token)
104 | idxs_to_segment.append(i)
105 |
106 | subevents = self.organize_subevents(tokens, lemmatized_tokens, poss, deps,
107 | idxs, idxs_to_segment, original_tokens)
108 | return subevents
109 |
110 | def _get_relative_clause(self, subevent):
111 | tokens = subevent[1]
112 | poss = subevent[3]
113 |
114 | for i, token in enumerate(tokens):
115 | if token in self.relative_pronounce_list and i != 0 and poss[i - 1] in ["NOUN"] \
116 | or token in ["who", "whom"]:
117 | return [
118 | [" ".join(tokens[:i])] + [l[:i]
119 | for l in subevent[1:-1]] + ["content"],
120 | [" ".join(tokens[i:i+1])] + [l[i:i+1]
121 | for l in subevent[1:-1]] + ["relative pronoun"],
122 | [" ".join(tokens[i+1:])] + [l[i+1:]
123 | for l in subevent[1:-1]] + ["relative clause"]
124 | ]
125 |
126 | def _get_all_subevents(self, event):
127 | event = self.fix_event(event)
128 | parsed_event = parse_sequence(event, is_dependency_parse=True)
129 |
130 | tokens = parsed_event["tokens"]["tokens_list"]
131 | lemmatized_tokens = parsed_event["lemmatized_tokens"]["tokens_list"]
132 | poss = [_token[2] for _token in parsed_event["tokens"]["tokens_dict"]]
133 | deps = [_token[3] for _token in parsed_event["tokens"]["tokens_dict"]]
134 |
135 | return self.get_subevents(tokens, lemmatized_tokens, poss, deps)
136 |
137 | def glue_subevents(self, subevents):
138 | glued_subevent = subevents[0]
139 |
140 | for e in subevents[1:]:
141 | for i, c in enumerate(e[1:-1]):
142 | glued_subevent[i + 1] += c
143 |
144 | glued_subevent[0] = " ".join(glued_subevent[1])
145 | glued_subevent[-1] = "content"
146 |
147 | return glued_subevent
148 |
149 | def _get_parsed_event(self, event, is_simple=True):
150 | subevents = self._get_all_subevents(event)
151 | parsed_event = {}
152 |
153 | relative_clause = self._get_relative_clause(subevents[0])
154 | if relative_clause != None:
155 | parsed_event["main_action"] = {"main_clause": relative_clause[0],
156 | "relative_pronoun": relative_clause[1], "relative_clause": relative_clause[2]}
157 | else:
158 | parsed_event["main_action"] = {"main_clause": subevents[0]}
159 |
160 | if len(subevents) > 1:
161 | parsed_event["connector"] = subevents[1]
162 | parsed_event["other_action"] = self.glue_subevents(subevents[2:])
163 |
164 | if not is_simple:
165 | return parsed_event
166 | else:
167 | main_action = parsed_event["main_action"]["main_clause"][0]
168 | relative_pronoun = None
169 | relative_clause = None
170 | connector = None
171 | other_action = None
172 | if "relative_pronoun" in parsed_event["main_action"]:
173 | relative_pronoun = parsed_event["main_action"]["relative_pronoun"][0]
174 | relative_clause = parsed_event["main_action"]["relative_clause"][0]
175 |
176 | if "connector" in parsed_event:
177 | connector = parsed_event["connector"][0]
178 | other_action = parsed_event["other_action"][0]
179 |
180 | return {"main_action": main_action,
181 | "relative_pronoun": relative_pronoun,
182 | "relative_clause": relative_clause,
183 | "connector": connector,
184 | "other_action": other_action}
185 |
186 | def get_parsed_event(self, event):
187 | parsed_events = self._get_parsed_event(event, is_simple=True)
188 |
189 | main_event = parsed_events["main_action"]
190 | main_event_main_clause = parsed_events["main_action"]
191 | main_event_relative_pronoun = parsed_events["relative_pronoun"]
192 | main_event_relative_clause = parsed_events["relative_clause"]
193 | adjunct_event = parsed_events["other_action"]
194 |
195 | if main_event_main_clause != None and main_event_relative_clause != None:
196 | main_event = main_event_main_clause + " " + \
197 | main_event_relative_pronoun + " " + main_event_relative_clause
198 |
199 | return {"main_event": main_event,
200 | "main_event_main_clause": main_event_main_clause,
201 | "main_event_relative_clause": main_event_relative_clause,
202 | "adjunct_event": adjunct_event, }
203 |
204 |
205 | if __name__ == "__main__":
206 | compositionality_parser = CompositionalityParser()
207 |
208 | events = []
209 | for split in ["test", "validation"]:
210 | input_file = data_base_path + \
211 | f"cache_norm_bank/events/clean_{split}.moral_acceptability.tsv"
212 | df_data = pd.read_csv(input_file, sep="\t")
213 | events += df_data["clean_event"].tolist()
214 |
215 | data_to_save = {}
216 | for event in tqdm(events):
217 | parsed_event = compositionality_parser.get_parsed_event(event)
218 | data_to_save[event] = parsed_event
219 |
220 | save_json(data_base_path + f"cache_norm_bank/constituents.json", data_to_save)
221 |
--------------------------------------------------------------------------------
/src/delphi_hybrid/components/DelphiScorer.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append(".")
3 |
4 | import torch
5 | from scipy.special import softmax
6 | from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
7 |
8 |
9 | class DelphiScorer:
10 | def __init__(self, device_id=1, model="t5-11b-1239200", server="beaker_batch"):
11 | CUDA_DEVICE = f"cuda:{device_id}" if torch.cuda.is_available() else 'cpu'
12 | self.device = torch.device(CUDA_DEVICE)
13 | print(f"DelphiScorer device: {self.device}", model)
14 |
15 | if model == "t5-large":
16 | MODEL_BASE = "t5-large"
17 | MODEL_LOCATION = "/net/nfs.cirrascale/mosaic/liweij/model/large_commonsense_morality_hf"
18 | self.class_token_pos = 4
19 | self.sep_tokens = [" /class> text>", " class>", " /text>"]
20 |
21 | elif model == "t5-11b":
22 | MODEL_BASE = "t5-11b"
23 | MODEL_LOCATION = "/net/nfs.cirrascale/mosaic/liweij/model/11b_commonsense_morality_hf"
24 | self.class_token_pos = 4
25 | self.sep_tokens = [" /class> text>", " class>", " /text>"]
26 |
27 | elif model == "t5-11b-1239200":
28 | MODEL_BASE = "t5-11b"
29 | if server == "beaker_batch":
30 | MODEL_LOCATION = "/model/delphi11b"
31 | else:
32 | MODEL_LOCATION = "/net/nfs.cirrascale/mosaic/liweij/model/11b_commonsense_morality_1239200_hf"
33 | self.class_token_pos = 4
34 | self.sep_tokens = [" /class> text>", " class>", " /text>"]
35 |
36 | elif model == "v11_distribution":
37 | MODEL_BASE = "t5-11b"
38 | MODEL_LOCATION = "/net/nfs.cirrascale/mosaic/liweij/model/v11_distribution_hf"
39 | self.class_token_pos = 3
40 | self.sep_tokens = ["[/class] [text]", "[class]", "[/text]"]
41 |
42 | elif model == "11B_001":
43 | MODEL_BASE = "t5-11b"
44 | MODEL_LOCATION = "/net/nfs.cirrascale/mosaic/liweij/model/11B_001_hf"
45 | self.class_token_pos = 4
46 | self.sep_tokens = [" /class> text>", " class>", " /text>"]
47 |
48 | else:
49 | print("ERROR: model doesn't exist")
50 | return
51 |
52 | self.model = T5ForConditionalGeneration.from_pretrained(MODEL_LOCATION)
53 | self.model.to(self.device)
54 | self.tokenizer = T5Tokenizer.from_pretrained(MODEL_BASE, model_max_length=512)
55 |
56 | def score(self, input_string, normalize=None):
57 | input_string = f"[moral_single]: {input_string}"
58 | input_ids = self.tokenizer(input_string, return_tensors='pt').to(self.device).input_ids
59 | outputs = self.model.generate(input_ids, max_length=200, output_scores=True, return_dict_in_generate=True)
60 |
61 | probs = [(self.tokenizer.decode(i), x) for (i, x) in enumerate(outputs['scores'][self.class_token_pos][0].softmax(0))]
62 |
63 | class1_prob = sum([v[1].item() for v in probs if v[0] == "1"])
64 | class0_prob = sum([v[1].item() for v in probs if v[0] == "0"])
65 | classminus1_prob = sum([v[1].item() for v in probs if v[0] == "-1"])
66 |
67 | probs = [class1_prob, class0_prob, classminus1_prob]
68 | probs_sum = sum(probs)
69 |
70 | if normalize == "regular":
71 | probs = [p / probs_sum for p in probs]
72 | elif normalize == "softmax":
73 | probs = softmax(probs)
74 |
75 | return probs
76 |
77 | def generate(self, input_string):
78 | input_string = f"[moral_single]: {input_string}"
79 | input_ids = self.tokenizer(input_string, return_tensors='pt').to(self.device).input_ids
80 | outputs = self.model.generate(input_ids, max_length=200, output_scores=True, return_dict_in_generate=True)
81 |
82 | decoded_sequence = self.tokenizer.decode(outputs["sequences"][0])
83 | class_label = int(decoded_sequence.split(self.sep_tokens[0])[0].split(self.sep_tokens[1])[-1])
84 | text_label = decoded_sequence.split(self.sep_tokens[0])[-1].split(self.sep_tokens[2])[0]
85 |
86 | return class_label, text_label
87 |
88 |
89 | def generate_beam(self,
90 | input_string,
91 | num_beams=5,
92 | max_length=50,
93 | num_return_sequences=5,):
94 | input_string = f"[moral_single]: {input_string}"
95 | input_ids = self.tokenizer(input_string,
96 | max_length=16,
97 | truncation=True,
98 | return_tensors='pt').to(self.device).input_ids
99 | outputs = self.model.generate(input_ids,
100 | num_beams=num_beams,
101 | max_length=max_length,
102 | num_return_sequences=num_return_sequences,)
103 |
104 | decoded_sequences = self.tokenizer.batch_decode(outputs)
105 |
106 | class_labels = [ds.split(self.sep_tokens[0])[0].split(self.sep_tokens[1])[-1] for ds in decoded_sequences]
107 | text_labels = [ds.split(self.sep_tokens[0])[-1].split(self.sep_tokens[2])[0] for ds in decoded_sequences]
108 |
109 | return class_labels, text_labels
110 |
111 |
112 | def generate_with_score(self, input_string):
113 | input_string = f"[moral_single]: {input_string}"
114 | input_ids = self.tokenizer(input_string, max_length=512, return_tensors='pt').to(self.device).input_ids
115 | outputs = self.model.generate(input_ids, max_length=200, output_scores=True, return_dict_in_generate=True)
116 |
117 | probs = [(self.tokenizer.decode(i), x) for (i, x) in enumerate(outputs['scores'][self.class_token_pos][0].softmax(0))]
118 |
119 | class1_prob = sum([v[1].item() for v in probs if v[0] == "1"])
120 | class0_prob = sum([v[1].item() for v in probs if v[0] == "0"])
121 | classminus1_prob = sum([v[1].item() for v in probs if v[0] == "-1"])
122 |
123 | probs = [class1_prob, class0_prob, classminus1_prob]
124 | # probs_sum = sum(probs)
125 |
126 | decoded_sequence = self.tokenizer.decode(outputs["sequences"][0])
127 | class_label = int(decoded_sequence.split(self.sep_tokens[0])[0].split(self.sep_tokens[1])[-1])
128 | text_label = decoded_sequence.split(self.sep_tokens[0])[-1].split(self.sep_tokens[2])[0]
129 |
130 | return class_label, probs, text_label
131 |
132 |
133 | def generate_with_score_comparison(self, action1, action2):
134 | input_string = f"[moral_pair]: {action1} {action2}"
135 | input_ids = self.tokenizer(input_string, return_tensors='pt').to(self.device).input_ids
136 | outputs = self.model.generate(input_ids, max_length=200, output_scores=True, return_dict_in_generate=True)
137 |
138 | probs = [(self.tokenizer.decode(i), x) for (i, x) in enumerate(outputs['scores'][0][0].softmax(0))]
139 |
140 | class1_prob = sum([v[1].item() for v in probs if v[0] == "1"])
141 | class2_prob = sum([v[1].item() for v in probs if v[0] == "2"])
142 |
143 | probs = [class1_prob, class2_prob]
144 | # probs_sum = sum(probs)
145 |
146 | class_label = int(self.tokenizer.decode(outputs["sequences"][0], skip_special_tokens=True))
147 | # print(class1_prob, class2_prob, class_label)
148 |
149 | return class_label, probs
150 |
151 |
152 | if __name__ == '__main__':
153 | # delphi_scorer = DelphiScorer(device_id=0, model="t5-large")
154 | delphi_scorer = DelphiScorer(device_id=0, model="t5-11b-1239200")
155 | # print(delphi_scorer.generate_with_score("killing a cat"))
156 | # e_list = ["killing a cat",
157 | # "killing a dog",
158 | # "killing a child"]
159 | #
160 | # for e in e_list:
161 | # print(delphi_scorer.generate_with_score(e))
162 |
163 | print(delphi_scorer.generate_with_score_comparison("killing a cat", "killing an evil cat"))
164 |
165 |
166 | # for i in range(5):
167 | # print(delphi_scorer.generate_beam("Gay parents adopting a child"))\
168 |
169 | # for seq in ["Gay parents adopting a child", "quentin tarantino", "Can I kill a terrorist?", "Protecting China"]:
170 | # seq =
171 | # print(delphi_scorer.generate(seq))
172 |
173 |
--------------------------------------------------------------------------------
/src/delphi_hybrid/components/GPT3Scorer.py:
--------------------------------------------------------------------------------
1 | from scripts.sub_event_extraction.utils.LMScorer import *
2 |
3 |
4 | import math
5 | import openai
6 |
7 |
8 | class GPT3Scorer(LMScorer):
9 | def __init__(self, model_name="text-davinci-003"): # "text-davinci-002"
10 | super().__init__(model_name)
11 |
12 | self.MODEL_NAME = model_name
13 | openai.api_key = os.environ["OPENAI_API_KEY"]
14 | self.CONDITIONED_GEN_TOKEN = "<|endoftext|>"
15 |
16 | def correct_grammatical_error(self, input_sequence):
17 | response = openai.Completion.create(
18 | model=self.MODEL_NAME,
19 | prompt="Correct this to standard English:\n\n" + input_sequence,
20 | temperature=0,
21 | max_tokens=60,
22 | top_p=1.0,
23 | frequency_penalty=0.0,
24 | presence_penalty=0.0
25 | )
26 | corrected_sequence = response["choices"][0]["text"].split("\n\n")[-1]
27 | return corrected_sequence
28 |
29 | def get_input_perplexity(self, input_sequence):
30 | """
31 | Helper function to get the perplexity of a sequence from GPT3
32 | """
33 | # formatted_input_sequence = CONDITIONED_GEN_TOKEN + input_sequence[0].upper() + input_sequence[1:]
34 | # formatted_input_sequence = CONDITIONED_GEN_TOKEN + input_sequence[0].upper() + input_sequence[1:] + "."
35 | # formatted_input_sequence = CONDITIONED_GEN_TOKEN + "'" + input_sequence[0].upper() + input_sequence[1:] + ".'"
36 | formatted_input_sequence = self.CONDITIONED_GEN_TOKEN + input_sequence
37 | return self._get_input_perplexity(formatted_input_sequence)
38 |
39 | def _get_input_perplexity(self, formatted_input_sequence):
40 | """
41 | Helper function to get the perplexity of a sequence from GPT3
42 | """
43 | response = openai.Completion.create(
44 | model=self.MODEL_NAME,
45 | prompt=formatted_input_sequence,
46 | # prompt=formatted_input_sequence
47 | max_tokens=0,
48 | logprobs=1,
49 | echo=True
50 | )
51 | # print(formatted_input_sequence)
52 |
53 | tokens_logprobs = response["choices"][0]["logprobs"]["token_logprobs"][1:]
54 | num_tokens = len(tokens_logprobs)
55 | sequence_cross_entropy = -sum(tokens_logprobs) / num_tokens
56 | perplexity = math.exp(sequence_cross_entropy)
57 | return perplexity
58 |
59 | def get_input_perplexity_combo(self, input_sequence, return_all_ppl=False):
60 | """
61 | Get a combined perplexity averaged across multiple input sequence formats.
62 | """
63 | formatted_input_sequences = [
64 | self.CONDITIONED_GEN_TOKEN + input_sequence[0].upper() + input_sequence[1:],
65 | self.CONDITIONED_GEN_TOKEN + input_sequence[0].upper() + input_sequence[1:] + ".",
66 | self.CONDITIONED_GEN_TOKEN + '"' + input_sequence[0].upper() + input_sequence[1:] + '"',
67 | ]
68 |
69 | return self._get_input_perplexity_combo(formatted_input_sequences, return_all_ppl)
70 |
71 | def get_input_logprob_gpt3(self, input_sequence):
72 | """
73 | Helper function to get the perplexity of a sequence from GPT3
74 | """
75 | formatted_input_sequence = self.CONDITIONED_GEN_TOKEN + '"' + input_sequence[0].upper() + input_sequence[1:] + '"'
76 | response = openai.Completion.create(
77 | model=self.MODEL_NAME,
78 | prompt=formatted_input_sequence,
79 | max_tokens=0,
80 | logprobs=1,
81 | echo=True
82 | )
83 |
84 | tokens_logprobs = response["choices"][0]["logprobs"]["token_logprobs"][1:]
85 | return sum(tokens_logprobs)
86 |
87 | if __name__ == "__main__":
88 | gpt3_scorer = GPT3Scorer()
89 |
90 | # premise = "killing a bear to save your child."
91 | # print(gpt3_scorer.get_input_perplexity_combo(premise))
92 |
93 | # premise = "killing a stupid bear to save your child"
94 | # print(gpt3_scorer.get_paraphrases(premise))
95 |
--------------------------------------------------------------------------------
/src/delphi_hybrid/components/LMScorer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from transformers import GPT2LMHeadModel, GPT2TokenizerFast
4 |
5 |
6 | class LMScorer:
7 | def __init__(self, model_name):
8 | print("Initialize model:",model_name)
9 |
10 | def get_input_perplexity(self, input_sequence):
11 | return
12 |
13 | def _get_input_perplexity(self, formatted_input_sequence):
14 | return
15 |
16 | def _get_input_perplexity_combo(self, formatted_input_sequences, return_all_ppl=False):
17 | all_ppl = []
18 | for formatted_input_sequence in formatted_input_sequences:
19 | ppl = self._get_input_perplexity(formatted_input_sequence)
20 | all_ppl.append(ppl)
21 |
22 | if return_all_ppl:
23 | return formatted_input_sequences, (all_ppl + [(sum(all_ppl) / len(all_ppl))])
24 | else:
25 | return (sum(all_ppl) / len(all_ppl))
26 |
27 | def get_input_perplexity_combo(self, input_sequence, return_all_ppl=False):
28 | return
29 |
30 |
--------------------------------------------------------------------------------
/src/delphi_hybrid/components/MoralSaliencyKeywordCounter.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 |
5 | sys.path.append(os.getcwd())
6 | from src.delphi_hybrid.components.DelphiScorer import *
7 | from src.delphi_hybrid.components.CacheHandler.CacheHandler import *
8 |
9 | class MoralSaliencyKeywordCounter():
10 | def __init__(self, filename=data_base_path + f"cache/keywords.json"):
11 | self.keywords_cache = read_json(filename)
12 |
13 | def get_events_keyword_counts(self, e, keyword):
14 | e_keywords = self.keywords_cache[e]
15 |
16 | all_keywords_selected = all_keywords_categorize[keyword]
17 | keywords_counts = {k: 0 for k in all_keywords_selected}
18 |
19 | for k in all_keywords_selected:
20 | if k in e_keywords:
21 | keywords_counts[k] = e_keywords[k]
22 | return keywords_counts
23 |
24 | def get_all_top_level_keyword_count(self, keywords_map):
25 | return sum([keywords_map[k] for k in keywords_map])
26 |
27 | def get_event_top_level_keyword_count(self, keywords_counts):
28 | return sum([keywords_counts[k] for k in keywords_counts])
29 |
30 | def get_top_level_keywords_counts_map(self, top_level_keyword):
31 | top_level_keywords_counts_map = {}
32 | for e in self.keywords_cache:
33 | keywords_counts = self.get_events_keyword_counts(e, top_level_keyword)
34 | keywords_counts_map[e] = self.get_all_top_level_keyword_count(keywords_counts)
35 | return keywords_counts_map
36 |
37 | if __name__ == "__main__":
38 | keyword_counter = MoralSaliencyKeywordCounter()
39 | # instance = cache_handler.get_instance("Killing a bear")
40 |
41 |
--------------------------------------------------------------------------------
/src/delphi_hybrid/components/Paraphraser.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.getcwd())
4 |
5 | import pandas as pd
6 | import re
7 | import random
8 | from src.delphi_hybrid.components.utils import *
9 | from src.delphi_hybrid.components.WANLIScorer import WANLIScorer
10 |
11 | # S: Putting my mother in a retirement home.
12 | # A: Placing my mother in a retirement home.
13 | # A: Killing someone because you are being kind.
14 |
15 | PROMPT_SOURCE = """
16 | Paraphrase the given sentence.
17 | S: Torturing a cat if other people cannot see it.
18 | A: {}
19 | S: Not caring for Christmas.
20 | A: Being indifferent towards Christmas.
21 | S: Killing someone with kindness.
22 | A: Killing someone for showing your kindness.
23 | S: {}
24 | A:
25 | """
26 |
27 | PARAPHRASE_OPTIONS = [
28 | "Torturing a cat assuming that other people cannot see it.",
29 | "Torturing a cat on the assumption that other people cannot see it.",
30 | "Torturing a cat in case other people cannot see it.",
31 | "Torturing a cat, given that other people cannot see it.",
32 | "Torturing a cat if nobody can see it."
33 | ]
34 |
35 |
36 | class Paraphraser():
37 | def __init__(self, model_name="text-davinci-003"): # "text-davinci-003", "text-curie-001"
38 | self.MODEL_NAME = model_name
39 | openai.api_key = os.environ["OPENAI_API_KEY"]
40 | self.CONDITIONED_GEN_TOKEN = "<|endoftext|>"
41 | self.nli_scorer = WANLIScorer()
42 |
43 | def fix_sentence(self, input_str):
44 | if input_str == "": # Don't change empty strings.
45 | return input_str
46 | if input_str[-1] in ["?", ".", "!"]: # Don't change if already okay.
47 | return input_str
48 | if input_str[-1] == ",": # Change trailing ',' to '.'.
49 | return input_str[:-1] + "."
50 | return input_str + "."
51 |
52 | def _qualify_paraphrase(self, action_doc, paraphrased_action_doc):
53 | action = action_doc.text.lower()
54 | paraphrased_action = paraphrased_action_doc.text.lower()
55 |
56 | num_tokens_action = len(action_doc)
57 | num_tokens_paraphrased_action = len(paraphrased_action_doc)
58 |
59 | action_lemma = " ".join([token.lemma_ for token in action_doc]).lower()
60 | paraphrased_action_lemma = " ".join([token.lemma_ for token in paraphrased_action_doc]).lower()
61 |
62 | if action_lemma in paraphrased_action_lemma and action_lemma != paraphrased_action_lemma:
63 | return False
64 |
65 | if any(paraphrased_action.startswith(prefix) for prefix in ["it's ", "it is "]) \
66 | and not any(action.startswith(prefix) for prefix in ["it's ", "it is "]):
67 | return False
68 |
69 | if ":" not in action and ":" in paraphrased_action:
70 | return False
71 |
72 | if not re.search('[a-zA-Z]', paraphrased_action):
73 | return False
74 |
75 | if paraphrased_action.lower() in [action.lower(), "n/a"]:
76 | return False
77 |
78 | return abs(num_tokens_action - num_tokens_paraphrased_action) / num_tokens_action < 1
79 |
80 | def qualify_paraphrase(self, action, paraphrased_action):
81 | action_doc = nlp(action)
82 | paraphrased_action_doc = nlp(paraphrased_action)
83 | return self._qualify_paraphrase(action_doc, paraphrased_action_doc)
84 |
85 | def generate_paraphrases(self, action, num_paraphrases=1, max_iteration_count=20):
86 | fixed_action = self.fix_sentence(action.capitalize())
87 | action_doc = nlp(action)
88 |
89 | paraphrases = []
90 | iteration_count = 0
91 | while len(paraphrases) < num_paraphrases and iteration_count < max_iteration_count:
92 | _po = random.sample(PARAPHRASE_OPTIONS, 1)[0]
93 | _p = PROMPT_SOURCE.format(_po, fixed_action)
94 |
95 | # try:
96 | # response = gpt3_completion(_p, model_name="text-davinci-003", max_tokens=100, temperature=0.9,
97 | # logprobs=0, echo=False, num_outputs=10, top_p=1.0, best_of=1)
98 | # iteration_count += 1
99 | # except Exception as e:
100 | # continue
101 |
102 | response = gen_gpt3(_p, n=5, model=self.MODEL_NAME, temperature=0.8, max_tokens=400)
103 | iteration_count += 1
104 | paraphrased_action = [c['text'].strip() for c in response['choices']]
105 |
106 | for paraphrase in paraphrased_action:
107 | paraphrase_doc = nlp(paraphrase)
108 | if self._qualify_paraphrase(action_doc, paraphrase_doc):
109 | paraphrases.append(paraphrase)
110 | paraphrases = list(set(paraphrases))
111 | print(len(paraphrases))
112 |
113 | return {
114 | "action": action,
115 | "fixed_action": fixed_action,
116 | "paraphrases": paraphrases
117 | }
118 |
119 | if __name__ == "__main__":
120 | # paraphraser = Paraphraser()
121 |
122 | # action = "genocide if it creates job"
123 | # action = "Having sex with your best friend"
124 | # action = "Putting my mother in a retirement home"
125 | # action = "Killing a baby because you really want to"
126 | # action = "Stabbing someone who opened your mail without permission"
127 | # action = "Killing someone with kindness"
128 | # action = "Robbing a bank because it is my job"
129 | # action = "Let people play games and allow them to win lots of money"
130 | # action = "My husband doesn't know that our daughter is not his"
131 | # action = "Making money instead of saving someone"
132 | # print(paraphraser.generate_paraphrases(action, num_paraphrases=6))
133 |
134 | # premise = "killing a bear to save your child."
135 | # hypothesis = "killing a bear."
136 | # prediction = wanli_scorer.get_scores(premise, hypothesis)
137 | # print(premise, "|", hypothesis, ":", prediction)
138 | openai.api_key = os.environ["OPENAI_API_KEY"]
139 | print(gen_gpt3("_p", n=20, model="text-davinci-003", temperature=0.7, max_tokens=400))
140 |
141 |
142 |
143 |
144 |
--------------------------------------------------------------------------------
/src/delphi_hybrid/components/WANLIScorer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import RobertaTokenizer, RobertaForSequenceClassification
3 |
4 |
5 | class WANLIScorer:
6 |
7 | def __init__(self):
8 | self.model = RobertaForSequenceClassification.from_pretrained('alisawuffles/roberta-large-wanli')
9 | self.tokenizer = RobertaTokenizer.from_pretrained('alisawuffles/roberta-large-wanli')
10 |
11 | def get_scores(self, premise, hypothesis, is_return_score=True):
12 | x = self.tokenizer(premise, hypothesis, hypothesis, return_tensors='pt', max_length=128, truncation=True)
13 | logits = self.model(**x).logits
14 | probs = logits.softmax(dim=1).squeeze(0)
15 | label_id = torch.argmax(probs).item()
16 | prediction = self.model.config.id2label[label_id]
17 |
18 | if is_return_score:
19 | scores = {"contradiction": probs[0].item(),
20 | "entailment": probs[1].item(),
21 | "neutral": probs[2].item()}
22 | return scores
23 | return prediction
24 |
25 |
26 | if __name__ == "__main__":
27 | wanli_scorer = WANLIScorer()
28 | # premise = "killing a bear to save your child."
29 | # hypothesis = "killing a bear."
30 | # prediction = wanli_scorer.get_scores(premise, hypothesis)
31 | # print(premise, "|", hypothesis, ":", prediction)
32 | #
33 | # hypothesis = "killing a child."
34 | # prediction = wanli_scorer.get_scores(premise, hypothesis)
35 | # print(premise, "|", hypothesis, ":", prediction)
36 | #
37 | # hypothesis = "save your child."
38 | # prediction = wanli_scorer.get_scores(premise, hypothesis)
39 | # print(premise, "|", hypothesis, ":", prediction)
40 | #
41 | # hypothesis = "bear your child."
42 | # prediction = wanli_scorer.get_scores(premise, hypothesis)
43 | # print(premise, "|", hypothesis, ":", prediction)
44 | #
45 | # hypothesis = "bear."
46 | # prediction = wanli_scorer.get_scores(premise, hypothesis)
47 | # print(premise, "|", hypothesis, ":", prediction)
48 | #
49 | # hypothesis = "save."
50 | # prediction = wanli_scorer.get_scores(premise, hypothesis)
51 | # print(premise, "|", hypothesis, ":", prediction)
52 | #
53 | # hypothesis = "killing."
54 | # prediction = wanli_scorer.get_scores(premise, hypothesis)
55 | # print(premise, "|", hypothesis, ":", prediction)
56 |
57 |
58 | premise = "Killing a bear if I attack it."
59 | hypothesis = "Killing a bear that attacks you"
60 | prediction = wanli_scorer.get_scores(premise, hypothesis)
61 | print(premise, "|", hypothesis, ":", prediction)
62 |
63 | hypothesis = "Killing a bear if I attack it."
64 | premise = "Killing a bear that attacks you"
65 | prediction = wanli_scorer.get_scores(premise, hypothesis)
66 | print(premise, "|", hypothesis, ":", prediction)
67 |
--------------------------------------------------------------------------------
/src/delphi_hybrid/components/constants.py:
--------------------------------------------------------------------------------
1 | base_path = "/net/nfs.cirrascale/mosaic/liweij/delphi_algo/"
2 | data_base_path = base_path + "data/"
3 | results_base_path = base_path + "results/"
4 |
--------------------------------------------------------------------------------
/src/delphi_hybrid/prepare_data/compile_demo_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 | import pandas as pd
5 |
6 | # sys.path.append(os.getcwd())
7 |
8 | sys.path.append("/Users/liweijiang/Desktop/delphi_algo/scripts/utils")
9 | # print(os.getcwd())
10 |
11 | from main_utils import *
12 |
13 | random.seed(10)
14 |
15 |
16 | def get_event_len(e):
17 | return len(e)
18 |
19 | def filter_by_keywords_text_label(e):
20 | kw_to_exclude = ["yes,", "no,"]
21 | return not any(kw in e.lower() for kw in kw_to_exclude)
22 |
23 | def filter_by_keywords_event(e):
24 | kw_to_exclude = ["?", "\"", "yeri", "izone", "aaaaaaaaaaaaaaaaaaaaaaaaaaaa",
25 | "delphi", "is a ", "hhhhhooooooo", "poiuytpoiuytrpoiuytrezapoiuytrezaqwxcvbnracist",
26 | "rhum", "..", "ben"]
27 | return not any(kw in e.lower() for kw in kw_to_exclude)
28 |
29 | def filter_by_startswith_keywords_event(e):
30 | kw_to_exclude = ["can", "is ", "should", "how", "what", "when",
31 | "are ", "why"]
32 | return not any(e.lower().startswith(kw) for kw in kw_to_exclude)
33 |
34 | def is_in_english(e):
35 | try:
36 | e.encode(encoding='utf-8').decode('ascii')
37 | except UnicodeDecodeError:
38 | return False
39 | else:
40 | return True
41 |
42 | def get_upper_letter_count(e):
43 | return sum(1 for c in (e[0].lower() + e[1:]) if c.isupper())
44 |
45 | def main():
46 | df_data_done = pd.read_csv("/data/demo/mturk/result/v2_v3_comprehensible_clean.csv", sep=",")
47 | event_done = df_data_done["event"].tolist()
48 |
49 | df_data = pd.read_csv("/data/demo/demo_examples_102721_delphi.csv", sep="\t")
50 | df_data["action1_len"] = df_data["action1"].apply(get_event_len)
51 | df_data = df_data.dropna(subset=["action1", "text_label"])
52 |
53 | df_data = df_data[~df_data["action1"].isin(event_done)]
54 | df_data = df_data[df_data["action1_len"] > 25]
55 | df_data = df_data[df_data["action1_len"] < 200]
56 | df_data = df_data[df_data["text_label"].apply(filter_by_keywords_text_label)]
57 | df_data = df_data[df_data["action1"].apply(filter_by_keywords_event)]
58 | df_data = df_data[df_data["action1"].apply(filter_by_startswith_keywords_event)]
59 | df_data = df_data[df_data["action1"].apply(is_in_english)]
60 |
61 | df_data["clean_event"] = df_data["action1"].apply(normalize_event)
62 |
63 | # print(df_data) any(kw in e.lower() for kw in remove_keywords_list)
64 |
65 | events_to_annotate = df_data["clean_event"].value_counts()[:10000].keys().tolist()
66 | # print(events_to_annotate)
67 |
68 |
69 | df_data["question1"] = df_data["clean_event"]
70 |
71 | df_data_selected = df_data[df_data["clean_event"].isin(events_to_annotate)]
72 | df_data_selected = df_data_selected[["action1", "class_label", "text_label", "clean_event", "question1"]]
73 |
74 | df_data_selected = df_data_selected.drop_duplicates(subset=["clean_event"])
75 | df_data_selected.to_csv("/Users/liweijiang/Desktop/delphi_algo/data/demo/mturk/input/demo_events_to_annotate_v4.csv", index=False)
76 |
77 | # df_data = pd.read_csv("/Users/liweijiang/Desktop/delphi_algo/data/demo/mturk/input/demo_events_to_annotate_v4.csv", sep=",")
78 | # df_data = df_data.rename(columns={"action1": "raw_event"})
79 | # df_data.to_csv("/Users/liweijiang/Desktop/delphi_algo/data/demo/mturk/input/demo_events_to_annotate_v4.csv", index=False)
80 |
81 | def compile_annotated_data():
82 | df_data = pd.read_csv("/data/demo/mturk/result/demo_events_to_annotate_v4.csv",
83 | sep=",")
84 |
85 | df_data_done = pd.read_csv(
86 | "/data/demo/mturk/result/v2_v3_comprehensible_clean.csv", sep=",")
87 | event_done = df_data_done["event"].tolist()
88 |
89 | headers = ["Input.action1", "Input.class_label",
90 | "Input.text_label", "Input.clean_event",
91 | "Answer.controversial_1", "Answer.lewd_1",
92 | "Answer.feedback", "Answer.is_action_1",
93 | "Answer.make_sense_1", "Answer.privacy_1"]
94 |
95 | rename_map = {"Input.action1": "event",
96 | "Input.class_label": "class_label",
97 | "Input.text_label": "text_label",
98 | "Input.clean_event": "clean_event",
99 | "Answer.controversial_1": "is_controversial",
100 | "Answer.lewd_1": "is_lewd",
101 | "Answer.feedback": "feedback",
102 | "Answer.is_action_1": "is_action",
103 | "Answer.make_sense_1": "is_make_sense",
104 | "Answer.privacy_1": "is_privacy"}
105 |
106 | df_data = df_data[headers]
107 | df_data = df_data.rename(columns=rename_map)
108 | df_data = df_data[~df_data["event"].isin(event_done)]
109 | df_data = df_data[df_data["feedback"] == "{}"]
110 | df_data = df_data[df_data["is_make_sense"] == 1.0]
111 | df_data = df_data[df_data["is_lewd"] == -1.0]
112 | df_data = df_data[df_data["is_action"] == 1.0]
113 | df_data = df_data[df_data["is_privacy"] == -1.0]
114 | df_data = df_data[df_data["event"].apply(filter_by_keywords_event)]
115 |
116 | df_data["event_len"] = df_data["event"].apply(get_event_len)
117 | df_data = df_data[df_data["event_len"] < 120]
118 | df_data = df_data[df_data["text_label"].apply(filter_by_keywords_text_label)]
119 |
120 | df_data["upper_letter_count"] = df_data["event"].apply(get_upper_letter_count)
121 | df_data = df_data[df_data["upper_letter_count"] < 2]
122 |
123 | df_data = df_data.sort_values(by=["event_len"])
124 | print(df_data["event_len"].mean())
125 |
126 | df_data_non_controversial = df_data[df_data["is_controversial"] == -1.0]
127 | df_data_non_controversial = df_data_non_controversial.sample(n=2500, random_state=0)
128 | df_data_controversial = df_data[df_data["is_controversial"] == 1.0]
129 | df_data_controversial = df_data_controversial.sample(n=4000 - df_data_non_controversial.shape[0], random_state=0)
130 |
131 | df_data_selected = pd.concat([df_data_non_controversial, df_data_controversial], ignore_index=True, sort=False)
132 | df_data_selected.to_csv("/Users/liweijiang/Desktop/delphi_algo/data/demo/mturk/input/demo_events_to_annotate_v4_selected.csv", index=False)
133 | input_events = df_data_selected["event"].tolist()
134 |
135 | df_data_input = pd.DataFrame()
136 | for i in range(4):
137 | df_data_input[f"action{i+1}"] = input_events[i * 1000: (i+1) * 1000]
138 | df_data_input.to_csv("/Users/liweijiang/Desktop/delphi_algo/data/demo/mturk/input/demo_events_to_annotate_v4_selected_input.csv",
139 | index=False)
140 |
141 | # print(len(event_done), len(set(event_done)))
142 | # print(len(input_events), len(set(input_events)))
143 | # print(len(set(event_done + input_events)))
144 |
145 |
146 | def compile_v234_data():
147 | df_data_v4 = pd.read_csv(
148 | "/data/demo/mturk/input/demo_events_to_annotate_v4_selected.csv", sep=",")
149 | event_v4 = df_data_v4["event"].tolist()
150 |
151 | df_data_v2_v3 = pd.read_csv(
152 | "/data/demo/mturk/result/v2_v3_comprehensible_clean.csv", sep=",")
153 | event_v2_v3 = df_data_v2_v3["event"].tolist()
154 | # print(len(event_v4))
155 | # print(len(event_v2_v3))
156 |
157 | all_events = event_v2_v3 + event_v4
158 | random.shuffle(all_events)
159 |
160 | # train_event = all_events[0: int(0.4 * len(all_events))]
161 | # dev_event = all_events[int(0.4 * len(all_events)): int(0.7 * len(all_events))]
162 | # test_event = all_events[int(0.7 * len(all_events)):]
163 |
164 | train_split_labels = ["train" for _ in range(int(0.5 * len(all_events)) + 2)] \
165 | + ["dev" for _ in range(int(0.25 * len(all_events)))] \
166 | + ["test" for _ in range(int(0.25 * len(all_events)))]
167 |
168 | df_data = pd.DataFrame()
169 | df_data["event"] = all_events
170 | df_data["split"] = train_split_labels
171 |
172 | df_data["source"] = "v23"
173 | df_data.loc[df_data["event"].isin(event_v4), "source"] = "v4"
174 |
175 | # df_data["source"] = df_data["source"]
176 | df_data["clean_event"] = df_data["event"].apply(normalize_event)
177 |
178 | # print(df_data["source"].value_counts())
179 | df_data.to_csv("/Users/liweijiang/Desktop/delphi_algo/data/demo/mturk/split/event_only_v234.csv", index=False, sep="\t")
180 | # print(len(train_event))
181 | # print(len(dev_event))
182 | # print(len(test_event))
183 |
184 |
185 | def compile_v5_data():
186 | """
187 | remove some controversial data from v4, and combine with v23 data to form v5
188 | """
189 | df_data_v4 = pd.read_csv(
190 | "/data/demo/mturk/input/demo_events_to_annotate_v4_selected.csv", sep=",")
191 | df_data_v4 = df_data_v4[df_data_v4["is_controversial"] == -1.0]
192 | event_v4 = df_data_v4["event"].tolist()
193 |
194 | df_data_v2_v3 = pd.read_csv(
195 | "/data/demo/mturk/result/v2_v3_comprehensible_clean.csv", sep=",")
196 | event_v2_v3 = df_data_v2_v3["event"].tolist()
197 | # print(len(event_v4))
198 | # print(len(event_v2_v3))
199 |
200 | all_events = event_v2_v3 + event_v4
201 | random.shuffle(all_events)
202 |
203 | # train_event = all_events[0: int(0.4 * len(all_events))]
204 | # dev_event = all_events[int(0.4 * len(all_events)): int(0.7 * len(all_events))]
205 | # test_event = all_events[int(0.7 * len(all_events)):]
206 |
207 | train_split_labels = ["train" for _ in range(int(0.5 * len(all_events)) + 2)] \
208 | + ["dev" for _ in range(int(0.25 * len(all_events)))] \
209 | + ["test" for _ in range(int(0.25 * len(all_events)))]
210 |
211 | df_data = pd.DataFrame()
212 | df_data["event"] = all_events
213 | df_data["split"] = train_split_labels
214 |
215 | df_data["source"] = "v23"
216 | df_data.loc[df_data["event"].isin(event_v4), "source"] = "v4"
217 |
218 | df_data["clean_event"] = df_data["event"].apply(normalize_event)
219 |
220 | # df_data["source"] = df_data["source"]
221 |
222 | df_data.to_csv("/Users/liweijiang/Desktop/delphi_algo/data/demo/mturk/split/event_only_v5.csv", index=False, sep="\t")
223 |
224 |
225 | if __name__ == "__main__":
226 | # compile_v234_data()
227 | compile_v5_data()
228 |
229 |
--------------------------------------------------------------------------------
/src/delphi_hybrid/prepare_data/compile_gold_labels.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from collections import Counter
4 |
5 | sys.path.append(os.getcwd())
6 |
7 | from scripts.utils.utils import *
8 |
9 | def get_maj_vote_class_label(raw_class_labels):
10 | class_label_counts = Counter(raw_class_labels)
11 |
12 | if len(set(raw_class_labels)) == 1:
13 | return raw_class_labels[0]
14 | elif len(set(raw_class_labels)) == 2:
15 | if -1 not in raw_class_labels:
16 | count_0 = class_label_counts[0]
17 | count_1 = class_label_counts[1]
18 |
19 | if count_0 > count_1:
20 | return 0
21 | else:
22 | return 1
23 | else:
24 | count_minus_1 = class_label_counts[-1]
25 | if 0 in class_label_counts:
26 | count_POS = class_label_counts[0]
27 | else:
28 | count_POS = class_label_counts[1]
29 |
30 | if count_minus_1 > count_POS:
31 | return -1
32 | else:
33 | if 0 in class_label_counts:
34 | return 0
35 | else:
36 | return 1
37 | else:
38 | count_minus_1 = class_label_counts[-1]
39 | count_0 = class_label_counts[0]
40 | count_1 = class_label_counts[1]
41 |
42 | if count_minus_1 == 3:
43 | return -1
44 | elif count_minus_1 == 2:
45 | if count_0 == 2:
46 | return 0
47 | elif count_1 == 2:
48 | return 1
49 | else:
50 | if count_0 == 3:
51 | return 0
52 | elif count_1 == 3:
53 | return 1
54 | else:
55 | return 0
56 |
57 |
58 | def get_agreement_rate(raw_class_labels):
59 | class_label_counts = Counter(raw_class_labels)
60 |
61 | if -1 in class_label_counts:
62 | count_minus_1 = class_label_counts[-1]
63 | if count_minus_1 in [2, 3]:
64 | return 3/5
65 | elif count_minus_1 in [5]:
66 | return 1
67 | else:
68 | return 4/5
69 | else:
70 | return 1
71 |
72 |
73 | def compile_gold_labels_v5():
74 | data_path = "data/demo/mturk/result/v5.csv"
75 | df_data = pd.read_csv(data_path)
76 |
77 | # print(df_data)
78 |
79 | event_to_raw_labels = {}
80 | all_evants = []
81 | for i, row in df_data.iterrows():
82 | for j in range(4):
83 | event = row[f"Input.action{j+1}"]
84 | label = row[f"Answer.judgment_{j+1}"]
85 |
86 | all_evants.append(event)
87 |
88 | # if label not in [0, 1, -1]:
89 | # print(label)
90 | if event not in event_to_raw_labels:
91 | event_to_raw_labels[event] = [label]
92 | else:
93 | event_to_raw_labels[event].append(label)
94 |
95 | all_evants = list(set(all_evants))
96 | print(len(event_to_raw_labels))
97 |
98 |
99 | all_class_labels = []
100 | all_agreement_rates = []
101 | all_raw_class_labels = []
102 | for event in all_evants:
103 | class_label = get_maj_vote_class_label(event_to_raw_labels[event])
104 | agreement_rate = get_agreement_rate(event_to_raw_labels[event])
105 |
106 | all_class_labels.append(class_label)
107 | all_agreement_rates.append(agreement_rate)
108 | all_raw_class_labels.append(event_to_raw_labels[event])
109 |
110 |
111 | df_data_to_save = pd.DataFrame()
112 | df_data_to_save["event"] = all_evants
113 | df_data_to_save["raw_class_labels"] = all_raw_class_labels
114 | df_data_to_save["agreement_rate"] = all_agreement_rates
115 | df_data_to_save["class_label"] = all_class_labels
116 |
117 | print(df_data_to_save)
118 |
119 | df_data_to_save.to_csv("data/demo/mturk/result/v5_clean.csv", index=False)
120 |
121 |
122 | if __name__ == "__main__":
123 | compile_gold_labels_v5()
124 |
125 |
--------------------------------------------------------------------------------
/src/delphi_hybrid/prepare_data/filter_paraphrases.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from tqdm import tqdm
4 | import argparse
5 | import pandas as pd
6 |
7 | sys.path.append(os.getcwd())
8 | from src.delphi_hybrid.components.utils import *
9 | from src.delphi_hybrid.components.main_utils import *
10 | from src.delphi_hybrid.components.WANLIScorer import *
11 |
12 | def generate_nli():
13 | wanli_scorer = WANLIScorer()
14 | paraphrases_cache = read_json(data_base_path + "cache/paraphrases.json")
15 |
16 | all_events = paraphrases_cache.keys()
17 |
18 | nli_to_save = read_json(data_base_path + "cache/nli.json")
19 | for premise in tqdm(all_events):
20 | for hypothesis in paraphrases_cache[premise]:
21 | if premise + " | " + hypothesis not in nli_to_save:
22 | prediction = wanli_scorer.get_scores(premise, hypothesis)
23 | print(premise, "|", hypothesis, ":", prediction)
24 | prediction["type"] = "event_paraphrase"
25 | nli_to_save[premise + " | " + hypothesis] = prediction
26 |
27 | if hypothesis + " | " + premise not in nli_to_save:
28 | prediction = wanli_scorer.get_scores(hypothesis, premise)
29 | print(hypothesis, "|", premise, ":", prediction)
30 | prediction["type"] = "paraphrase_event"
31 | nli_to_save[hypothesis + " | " + premise] = prediction
32 |
33 | save_json(data_base_path + "cache/nli.json", nli_to_save)
34 |
35 |
36 | if __name__ == "__main__":
37 | parser = argparse.ArgumentParser(description="")
38 |
39 | parser.add_argument('--input_file', type=str, help="location of data file",
40 | default="data/demo/mturk/split/event_only_v5.csv")
41 | parser.add_argument('--device_id', type=int, help="device id", default=0)
42 | parser.add_argument('--total_num_device', type=int,
43 | default=8, help="total number device")
44 | args = parser.parse_args()
45 |
46 | df_data = pd.read_csv(args.input_file, sep="\t")
47 | all_events = df_data["clean_event"].tolist()
48 |
49 | generate_nli()
50 |
--------------------------------------------------------------------------------
/src/delphi_plus/evalute/evaluate_declare_only.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("script/evaluate")
3 | from evaluate_utils import *
4 |
5 |
6 | # ######################## moral acceptability/agreement class ########################
7 | def get_gold_single_input_task_class(bucket, bucket_name, base_path, data_version, task_name, data_split):
8 | """
9 | Get gold inputs and targets class labels
10 | """
11 | data_base_path = f"gs://{bucket_name}/" + "/".join(base_path.split("/")[:3]) + "/data"
12 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_declare_only/{task_name}/{data_split}.tsv", sep="\t")
13 |
14 | inputs_all = list(df_inputs["inputs"])
15 | inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all]
16 |
17 | targets_all = list(df_inputs["targets"])
18 |
19 | targets = [int(i.split("[/class] [text]")[0].split("[class]")[-1]) for i in targets_all]
20 | return inputs, targets
21 |
22 |
23 | def get_pred_single_input_task_class(bucket, base_path, task_name, check_point):
24 | """
25 | Get preds class labels
26 | """
27 | preds_blob = bucket.get_blob(base_path + f"{task_name}_{check_point}_predictions")
28 | preds_blob_list = preds_blob.download_as_string().decode('utf-8').split("\n")[1:]
29 |
30 | preds_class = []
31 | for i in preds_blob_list:
32 | try:
33 | preds_class.append(int(i.split("[/class] [text]")[0].split("[class]")[-1]))
34 | except:
35 | print("output form not identifiable:", i)
36 | preds_class.append(1)
37 |
38 | return preds_class
39 |
40 |
41 |
42 | ######################## moral acceptability/agreement text ########################
43 | def get_gold_single_input_task_text(bucket, bucket_name, base_path, data_version, task_name, data_split):
44 | data_base_path = f"gs://{bucket_name}/" + "/".join(base_path.split("/")[:3]) + "/data"
45 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_declare_only/{task_name}/{data_split}.tsv", sep="\t")
46 | inputs_all = list(df_inputs["inputs"])
47 | inputs = [s.split("[moral_single]: ")[-1] for s in inputs_all]
48 |
49 | targets_all = list(df_inputs["targets"])
50 | targets = [i.split("[/class] [text]")[1].split("[/text]")[0] for i in targets_all]
51 | return inputs, targets
52 |
53 |
54 | def get_pred_single_input_task_text(bucket, base_path, task_name, check_point):
55 | preds_blob = bucket.get_blob(base_path + f"{task_name}_{check_point}_predictions")
56 | preds_blob_list = preds_blob.download_as_string().decode('utf-8').split("\n")[1:]
57 |
58 | preds_text = []
59 | for i in preds_blob_list:
60 | try:
61 | preds_text.append(i.split("[/class] [text]")[1].split("[/text")[0])
62 | except:
63 | print("output form not identifiable:", i)
64 | preds_text.append("")
65 | return preds_text
66 |
67 |
68 | ######################## main ########################
69 | def main_get_accuracy(base_path, data_split, check_points=None,
70 | is_include_accept_class=True,
71 | is_include_accept_text=True,
72 | is_include_agree_class=True,
73 | is_include_agree_text=True,
74 | is_include_compare=True,):
75 | base_path += f"{data_split}_eval/"
76 |
77 | print("=" * 20)
78 | bucket_name = base_path.split("/")[0]
79 | result_prefix = "/".join(base_path.split("/")[1:])
80 | data_version = base_path.split("/")[5]
81 | base_path = "/".join(base_path.split("/")[1:])
82 |
83 | client = storage.Client()
84 | bucket = client.get_bucket(bucket_name)
85 |
86 | if check_points == None:
87 | check_points = get_check_points(client, bucket_name, result_prefix, after_check_point=-1)[1:]
88 | # check_points.sort(reverse=True)
89 |
90 | for check_point in check_points:
91 | print("=" * 40, check_point, "=" * 40)
92 |
93 | main_moral_acceptability(client, bucket, bucket_name, base_path, check_point,
94 | data_version, data_split,
95 | get_gold_single_input_task_class,
96 | get_pred_single_input_task_class,
97 | get_gold_single_input_task_text,
98 | get_pred_single_input_task_text,
99 | is_include_accept_class,
100 | is_include_accept_text)
101 |
102 | main_moral_agreement(client, bucket, bucket_name, base_path, check_point,
103 | data_version, data_split,
104 | get_gold_single_input_task_class,
105 | get_pred_single_input_task_class,
106 | get_gold_single_input_task_text,
107 | get_pred_single_input_task_text,
108 | is_include_agree_class,
109 | is_include_agree_text)
110 |
111 |
112 | if __name__ == "__main__":
113 | base_path = "ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/v11/unicorn-pt/declare_only/lr-0.0001_bs-16/"
114 | # check_points = [1266200]
115 | check_points = None
116 | main_get_accuracy(base_path, "validation", check_points)
117 | # main_get_accuracy(base_path, "test", check_points)
118 |
119 | # base_path = "ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/v11/11B/declare_only/lr-0.0001_bs-16/"
120 | # # check_points = [1266200]
121 | # check_points = None
122 | # main_get_accuracy(base_path, "validation", check_points)
123 | # # main_get_accuracy(base_path, "test", check_points)
124 |
--------------------------------------------------------------------------------
/src/delphi_plus/evalute/evaluate_distribution.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("script/evaluate")
3 | from evaluate_utils import *
4 |
5 |
6 | # ######################## moral acceptability/agreement class ########################
7 | def get_gold_single_input_task_class(bucket, bucket_name, base_path, data_version, task_name, data_split):
8 | """
9 | Get gold inputs and targets class labels
10 | """
11 | data_base_path = f"gs://{bucket_name}/" + "/".join(base_path.split("/")[:3]) + "/data"
12 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/{task_name}/"
13 | f"{data_split}.{task_name}.tsv", sep="\t")
14 |
15 | inputs_all = list(df_inputs["inputs"])
16 | inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all]
17 |
18 | targets_all = list(df_inputs["targets"])
19 | targets = [int(i.split("[/class] [text]")[0].split("[class]")[-1]) for i in targets_all]
20 | return inputs, targets
21 |
22 |
23 | def get_pred_single_input_task_class(bucket, base_path, task_name, check_point):
24 | """
25 | Get preds class labels
26 | """
27 | preds_blob = bucket.get_blob(base_path + f"{task_name}_{check_point}_predictions")
28 | preds_blob_list = preds_blob.download_as_string().decode('utf-8').split("\n")[1:]
29 |
30 | preds_class = []
31 | for i in preds_blob_list:
32 | try:
33 | preds_class.append(int(i.split("[/class] [text]")[0].split("[class]")[-1]))
34 | except:
35 | print("output form not identifiable:", i)
36 | preds_class.append(1)
37 | return preds_class
38 |
39 |
40 | def get_gold_single_input_task_class_wild_v11(bucket, bucket_name, base_path, data_version, task_name, data_split):
41 | data_base_path = f"gs://{bucket_name}/" + "/".join(base_path.split("/")[:3]) + "/data"
42 |
43 | if data_split == "validation":
44 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/"
45 | f"dev.tsv", sep="\t")
46 | # elif data_split == "test" and task_name == "general":
47 | elif data_split == "test":
48 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/"
49 | f"{task_name}.tsv", sep="\t")
50 | # elif data_split == "test" and task_name == "race":
51 | # df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/"
52 | # f"race_test.tsv", sep="\t")
53 | # elif data_split == "test" and task_name == "gender":
54 | # df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/"
55 | # f"gender_test.tsv", sep="\t")
56 | else:
57 | print("ERROR: not validation split")
58 |
59 | inputs_all = list(df_inputs["inputs"])
60 | targets_all = list(df_inputs["targets"])
61 |
62 | # inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all]
63 | inputs = []
64 | for _, i in enumerate(inputs_all):
65 | if type(i) != type(""):
66 | print("gold class input error:", _, i, targets_all[_])
67 | inputs.append("")
68 | else:
69 | inputs.append(i.split("[moral_single]: ")[-1])
70 |
71 |
72 | # targets = [int(i.split("[/class] [text]")[0].split("[class]")[-1]) for i in targets_all]
73 | targets = []
74 | for _, i in enumerate(targets_all):
75 | if type(i) != type(""):
76 | print("gold class output error:", _, i, inputs_all[_])
77 | targets.append(0)
78 | else:
79 | targets.append(int(i.split("[/class] [text]")[0].split("[class]")[-1]))
80 |
81 | return inputs, targets
82 |
83 |
84 | def get_pred_single_input_task_class_wild_v11(bucket, base_path, task_name, check_point):
85 | if task_name == "general_test":
86 | preds_blob = bucket.get_blob(base_path + f"wild_{check_point}_predictions")
87 | else:
88 | preds_blob = bucket.get_blob(base_path + f"{task_name}_{check_point}_predictions")
89 | preds_blob_list = preds_blob.download_as_string().decode('utf-8').split("\n")[1:]
90 |
91 | preds_class = []
92 | for i in preds_blob_list:
93 | try:
94 | preds_class.append(int(i.split("[/class] [text]")[0].split("[class]")[-1]))
95 | except:
96 | print("output form not identifiable:", i)
97 | preds_class.append(1)
98 | return preds_class
99 |
100 |
101 | ######################## moral acceptability/agreement text ########################
102 | def get_gold_single_input_task_text(bucket, bucket_name, base_path, data_version, task_name, data_split):
103 | data_base_path = f"gs://{bucket_name}/" + "/".join(base_path.split("/")[:3]) + "/data"
104 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/{task_name}/"
105 | f"{data_split}.{task_name}.tsv", sep="\t")
106 | inputs_all = list(df_inputs["inputs"])
107 | inputs = [s.split("[moral_single]: ")[-1] for s in inputs_all]
108 |
109 | targets_all = list(df_inputs["targets"])
110 | targets = [i.split("[/class] [text]")[1].split("[/text]")[0] for i in targets_all]
111 | return inputs, targets
112 |
113 |
114 | def get_pred_single_input_task_text(bucket, base_path, task_name, check_point):
115 | preds_blob = bucket.get_blob(base_path + f"{task_name}_{check_point}_predictions")
116 | preds_blob_list = preds_blob.download_as_string().decode('utf-8').split("\n")[1:]
117 |
118 | preds_text = []
119 | for i in preds_blob_list:
120 | try:
121 | preds_text.append(i.split("[/class] [text]")[1].split("[/text")[0])
122 | except:
123 | print("output form not identifiable:", i)
124 | preds_text.append("")
125 | return preds_text
126 |
127 |
128 | def get_gold_single_input_task_text_wild_v11(bucket, bucket_name, base_path, data_version, task_name, data_split):
129 | data_base_path = f"gs://{bucket_name}/" + "/".join(base_path.split("/")[:3]) + "/data"
130 | if data_split == "validation":
131 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/dev.tsv", sep="\t")
132 | # elif data_split == "test" and task_name == "wild":
133 | # df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/general_test.tsv", sep="\t")
134 | # elif data_split == "test" and task_name == "race_test":
135 | # df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/race_test.tsv", sep="\t")
136 | # elif data_split == "test" and task_name == "gender_test":
137 | # df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/gender_test.tsv", sep="\t")
138 | elif data_split == "test":
139 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/{task_name}.tsv", sep="\t")
140 | else:
141 | print("ERROR: not validation split")
142 |
143 | inputs_all = list(df_inputs["inputs"])
144 | targets_all = list(df_inputs["targets"])
145 |
146 | inputs = []
147 | for _, i in enumerate(inputs_all):
148 | if type(i) != type(""):
149 | print("gold text input error:", _, i, targets_all[_])
150 | inputs.append("")
151 | else:
152 | inputs.append(i.split("[moral_single]: ")[-1])
153 |
154 |
155 | # targets = [int(i.split("[/class] [text]")[0].split("[class]")[-1]) for i in targets_all]
156 | targets = []
157 | for _, i in enumerate(targets_all):
158 | if type(i) != type(""):
159 | print("gold text output error:", _, i, inputs_all[_])
160 | targets.append(0)
161 | else:
162 | targets.append(i.split("[/class] [text]")[1].split("[/text]")[0])
163 |
164 | return inputs, targets
165 |
166 |
167 | def get_pred_single_input_task_text_wild_v11(bucket, base_path, task_name, check_point):
168 | if task_name == "general_test":
169 | preds_blob = bucket.get_blob(base_path + f"wild_{check_point}_predictions")
170 | else:
171 | preds_blob = bucket.get_blob(base_path + f"{task_name}_{check_point}_predictions")
172 |
173 | preds_blob_list = preds_blob.download_as_string().decode('utf-8').split("\n")[1:]
174 |
175 | preds_text = []
176 | for i in preds_blob_list:
177 | try:
178 | preds_text.append(i.split("[/class] [text]")[1].split("[/text")[0])
179 | except:
180 | print("output form not identifiable:", i)
181 | preds_text.append("")
182 | return preds_text
183 |
184 |
185 | ######################## moral comparison class ########################
186 | def get_gold_moral_comparison_class(bucket, bucket_name, base_path, data_version, data_split):
187 | data_base_path = f"gs://{bucket_name}/" + "/".join(base_path.split("/")[:3]) + "/data"
188 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/moral_comparison/"
189 | f"{data_split}.moral_comparison.tsv", sep="\t")
190 | inputs_all = list(df_inputs["inputs"])
191 | inputs = [s.split("[moral_pair]: ")[-1] for s in inputs_all]
192 |
193 | targets_all = list(df_inputs["targets"])
194 | targets = [int(t) for t in targets_all]
195 |
196 | return inputs, targets
197 |
198 |
199 | def get_pred_moral_comparison_class(bucket, base_path, check_point):
200 | preds_blob = bucket.get_blob(base_path + f"moral_comparison_{check_point}_predictions")
201 | preds_blob_list = preds_blob.download_as_string().decode('utf-8').split("\n")[1:]
202 | return [int(i) for i in preds_blob_list]
203 |
204 |
205 | ######################## main ########################
206 | def main_get_accuracy(base_path, data_split, check_points=None,
207 | is_include_accept_class=True,
208 | is_include_accept_text=True,
209 | is_include_agree_class=True,
210 | is_include_agree_text=True,
211 | is_include_compare=True,):
212 | base_path += f"{data_split}_eval/"
213 |
214 | bucket_name = base_path.split("/")[0]
215 | result_prefix = "/".join(base_path.split("/")[1:])
216 | data_version = base_path.split("/")[5]
217 | model_type = base_path.split("/")[7]
218 | # eval_data = data_version + "_sbic_" + model_type.split("_")[-3]
219 | base_path = "/".join(base_path.split("/")[1:])
220 |
221 | client = storage.Client()
222 | bucket = client.get_bucket(bucket_name)
223 |
224 | if check_points == None:
225 | check_points = get_check_points(client, bucket_name, result_prefix, after_check_point=-1)[2:]
226 | # check_points.sort(reverse=True)
227 | for check_point in check_points:
228 | print("=" * 40, check_point, "=" * 40)
229 |
230 | # main_moral_acceptability(client, bucket, bucket_name, base_path, check_point,
231 | # data_version, data_split,
232 | # get_gold_single_input_task_class,
233 | # get_pred_single_input_task_class,
234 | # get_gold_single_input_task_text,
235 | # get_pred_single_input_task_text,
236 | # is_include_accept_class,
237 | # is_include_accept_text,
238 | # task_name="moral_acceptability")
239 | #
240 | # main_moral_agreement(client, bucket, bucket_name, base_path, check_point,
241 | # data_version, data_split,
242 | # get_gold_single_input_task_class,
243 | # get_pred_single_input_task_class,
244 | # get_gold_single_input_task_text,
245 | # get_pred_single_input_task_text,
246 | # is_include_agree_class,
247 | # is_include_agree_text,
248 | # task_name="moral_agreement")
249 | #
250 | # if is_include_compare:
251 | # main_moral_comparison(client, bucket, bucket_name, base_path, check_point,
252 | # data_version, data_split,
253 | # get_gold_moral_comparison_class,
254 | # get_pred_moral_comparison_class)
255 |
256 | main_wild_v11(client, bucket, bucket_name, base_path, check_point,
257 | data_version, data_split,
258 | get_gold_single_input_task_class_wild_v11,
259 | get_pred_single_input_task_class_wild_v11,
260 | get_gold_single_input_task_text_wild_v11,
261 | get_pred_single_input_task_text_wild_v11, "general_test")
262 |
263 | if data_split == "test":
264 | main_wild_v11(client, bucket, bucket_name, base_path, check_point,
265 | data_version, data_split,
266 | get_gold_single_input_task_class_wild_v11,
267 | get_pred_single_input_task_class_wild_v11,
268 | get_gold_single_input_task_text_wild_v11,
269 | get_pred_single_input_task_text_wild_v11, "race_test")
270 |
271 | main_wild_v11(client, bucket, bucket_name, base_path, check_point,
272 | data_version, data_split,
273 | get_gold_single_input_task_class_wild_v11,
274 | get_pred_single_input_task_class_wild_v11,
275 | get_gold_single_input_task_text_wild_v11,
276 | get_pred_single_input_task_text_wild_v11, "gender_test")
277 |
278 |
279 | if __name__ == "__main__":
280 | base_path = "ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/v11/unicorn-pt/distribution/lr-0.0001_bs-16/"
281 | check_points = None
282 | # main_get_accuracy(base_path, "validation", check_points)
283 | main_get_accuracy(base_path, "test", [1249400])
284 |
285 |
--------------------------------------------------------------------------------
/src/delphi_plus/evalute/select_declare_only.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("script/evaluate")
3 | from evaluate_utils import *
4 |
5 | def eval_accept(row_accuracies, df_results):
6 | class_targets = df_results["freeform_class_targets"].tolist()
7 | class_preds = df_results["freeform_class_preds"].tolist()
8 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="exact"))
9 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="binary"))
10 |
11 | text_class_targets = df_results["moral_acceptability_text_2_class_targets"].tolist()
12 | text_class_preds = df_results["moral_acceptability_text_2_class_preds"].tolist()
13 | row_accuracies.append(get_accuracy(text_class_targets, text_class_preds, accuracy_type="binary"))
14 |
15 | text_targets = df_results["moral_acceptability_text_targets"].tolist()
16 | text_preds = df_results["moral_acceptability_text_preds"].tolist()
17 | exact_match_accuracy = get_moral_acceptability_text_exact_match_accuracy(text_targets, text_preds)
18 | return row_accuracies
19 |
20 |
21 | def eval_agree(row_accuracies, df_results):
22 | class_targets = df_results["yesno_class_targets"].tolist()
23 | class_preds = df_results["yesno_class_preds"].tolist()
24 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="binary"))
25 |
26 | text_targets = df_results["moral_agreement_text_targets"].tolist()
27 | text_preds = df_results["moral_agreement_text_preds"].tolist()
28 | exact_match_accuracy, polarity_align_accuracy = get_moral_agreement_text_accuracy(text_targets, text_preds)
29 | row_accuracies.append(polarity_align_accuracy)
30 | return row_accuracies
31 |
32 |
33 |
34 | def select_check_point(data_split, model_type, pt_model, bs, check_points):
35 | data_version = "v11"
36 | bucket_name = "ai2-tpu-europe-west4"
37 | lr = 0.0001
38 |
39 | client = storage.Client()
40 | bucket = client.get_bucket(bucket_name)
41 |
42 | print("model_type:", model_type)
43 | print("lr:", lr)
44 | print("bs:", bs)
45 |
46 | result_prefix = f"projects/liweij/mosaic-commonsense-morality/results/{data_version}/" \
47 | f"{pt_model}/{model_type}/lr-{lr}_bs-{bs}/" \
48 | f"freeform/{data_split}/"
49 |
50 | if check_points == None:
51 | check_points = get_result_check_points(client, bucket_name, result_prefix, after_check_point=-1)[2:]
52 |
53 | accuracies = []
54 | for check_point in check_points:
55 | row_accuracies = [check_point]
56 |
57 | ##################### accept #####################
58 | df_results = read_result_file(bucket_name, data_version, model_type,
59 | check_point, data_split, "freeform", lr, bs, pt_model)
60 | row_accuracies = eval_accept(row_accuracies, df_results)
61 |
62 | ##################### agree #####################
63 | df_results = read_result_file(bucket_name, data_version, model_type,
64 | check_point, data_split, "yesno", lr, bs, pt_model)
65 | row_accuracies = eval_agree(row_accuracies, df_results)
66 |
67 | accuracies.append(row_accuracies)
68 | print("-- check point:", check_point, row_accuracies)
69 |
70 | df_to_save = pd.DataFrame(accuracies)
71 | df_to_save.to_csv("temp_result_file_2.csv", index=False)
72 |
73 |
74 | if __name__ == "__main__":
75 | model_type = "declare_only"
76 | pt_model = "unicorn-pt"
77 | bs = 16
78 | check_points = None
79 | select_check_point("validation", model_type, pt_model, bs, check_points)
80 |
--------------------------------------------------------------------------------
/src/delphi_plus/evalute/select_distribution.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("script/evaluate")
3 | from evaluate_utils import *
4 |
5 | def eval_accept(row_accuracies, df_results):
6 | class_targets = df_results["moral_acceptability_class_targets"].tolist()
7 | class_preds = df_results["moral_acceptability_class_preds"].tolist()
8 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="exact"))
9 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="binary"))
10 |
11 | text_class_targets = df_results["moral_acceptability_text_2_class_targets"].tolist()
12 | text_class_preds = df_results["moral_acceptability_text_2_class_preds"].tolist()
13 | row_accuracies.append(get_accuracy(text_class_targets, text_class_preds, accuracy_type="binary"))
14 |
15 | text_targets = df_results["moral_acceptability_text_targets"].tolist()
16 | text_preds = df_results["moral_acceptability_text_preds"].tolist()
17 | exact_match_accuracy = get_moral_acceptability_text_exact_match_accuracy(text_targets, text_preds)
18 | return row_accuracies
19 |
20 |
21 | def eval_agree(row_accuracies, df_results):
22 | class_targets = df_results["moral_agreement_class_targets"].tolist()
23 | class_preds = df_results["moral_agreement_class_preds"].tolist()
24 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="binary"))
25 |
26 | text_targets = df_results["moral_agreement_text_targets"].tolist()
27 | text_preds = df_results["moral_agreement_text_preds"].tolist()
28 | exact_match_accuracy, polarity_align_accuracy = get_moral_agreement_text_accuracy(text_targets, text_preds)
29 | row_accuracies.append(polarity_align_accuracy)
30 | return row_accuracies
31 |
32 |
33 | def eval_compare(row_accuracies, df_results):
34 | class_targets = df_results["moral_comparison_class_targets"].tolist()
35 | class_preds = df_results["moral_comparison_class_preds"].tolist()
36 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="exact"))
37 | return row_accuracies
38 |
39 |
40 | def eval_wild(row_accuracies, df_results):
41 | df_results["wild_class_targets"] = df_results["wild_class_targets"]
42 | df_results["wild_class_preds"] = df_results["wild_class_preds"]
43 | class_targets = df_results["wild_class_targets"].tolist()
44 | class_preds = df_results["wild_class_preds"].tolist()
45 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="exact"))
46 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="binary"))
47 |
48 | text_class_targets = df_results["wild_v11_text_2_class_targets"].tolist()
49 | text_class_preds = df_results["wild_v11_text_2_class_preds"].tolist()
50 |
51 | row_accuracies.append(get_accuracy(text_class_targets, text_class_preds, accuracy_type="binary"))
52 | return row_accuracies
53 |
54 |
55 | def select_check_point(data_split, model_type, pt_model, bs, check_points):
56 | data_version = "v11"
57 | bucket_name = "ai2-tpu-europe-west4"
58 | lr = 0.0001
59 |
60 | client = storage.Client()
61 | bucket = client.get_bucket(bucket_name)
62 |
63 | print("model_type:", model_type)
64 | print("lr:", lr)
65 | print("bs:", bs)
66 |
67 | result_prefix = f"projects/liweij/mosaic-commonsense-morality/results/{data_version}/" \
68 | f"{pt_model}/{model_type}/lr-{lr}_bs-{bs}/" \
69 | f"moral_acceptability/{data_split}/"
70 |
71 | if check_points == None:
72 | check_points = get_result_check_points(client, bucket_name, result_prefix, after_check_point=-1)[2:]
73 |
74 | accuracies = []
75 | for check_point in check_points:
76 | row_accuracies = [check_point]
77 |
78 | ##################### accept #####################
79 | df_results = read_result_file(bucket_name, data_version, model_type,
80 | check_point, data_split, "moral_acceptability", lr, bs, pt_model)
81 | row_accuracies = eval_accept(row_accuracies, df_results)
82 |
83 | ##################### agree #####################
84 | df_results = read_result_file(bucket_name, data_version, model_type,
85 | check_point, data_split, "moral_agreement", lr, bs, pt_model)
86 | row_accuracies = eval_agree(row_accuracies, df_results)
87 |
88 | ##################### compare #####################
89 | df_results = read_result_file(bucket_name, data_version, model_type,
90 | check_point, data_split, "moral_comparison", lr, bs, pt_model)
91 | row_accuracies = eval_compare(row_accuracies, df_results)
92 |
93 | ##################### wild general #####################
94 | df_results = read_result_file(bucket_name, data_version, model_type,
95 | check_point, data_split, "wild", lr, bs, pt_model)
96 | df_results = df_results[df_results["wild_class_targets"] < 90]
97 | row_accuracies = eval_wild(row_accuracies, df_results)
98 |
99 |
100 | accuracies.append(row_accuracies)
101 | print("-- check point:", check_point, row_accuracies)
102 |
103 | df_to_save = pd.DataFrame(accuracies)
104 | df_to_save.to_csv("temp_result_file_2.csv", index=False)
105 |
106 |
107 | if __name__ == "__main__":
108 | model_type = "distribution"
109 | pt_model = "unicorn-pt"
110 | bs = 16
111 | check_points = None
112 | select_check_point("validation", model_type, pt_model, bs, check_points)
113 |
--------------------------------------------------------------------------------
/src/delphi_plus/train/evaluate.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 |
3 | """
4 | Evaluate the model checkpoint
5 | """
6 |
7 | import t5
8 | import os
9 | import sys
10 | import util
11 | import seqio
12 | import click
13 | import logging
14 | import tensorflow.compat.v1 as tf
15 |
16 | print("python", sys.version)
17 | print("t5", t5.__version__)
18 | print("tf", tf.__version__)
19 | print("seqio", seqio.__version__)
20 |
21 | tf.disable_v2_behavior()
22 |
23 | import tasks, mixtures
24 | # N.B. We must import tasks and mixtures here so that they are registered and available for evaluation.
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 | @click.command()
29 | @click.argument("mixture", type=str)
30 | @click.argument("results_dir", type=str)
31 | @click.argument("tpu-name", type=str) # The name of the TPU. Defaults to the TPU_NAME environment variable.
32 | @click.argument("tpu-topology", type=str) # The topology of the TPU. Defaults to the TPU_TOPOLOGY environment variable.
33 | @click.argument("split", type=str)
34 | @click.argument("checkpoint", type=int)
35 | @click.option(
36 | "--model-parallelism",
37 | type=int,
38 | default=8,
39 | help="The degree of model parallelism to use. Defaults to 8.",
40 | )
41 |
42 | def evaluate(
43 | mixture: str,
44 | results_dir: str,
45 | split: str,
46 | checkpoint: int,
47 | model_parallelism: int,
48 | tpu_name: str,
49 | tpu_topology: str,
50 | ) -> None:
51 | """
52 | Evaluate the model located at RESULTS_DIR on MIXTURE.
53 | """
54 |
55 | print(tpu_name)
56 | print(tpu_topology)
57 | # print(mixture)
58 |
59 | # Initialize arguments
60 | if tpu_topology == "v3-32":
61 | batch_size = 16
62 | model_parallelism = 8
63 | elif tpu_topology == "v3-8":
64 | batch_size = 8
65 | model_parallelism = 8
66 | else:
67 | print("ERROR: tpu_topology invalid")
68 | return
69 |
70 | # Validate arguments
71 | util.validate_path(results_dir)
72 |
73 | checkpoints = util.get_result_check_points(results_dir, split, "ethics_cm_converted_class_only")
74 |
75 | print("-" * 10, "checkpoints todo", "-" * 10)
76 |
77 | if checkpoint == 100:
78 | checkpoints_to_eval = None
79 | elif checkpoint == 0:
80 | checkpoints_to_eval = checkpoints
81 | else:
82 | checkpoints_to_eval = [checkpoint]
83 | print(checkpoints_to_eval)
84 |
85 | # Run evaluation
86 | model = t5.models.MtfModel(
87 | model_dir=results_dir,
88 | tpu=tpu_name,
89 | tpu_topology=tpu_topology,
90 | model_parallelism=model_parallelism,
91 | batch_size=batch_size,
92 | sequence_length={"inputs": 512, "targets": 128},
93 | learning_rate_schedule=None,
94 | save_checkpoints_steps=5000,
95 | keep_checkpoint_max=None,
96 | iterations_per_loop=100,
97 | )
98 |
99 | model.eval(
100 | mixture_or_task_name=mixture,
101 | checkpoint_steps=checkpoints_to_eval,
102 | split=split,
103 | )
104 |
105 |
106 | if __name__ == "__main__":
107 | evaluate()
108 |
--------------------------------------------------------------------------------
/src/delphi_plus/train/fine-tune.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 |
3 | """Fine-tune T5 based models."""
4 |
5 | import t5
6 | import sys
7 | import seqio
8 | import click
9 | import logging
10 | import tensorflow.compat.v1 as tf
11 |
12 | print("python", sys.version)
13 | print("t5", t5.__version__)
14 | print("tf", tf.__version__)
15 | print("seqio", seqio.__version__)
16 |
17 | import util
18 | import warnings
19 | import tasks, mixtures
20 | # We must import tasks and mixtures here so that the tasks and mixtures are registered and available for training.
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 | v=tf.compat.v1.logging.FATAL
25 | tf.compat.v1.logging.set_verbosity(v)
26 | tf.disable_v2_behavior()
27 |
28 | config = tf.ConfigProto()
29 | config.gpu_options.allow_growth = True
30 | session = tf.InteractiveSession(config=config)
31 |
32 | warnings.filterwarnings("ignore", category=DeprecationWarning)
33 |
34 | PRETRAINED_MODELS = {
35 | "small": ("gs://t5-data/pretrained_models/small/", -1),
36 | "base": ("gs://t5-data/pretrained_models/base/", -1),
37 | "large": ("gs://t5-data/pretrained_models/large/", -1),
38 | "3B": ("gs://t5-data/pretrained_models/3B/", -1),
39 | "11B": ("gs://t5-data/pretrained_models/11B/", -1),
40 | "unicorn-pt": ("gs://ai2-mosaic-public/projects/rainbow/v1.0/unicorns/lr-2e-3_batch-size-32/", -1),
41 | "v11-delphi-declare": (
42 | "gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/v11/unicorn-pt/declare_only/lr-0.0001_bs-16/",
43 | 1218800),
44 | }
45 |
46 | @click.command()
47 | @click.argument("mixture", type=str)
48 | @click.argument("results_dir", type=str)
49 | @click.argument("tpu-name", type=str) # The name of the TPU. Defaults to the TPU_NAME environment variable.
50 | @click.argument("tpu-topology", type=str) # The topology of the TPU. Defaults to the TPU_TOPOLOGY environment variable.
51 | @click.argument("pretrained-model", type=str)
52 | @click.option(
53 | "--split",
54 | type=str,
55 | default="train",
56 | help="The split on which to train. Defaults to 'train'.",
57 | )
58 | @click.option(
59 | "--n-steps",
60 | type=int,
61 | default=600000,
62 | help="The number of gradient updates. Defaults to 25,000.",
63 | )
64 | @click.option(
65 | "--save-checkpoints-steps",
66 | type=int,
67 | default=5000,
68 | help=(
69 | "The number of steps to take before saving a checkpoint. Defaults to"
70 | " 5000."
71 | ),
72 | )
73 | @click.option(
74 | "--n-checkpoints-to-keep",
75 | type=int,
76 | default=300,
77 | help=(
78 | "The number of checkpoints to keep during fine-tuning. Defaults"
79 | " to 4."
80 | ),
81 | )
82 | @click.option(
83 | "--learning-rate",
84 | type=float,
85 | default=2e-4,
86 | help="The learning rate to use for training. Defaults to 3e-3.",
87 | )
88 | @click.option(
89 | "--continue_finetune",
90 | type=bool,
91 | default=True,
92 | help="Whether to continue training from an existing checkpoint.",
93 | )
94 |
95 | def fine_tune(
96 | mixture: str,
97 | results_dir: str,
98 | split: str,
99 | pretrained_model: str,
100 | n_steps: int,
101 | learning_rate: float,
102 | save_checkpoints_steps: int,
103 | n_checkpoints_to_keep: int,
104 | tpu_name: str,
105 | tpu_topology: str,
106 | continue_finetune: bool,
107 | ) -> None:
108 | """
109 | Fine-tune the model on MIXTURE, writing results to RESULTS_DIR.
110 | """
111 |
112 | # Initialize arguments
113 | if tpu_topology == "v3-32":
114 | batch_size = 16
115 | model_parallelism = 32
116 | elif tpu_topology == "v3-8":
117 | batch_size = 8
118 | model_parallelism = 8
119 | else:
120 | print("ERROR: tpu_topology invalid")
121 | return
122 |
123 | pretrained_checkpoint_step = -1
124 |
125 | # Get result path given arguments
126 | result_path = util.get_result_path(results_dir, pretrained_model, mixture, learning_rate, batch_size)
127 |
128 | # Validate path
129 | util.validate_path(results_dir, pretrained_model, PRETRAINED_MODELS)
130 |
131 | # Process arguments
132 | if pretrained_model in PRETRAINED_MODELS:
133 | pretrained_model, pretrained_checkpoint_step = PRETRAINED_MODELS[pretrained_model]
134 |
135 | # If the training stops before finishing and we want to continue from the last checkpoint
136 | if continue_finetune:
137 | pretrained_model = result_path
138 |
139 | # Print arguments
140 | util.print_arguments(result_path, results_dir, mixture, split, pretrained_model,
141 | pretrained_checkpoint_step, n_steps, batch_size, model_parallelism,
142 | save_checkpoints_steps, n_checkpoints_to_keep, learning_rate,
143 | tpu_name, tpu_topology, tasks, continue_finetune)
144 |
145 | # Run fine-tuning
146 | model = t5.models.MtfModel(
147 | model_dir=result_path,
148 | tpu=tpu_name,
149 | tpu_topology=tpu_topology,
150 | model_parallelism=model_parallelism,
151 | batch_size=batch_size,
152 | sequence_length={"inputs": 512, "targets": 128},
153 | learning_rate_schedule=learning_rate,
154 | save_checkpoints_steps=save_checkpoints_steps,
155 | keep_checkpoint_max=n_checkpoints_to_keep,
156 | iterations_per_loop=100,
157 | )
158 |
159 | model.finetune(
160 | mixture_or_task_name=mixture,
161 | pretrained_model_dir=pretrained_model,
162 | pretrained_checkpoint_step=pretrained_checkpoint_step,
163 | finetune_steps=n_steps,
164 | split=split,
165 | )
166 |
167 |
168 | if __name__ == "__main__":
169 | fine_tune()
170 |
--------------------------------------------------------------------------------
/src/delphi_plus/train/mixtures.py:
--------------------------------------------------------------------------------
1 | import os
2 | import t5
3 | import tasks
4 | import rates
5 | import seqio
6 | import functools
7 |
8 | import util
9 |
10 | # seqio.MixtureRegistry.add(
11 | # "sbic_commonsense_morality_joint_all_proportional",
12 | # ["sbic_moral_acceptability",
13 | # "sbic_moral_agreement",
14 | # "sbic_moral_comparison"],
15 | # default_rate=rates.MIXING_RATES["proportional"]
16 | # )
17 | # util.print_mixture_examples("sbic_commonsense_morality_joint_all_proportional")
18 | #
19 | # ################## commonsense norm bank + wild ablations ##################
20 | # proportions = [10, 20, 40, 60, 80, 100]
21 | # for p in proportions:
22 | # seqio.MixtureRegistry.add(
23 | # f"sbic_commonsense_morality_joint_all_proportional_wild_{p}",
24 | # [ f"wild_train_{p}",
25 | # "sbic_moral_acceptability",
26 | # "sbic_moral_agreement",
27 | # "sbic_moral_comparison"],
28 | # default_rate=rates.MIXING_RATES["proportional"]
29 | # )
30 | # util.print_mixture_examples(f"sbic_commonsense_morality_joint_all_proportional_wild_{p}")
31 | #
32 | #
33 | # seqio.MixtureRegistry.add(
34 | # "sbic_commonsense_morality_joint_all_proportional_wild_woz_100",
35 | # [ f"wild_train_woz_100",
36 | # "sbic_moral_acceptability",
37 | # "sbic_moral_agreement",
38 | # "sbic_moral_comparison"],
39 | # default_rate=rates.MIXING_RATES["proportional"]
40 | # )
41 | # util.print_mixture_examples(f"sbic_commonsense_morality_joint_all_proportional_wild_woz_100")
42 | #
43 | #
44 | # seqio.MixtureRegistry.add(
45 | # "sbic_commonsense_morality_joint_all_proportional_wild_woz_100_v1",
46 | # [ f"wild_train_woz_100",
47 | # "sbic_moral_acceptability",
48 | # "sbic_moral_agreement",
49 | # "sbic_moral_comparison"],
50 | # default_rate=rates.MIXING_RATES["proportional"]
51 | # )
52 | # util.print_mixture_examples(f"sbic_commonsense_morality_joint_all_proportional_wild_woz_100_v1")
53 | #
54 | #
55 | # seqio.MixtureRegistry.add(
56 | # "wild_hard_test",
57 | # [ "race_test",
58 | # "gender_test"],
59 | # default_rate=rates.MIXING_RATES["proportional"]
60 | # )
61 | # util.print_mixture_examples(f"wild_hard_test")
62 | #
63 | # seqio.MixtureRegistry.add(
64 | # "all",
65 | # [ f"wild_train_100",
66 | # "sbic_moral_acceptability",
67 | # "sbic_moral_agreement",
68 | # "sbic_moral_comparison",
69 | # "race_test",
70 | # "gender_test"],
71 | # default_rate=rates.MIXING_RATES["proportional"]
72 | # )
73 | # util.print_mixture_examples(f"all")
74 |
75 |
76 | # seqio.MixtureRegistry.add(
77 | # "sbic_single_class_only",
78 | # [ "sbic_moral_acceptability_single_class_only",
79 | # "sbic_moral_agreement_single_class_only"],
80 | # default_rate=rates.MIXING_RATES["proportional"]
81 | # )
82 | # util.print_mixture_examples(f"sbic_single_class_only")
83 | #
84 |
85 | # dynahate = True
86 | if tasks.dynahate:
87 | for t in tasks.dynahate_task_formats:
88 | seqio.MixtureRegistry.add(
89 | f"dynahate_all_{t}",
90 | [ f"dynahate_round_1_{t}",
91 | f"dynahate_round_2_{t}",
92 | f"dynahate_round_3_{t}",
93 | f"dynahate_round_4_{t}",],
94 | default_rate=rates.MIXING_RATES["proportional"]
95 | )
96 |
97 | seqio.MixtureRegistry.add(
98 | f"dynahate_all_{t}_100_shot",
99 | [ f"dynahate_round_1_{t}_100_shot",
100 | f"dynahate_round_2_{t}_100_shot",
101 | f"dynahate_round_3_{t}_100_shot",
102 | f"dynahate_round_4_{t}_100_shot",],
103 | default_rate=rates.MIXING_RATES["proportional"]
104 | )
105 |
106 |
107 | seqio.MixtureRegistry.add(
108 | "declare_only",
109 | [f"freeform",
110 | f"yesno"],
111 | default_rate=rates.MIXING_RATES["proportional"]
112 | )
113 |
114 | seqio.MixtureRegistry.add(
115 | "norm_bank",
116 | ["moral_acceptability",
117 | "moral_agreement",
118 | "moral_comparison",
119 | ],
120 | default_rate=rates.MIXING_RATES["proportional"]
121 | )
122 |
123 |
124 | # seqio.MixtureRegistry.add(
125 | # "distribution",
126 | # ["moral_acceptability",
127 | # "moral_agreement",
128 | # "moral_comparison",
129 | # "wild",
130 | # ],
131 | # default_rate=rates.MIXING_RATES["proportional"]
132 | # )
133 | #
134 | # seqio.MixtureRegistry.add(
135 | # "all_distribution_wild",
136 | # ["wild",
137 | # "race_test",
138 | # "gender_test",
139 | # ],
140 | # default_rate=rates.MIXING_RATES["proportional"]
141 | # )
142 |
143 |
144 | seqio.MixtureRegistry.add(
145 | "maj_vote",
146 | ["moral_acceptability",
147 | "moral_agreement",
148 | "moral_comparison",
149 | "wild",
150 | ],
151 | default_rate=rates.MIXING_RATES["proportional"]
152 | )
153 |
154 | seqio.MixtureRegistry.add(
155 | "all_maj_vote_wild",
156 | ["wild",
157 | "race_test",
158 | "gender_test",
159 | ],
160 | default_rate=rates.MIXING_RATES["proportional"]
161 | )
162 |
--------------------------------------------------------------------------------
/src/delphi_plus/train/predict.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 |
3 | """Evaluate the model on the rainbow datasets."""
4 |
5 | import t5
6 | import os
7 | import sys
8 | import seqio
9 | import logging
10 | import click
11 | import util
12 | import pandas as pd
13 | import tensorflow.compat.v1 as tf
14 |
15 | # Improve logging.
16 | from contextlib import contextmanager
17 |
18 | # print("python", sys.version)
19 | # print("t5", t5.__version__)
20 | # print("tf", tf.__version__)
21 | # print("seqio", seqio.__version__)
22 |
23 | tf.disable_v2_behavior()
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 | def getSubstringBetweenMarkers(source_string, start_marker, end_marker):
28 | start = source_string.find(start_marker) + len(start_marker)
29 | end = source_string.find(end_marker)
30 | return source_string[start: end]
31 |
32 |
33 | @contextmanager
34 | def tf_verbosity_level(level):
35 | og_level = tf.logging.get_verbosity()
36 | tf.logging.set_verbosity(level)
37 | yield
38 | tf.logging.set_verbosity(og_level)
39 |
40 |
41 | @click.command()
42 | @click.option(
43 | "--batch-size",
44 | type=int,
45 | default=64,
46 | help=(
47 | "The batch size to use for prediction. For efficient prediction on the"
48 | " TPU, choose a multiple of either 8 or 128. Defaults to 64."
49 | ),
50 | )
51 | @click.option(
52 | "--model-parallelism",
53 | type=int,
54 | default=8,
55 | help="The degree of model parallelism to use. Defaults to 8.",
56 | )
57 | @click.option(
58 | "--tpu-name",
59 | type=str,
60 | default="de-tpu1",
61 | required=True,
62 | help="The name of the TPU. Defaults to the TPU_NAME environment variable.",
63 | )
64 | @click.option(
65 | "--tpu-topology",
66 | type=str,
67 | default="v3-32",
68 | required=True,
69 | help=(
70 | "The topology of the TPU. Defaults to the TPU_TOPOLOGY environment"
71 | " variable."
72 | ),
73 | )
74 | def predict(
75 | batch_size: int,
76 | model_parallelism: int,
77 | tpu_name: str,
78 | tpu_topology: str,
79 | ) -> None:
80 | """Evaluate the model located at RESULTS_DIR on MIXTURE."""
81 |
82 | eval_data = "race_topk_batch6to10"
83 |
84 | data_version = "v9"
85 | model_type = "sbic_commonsense_morality_joint_all_proportional"
86 | check_point = 1264700
87 | lr = 0.0001
88 | bs = 16
89 | bucket_name = "ai2-tpu-europe-west4"
90 | models_dir = f"gs://{bucket_name}/projects/liweij/mosaic-commonsense-morality/model/{data_version}/" \
91 | f"unicorn-pt/{model_type}/lr-{lr}_bs-{bs}"
92 | training_type = model_type.split("_")[-3]
93 |
94 | # Run evaluation.
95 | model = t5.models.MtfModel(
96 | model_dir=models_dir,
97 | tpu=tpu_name,
98 | tpu_topology=tpu_topology,
99 | model_parallelism=model_parallelism,
100 | batch_size=batch_size,
101 | sequence_length={"inputs": 512, "targets": 128},
102 | learning_rate_schedule=None,
103 | save_checkpoints_steps=5000,
104 | keep_checkpoint_max=None,
105 | iterations_per_loop=100,
106 | )
107 |
108 | predict_joint_inputs_paths = ["gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/" \
109 | f"data/qualitative_eval/{training_type}/" + eval_data + "_qualitative_eval.tsv"]
110 | predict_joint_outputs_paths = [
111 | models_dir.replace("model", "preds") + "/raw/" + eval_data + "_qualitative_eval.tsv"]
112 |
113 | for i in range(len(predict_joint_inputs_paths)):
114 | predict_joint_inputs_path = predict_joint_inputs_paths[i]
115 | predict_joint_outputs_path = predict_joint_outputs_paths[i]
116 |
117 | # Ignore any logging so that we only see the model's answers to the questions.
118 | with tf_verbosity_level('ERROR'):
119 | model.batch_size = 8 # Min size for small model on v2-8 with parallelism 1.
120 | model.predict(
121 | input_file=predict_joint_inputs_path,
122 | output_file=predict_joint_outputs_path,
123 | # Select the most probable output token at each step.
124 | temperature=0,
125 | checkpoint_steps=check_point,
126 | )
127 |
128 |
129 | if __name__ == "__main__":
130 | predict()
131 |
--------------------------------------------------------------------------------
/src/delphi_plus/train/rates.py:
--------------------------------------------------------------------------------
1 | """
2 | Mixing rates
3 | """
4 |
5 | import seqio
6 |
7 | def equal_rate(task: seqio.Task):
8 | """Mix the datasets in equal amounts.
9 |
10 | Parameters
11 | ----------
12 | task : t5.data.Task
13 | The task.
14 |
15 | Returns
16 | -------
17 | float
18 | The constant: ``1.0``.
19 | """
20 | return 1.0
21 |
22 |
23 | def proportional_rate(task: seqio.Task):
24 | """Mix the datasets proportionally.
25 |
26 | Parameters
27 | ----------
28 | task : t5.data.Task
29 | The task.
30 |
31 | Returns
32 | -------
33 | float
34 | The number of examples in the task's training set.
35 | """
36 | return float(task.num_input_examples("train"))
37 |
38 |
39 | # constants
40 |
41 | MIXING_RATES = {
42 | "equal": equal_rate,
43 | "proportional": proportional_rate,
44 | }
45 | """A dictionary mapping mixing rates' names to their implementations."""
46 |
--------------------------------------------------------------------------------
/src/delphi_plus/train/util.py:
--------------------------------------------------------------------------------
1 | """
2 | Util functions for fine-tuning and evaluating models
3 | """
4 | import seqio
5 | import pandas as pd
6 | from google.cloud import storage
7 | import tensorflow_datasets as tfds
8 | import tensorflow as tf
9 |
10 |
11 | def create_folder(client, bucket, destination_folder_name):
12 | """
13 | Create a folder in Google Cloud Storage if such folder doesn't exist already
14 | """
15 | if not storage.Blob(bucket=bucket, name=destination_folder_name).exists(client):
16 | blob = bucket.blob(destination_folder_name)
17 | blob.upload_from_string('')
18 | print('Created: {}'.format(destination_folder_name))
19 | else:
20 | print('Exists: {}'.format(destination_folder_name))
21 |
22 |
23 | def print_task_examples(task_name, split="validation", num_ex=1):
24 | """
25 | Print examples from tasks
26 | """
27 | print("#" * 20, task_name, "#" * 20)
28 | task = seqio.TaskRegistry.get(task_name)
29 | ds = task.get_dataset(split=split, sequence_length={"inputs": 512, "targets": 128})
30 | for i, ex in enumerate(tfds.as_numpy(ds.take(num_ex))):
31 | print(i, ex)
32 | print("test", task.num_input_examples("test"))
33 | print("train", task.num_input_examples("train"))
34 | print("validation", task.num_input_examples("validation"))
35 |
36 |
37 | def print_mixture_examples(mixture_name, split="validation", num_ex=1):
38 | """
39 | Print examples from mixtures
40 | """
41 | print("#" * 20, mixture_name, "#" * 20)
42 | mixture = seqio.MixtureRegistry.get(mixture_name)
43 | ds = mixture.get_dataset(split=split,
44 | sequence_length={"inputs": 512, "targets": 128})
45 |
46 | for i, ex in enumerate(tfds.as_numpy(ds.take(num_ex))):
47 | print(i, ex)
48 | print("test", mixture.num_input_examples("test"))
49 | print("train", mixture.num_input_examples("train"))
50 | print("validation", mixture.num_input_examples("validation"))
51 |
52 |
53 | def get_num_elements_csv(file_name):
54 | """
55 | Get the total number of elements in a given csv/tsv file
56 | """
57 | df = pd.read_csv(file_name, delimiter="\t")
58 | return df.shape[0]
59 |
60 |
61 | def get_num_elements_split(split_paths):
62 | """
63 | Get the number of elements in each split of a dataset
64 | """
65 | num_elements_split = {}
66 | for split, path in split_paths.items():
67 | num_elements_split[split] = get_num_elements_csv(path)
68 | return num_elements_split
69 |
70 |
71 | def get_result_check_points(result_prefix, split, eval_data_type, after_check_point=-1):
72 | """
73 | Get a list of model checkpoints that haven't generated on the designated data split yet
74 | """
75 | client = storage.Client()
76 | bucket_name = "ai2-tpu-europe-west4"
77 | result_prefix = result_prefix.split(bucket_name + "/")[-1] + "/"
78 |
79 | check_points = []
80 | done_check_points = []
81 | for blob in client.list_blobs(bucket_name, prefix=result_prefix):
82 | blob_name = str(blob).split("/")[-1]
83 | if ".meta" in blob_name:
84 | check_point = int(blob_name.split(".meta")[0].split("-")[-1])
85 | if check_point > after_check_point:
86 | check_points.append(check_point)
87 |
88 | print("-" * 10, "checkpoints all", "-" * 10)
89 | print(check_points)
90 |
91 | for blob in client.list_blobs(bucket_name, prefix=result_prefix + f"{split}_eval/"):
92 | blob_name = str(blob).split("/")[-1]
93 | if "_predictions" in blob_name and eval_data_type in blob_name and "_predictions_clean" not in blob_name:
94 | check_point_done = int(blob_name.split("_predictions")[0].split("_")[-1])
95 | # check_point_done = int(blob_name.split("_")[0].split("_")[-1])
96 | if check_point_done in check_points:
97 | done_check_points.append(check_point_done)
98 | check_points.remove(check_point_done)
99 |
100 | print("-" * 10, "checkpoints done", "-" * 10)
101 | print(done_check_points)
102 | return check_points
103 |
104 |
105 | def validate_path(results_dir, pretrained_model=None, PRETRAINED_MODELS=None):
106 | """
107 | Validate result path
108 | """
109 | if PRETRAINED_MODELS != None:
110 | if not results_dir.startswith("gs://"):
111 | raise ValueError(f"RESULTS_DIR ({results_dir}) must be a GCS path.")
112 |
113 | if pretrained_model.startswith("gs://"):
114 | if not tf.io.gfile.exists(pretrained_model):
115 | raise IOError(
116 | f"--pretrained-model ({pretrained_model}) does not exist."
117 | )
118 | else:
119 | if pretrained_model not in PRETRAINED_MODELS:
120 | raise ValueError(
121 | f"--pretrained-model ({pretrained_model}) not recognized. It"
122 | f" must either be a GCS path or one of"
123 | f' {", ".join(PRETRAINED_MODELS.keys())}.')
124 | else:
125 | if not results_dir.startswith("gs://"):
126 | raise ValueError(f"RESULTS_DIR ({results_dir}) must be a GCS path.")
127 | elif not tf.io.gfile.exists(results_dir):
128 | raise IOError(f"RESULTS_DIR ({results_dir}) doesn't exist.")
129 |
130 |
131 | def print_arguments(result_path, results_dir, mixture, split, pretrained_model,
132 | pretrained_checkpoint_step, n_steps, batch_size, model_parallelism,
133 | save_checkpoints_steps, n_checkpoints_to_keep, learning_rate,
134 | tpu_name, tpu_topology, tasks, continue_finetune):
135 | print("=" * 10, "results_dir")
136 | print(results_dir)
137 |
138 | print("=" * 10, "mixture")
139 | print(mixture)
140 |
141 | print("=" * 10, "split")
142 | print(split)
143 |
144 | print("=" * 10, "pretrained_model")
145 | print(pretrained_model)
146 |
147 | print("=" * 10, "pretrained_checkpoint_step")
148 | print(pretrained_checkpoint_step)
149 |
150 | print("=" * 10, "n_steps")
151 | print(n_steps)
152 |
153 | print("=" * 10, "batch_size")
154 | print(batch_size)
155 |
156 | print("=" * 10, "model_parallelism")
157 | print(model_parallelism)
158 |
159 | print("=" * 10, "save_checkpoints_steps")
160 | print(save_checkpoints_steps)
161 |
162 | print("=" * 10, "n_checkpoints_to_keep")
163 | print(n_checkpoints_to_keep)
164 |
165 | print("=" * 10, "learning_rate")
166 | print(learning_rate)
167 |
168 | print("=" * 10, "tpu_name")
169 | print(tpu_name)
170 |
171 | print("=" * 10, "tpu_topology")
172 | print(tpu_topology)
173 |
174 | print("=" * 10, "result_path")
175 | print(result_path)
176 |
177 | print("=" * 10, "data_version")
178 | print(tasks.data_version)
179 |
180 | print("=" * 10, "continue_finetune")
181 | print(continue_finetune)
182 |
183 |
184 | def get_result_path(
185 | results_dir: str,
186 | pretrained_model: str,
187 | mixture: str,
188 | learning_rate: float,
189 | batch_size: int
190 | ) -> str:
191 | """
192 | Get a result path given arguments
193 | """
194 | result_path = results_dir + \
195 | "/" + pretrained_model + \
196 | "/" + mixture + \
197 | f"/lr-{learning_rate}_bs-{batch_size}"
198 | return result_path
199 |
200 |
--------------------------------------------------------------------------------