├── .gitignore ├── README.md ├── assets ├── bias_results.png ├── delphi_hybrid.png ├── examples.png ├── main_results.png ├── norm_bank_content.png └── overall.png ├── data └── datasheet.md └── src ├── delphi ├── evaluate │ ├── evaluate_dynahate_v9.py │ ├── evaluate_dynahate_yesno.py │ ├── evaluate_ethics_converted.py │ ├── evaluate_ethics_v9.py │ ├── evaluate_joint_st_wild_v9.py │ ├── evaluate_latenthatred_v9.py │ ├── evaluate_latenthatred_yesno.py │ ├── evaluate_utils.py │ └── select_check_point_wild_v9.py └── train │ ├── evaluate.py │ ├── fine-tune.py │ ├── mixtures.py │ ├── predict.py │ ├── rates.py │ ├── tasks.py │ └── util.py ├── delphi_hybrid ├── collective_reasoning │ ├── ITW │ │ ├── BaseRuler.py │ │ ├── EnsembleRuler.py │ │ └── Ruler.py │ ├── ITW_NormBank │ │ ├── ITWBaseRuler.py │ │ ├── NormBankBaseRuler.py │ │ └── Ruler.py │ └── NormBank │ │ ├── BaseRuler.py │ │ ├── EnsembleRuler.py │ │ └── Ruler.py ├── components │ ├── COMETGenerator.py │ ├── CacheHandler │ │ ├── COMETCacheHandler.py │ │ ├── CacheHandler.py │ │ ├── DelphiCacheHandler.py │ │ └── ParaphraseCacheHandler.py │ ├── CompositionalityParser.py │ ├── DelphiScorer.py │ ├── GPT3Scorer.py │ ├── LMScorer.py │ ├── MoralSaliencyKeywordCounter.py │ ├── MoralSaliencyKeywordIdentifier.py │ ├── Paraphraser.py │ ├── WANLIScorer.py │ ├── bank.py │ ├── constants.py │ ├── main_utils.py │ └── utils.py └── prepare_data │ ├── compile_data_v6.py │ ├── compile_demo_data.py │ ├── compile_gold_labels.py │ ├── filter_paraphrases.py │ └── prepare_data.py ├── delphi_plus ├── evalute │ ├── evaluate_declare_only.py │ ├── evaluate_distribution.py │ ├── evaluate_utils.py │ ├── select_declare_only.py │ └── select_distribution.py └── train │ ├── evaluate.py │ ├── fine-tune.py │ ├── mixtures.py │ ├── predict.py │ ├── rates.py │ ├── tasks.py │ └── util.py └── utils └── text2class.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | data/* 10 | results/ 11 | old/ 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 114 | .pdm.toml 115 | .pdm-python 116 | .pdm-build/ 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # An Empirical Investigation of Machines' Capabilities for Moral Judgment with the Delphi Experiment 2 | 3 | 4 | 5 | 6 | 7 | 8 | **Authors:** 9 | [Liwei Jiang](https://liweijiang.me), 10 | [Jena D. Hwang](https://jenahwang.github.io), 11 | [Chandra Bhagavatula](https://www.chandrab.page), 12 | [Ronan Le Bras](https://rlebras.github.io), 13 | [Jenny Liang](https://jennyliang.me), 14 | [Sydney Levine](https://sites.google.com/site/sydneymlevine/), 15 | [Jesse Dodge](https://jessedodge.github.io), 16 | [Keisuke Sakaguchi](https://keisuke-sakaguchi.github.io), 17 | [Maxwell Forbes](https://maxwellforbes.com), 18 | [Jack Hessel](https://jmhessel.com), 19 | [Jon Borchardt](https://www.linkedin.com/in/borchardt/), 20 | [Taylor Sorensen](https://tsor13.github.io), 21 | [Saadia Gabriel](https://saadiagabriel.com), 22 | [Yulia Tsvetkov](https://homes.cs.washington.edu/~yuliats/), 23 | [Oren Etzioni](https://homes.cs.washington.edu/~etzioni/site/), 24 | [Maarten Sap](http://maartensap.com), 25 | [Regina Rini](https://reginarini.net), 26 | [Yejin Choi](https://homes.cs.washington.edu/~yejin/), 27 | 28 | 29 | > As our society adopts increasingly powerful AI systems for pervasive use, there are growing concerns about machine morality---or lack thereof. Millions of users already rely upon the outputs of AI systems, such as chatbots, as decision aids. Meanwhile, AI researchers continue to grapple with the challenge of aligning these systems with human morality and values. In response to this challenge, we build and test Delphi, an open-source AI system trained to predict human moral judgments. The computational framework of Delphi is grounded in the philosophical moral framework proposed by the prominent moral philosopher John Rawls. Our results speak to the promises and limits of machine's capabilities to learn about human morality. On the one hand, Delphi demonstrates improved generalization capabilities over those exhibited by off-the-shelf neural language models. At the same time, Delphi's failures also underscore important challenges in this arena. For instance, Delphi has limited cultural awareness and is susceptible to pervasive biases. Despite these shortcomings, we demonstrate several compelling use cases of Delphi, including incorporating it as a component within an ensemble of AI systems. Finally, we computationally demonstrate the potential of Rawls' prospect of hybrid approaches for reliable moral reasoning, inspiring future research in computational morality. 30 | 31 | ## The Theoretical and Computational Frameworks of Delphi 32 | 33 | 39 | 40 | 41 | 42 | > (a) The theoretical framework of ethics proposed by the prominent moral philosopher John Rawls. In 1951, Rawls proposed a “decision procedure of ethics” that takes a bottom-up approach to capture patterns of human ethics via crowd- sourcing moral opinions of a wide variety of people. Later in 1971, Rawls complemented the theoretial procedure with top-down constraints in his most famous work, A Theory of Justice. Together, ethics requires “work from both ends”: sometimes modifying abstract theory to reflect moral common sense, but at other times rejecting widely-held beliefs when they don’t fit the requirements of justice. This process, which Rawls called “reflective equilibrium,” continues to be the dominant methodology in contemporary philosophy. (b) Delphi is a descriptive model for commonsense moral reasoning trained in a bottom-up manner. Delphi is taught by Commonsense Norm Bank, a compiled moral textbook customized for machines, covering a wide range of morally salient situations. Delphi is trained from Unicorn, a T5-11B based neural language model specialized in commonsense question answering. Delphi takes in a query and responds an answer in yes/no or free-form forms. 43 | 44 |
45 | Overview of Commonsense Norm Bank Content 46 | 47 | > Representative N-grams cover topics including people, relationships, actions, life & society, cognition, and others. The lemmatized and normalized 4-grams used for the topic analysis are bolded. Auxiliary words from the original form of data instances that are not used in the topics analysis are unbolded. 48 | 49 | 50 | 51 |
52 | 53 |
54 | Examples from Delphi 55 | 56 | > Delphi shows impressive ability to generalize to unseen situations beyond Commonsense Norm Bank, and is robust to adjust its judgment against changing contexts. Colors of labels indicate Delphi’s classification results (green: positive, gray: neutral, red: negative). Textual labels come from Delphi’s open-text responses. 57 | 58 | 59 | 60 |
61 | 62 | 63 |
64 | Main Results of Delphi 65 | 66 | > (a) Delphi achieves better performance on Norm Bank comparing to GPT-3 baselines. (b) Comparing the effect of the size of the base T5 model. (c) Ablation results showing the scale of training data improves Delphi’s learning. (d) Ablation results showing the compositionality of training instances improves Delphi’s learning. (e) Delphi, with minimal supervisions, outperforms baseline models on hate speech detection under both in-distribution and out-of-distribution settings. (g) Plugging Delphi into language generation models helps improve the prosocial implication scores of the generated stories, without sacrificing the language quality. (g) Delphi outperforms other baselines on transferring knowledge to specific theoretically motivated moral frameworks. 67 | 68 | 69 | 70 |
71 | 72 | 73 |
74 | Social Bias Evaluation Results of Delphi 75 | 76 | > (a) Results for the Universal Declaration of Human Rights probing, including top identities that Delphi shows biases against and their level of biases, and the average % error for each identity group. (b) Delphi and Delphi+’s performance under current-world and ideal-world settings. Statistical significance test is performed between Delphi under the current-world compared to other models or settings. 77 | 78 | 79 | 80 |
81 | 82 | 83 | ## An Illustration of the Delphi-Hybrid Framework and an Example Output Moral Constraint Graph 84 | 85 | 86 | 92 | 93 | 94 | 95 | 96 | > (a) A hybrid system that incorporates an optional symbolically guided rea- soning mechanism to complement the neural language model based Delphi. (b) An example of the moral constraint graph produced by Delphihybrid for the event “Mass genocide for greater good.” Nodes denote judgments derived either from top-down moral principles or bottom-up Delphi. Edges denote logical violations (i.e., identity, entailment, and contradiction) between nodes. ❌ denotes inconsistent nodes identified by the constrained optimization step. Note that each top-down moral principle may result in multiple nodes depending on engineering details (e.g., the same rule “Do not kill” applied at the full event level or constituent level). The final judgment is negative. 97 | 98 | 99 | ## Codebase Structure 100 | 101 | This codebase contains the training and evaluation code for Delphi, Delphi+, and the Delphi-Hybrid system. 102 | 103 | ### Delphi: 104 | 105 | - `src/delphi/evaluate`: scripts for evaluating the Delphi models on the yes/no QA and freeform QA tasks, as well as downstream tasks like Hate Speech Detection. 106 | 107 | - `src/delphi/train`: the scripts for finetuning T5 for Delphi. 108 | 109 | ### Delphi+: 110 | 111 | - `src/delphi_plus/evaluate`: scripts for evaluating the Delphi+ models on the yes/no QA and freeform QA tasks, as well as downstream tasks like Hate Speech Detection. 112 | 113 | - `src/delphi/train`: the scripts for finetuning T5 for Delphi+. 114 | 115 | ### Delphi-Hybrid: 116 | 117 | - `src/delphi_plus/collective_reasoning`: codebase for the collective reasoning component of Delphi-Hybrid. 118 | 119 | - `src/delphi/components`: components of the Delphi-Hybrid system. 120 | 121 | - `src/delphi/prepare_data`: scripts for preparing test data for Delphi-Hybrid experiments. 122 | 123 | ### Datasheet: 124 | 125 | - `data/datasheet.md`: the datasheet for the Commonsense Norm Bank dataset. 126 | 127 | ## Data and Model Access 128 | You can access the Commonsense Norm Bank dataset by [filling out this form](https://forms.gle/VoAVuPUJFNChWhSj8). 129 | 130 | For accessing the Delphi model checkpoints and API calls please feel free to reach out to Liwei Jiang at [lwjiang@cs.washington.edu](lwjiang@cs.washington.edu). 131 | 132 | If you find our paper or data useful, please cite the paper: 133 | ``` 134 | @article{jiang2025delphi, 135 | author = {Liwei Jiang and Jena D. Hwang and Chandra Bhagavatula and Ronan Le Bras and Jenny T. Liang and Sydney Levine and Jesse Dodge and Keisuke Sakaguchi and Maxwell Forbes and Jack Hessel and Jon Borchardt and Taylor Sorensen and Saadia Gabriel and Yulia Tsvetkov and Oren Etzioni and Maarten Sap and Regina Rini and Yejin Choi}, 136 | title = {Investigating machine moral judgement through the Delphi experiment}, 137 | journal = {Nature Machine Intelligence}, 138 | volume = {7}, 139 | pages = {145--160}, 140 | year = {2025}, 141 | doi = {10.1038/s42256-024-00700-y} 142 | } 143 | 144 | ``` 145 | 146 | -------------------------------------------------------------------------------- /assets/bias_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liweijiang/delphi/cc7130578fb10f4db122ac537c1b387bc02b7837/assets/bias_results.png -------------------------------------------------------------------------------- /assets/delphi_hybrid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liweijiang/delphi/cc7130578fb10f4db122ac537c1b387bc02b7837/assets/delphi_hybrid.png -------------------------------------------------------------------------------- /assets/examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liweijiang/delphi/cc7130578fb10f4db122ac537c1b387bc02b7837/assets/examples.png -------------------------------------------------------------------------------- /assets/main_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liweijiang/delphi/cc7130578fb10f4db122ac537c1b387bc02b7837/assets/main_results.png -------------------------------------------------------------------------------- /assets/norm_bank_content.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liweijiang/delphi/cc7130578fb10f4db122ac537c1b387bc02b7837/assets/norm_bank_content.png -------------------------------------------------------------------------------- /assets/overall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liweijiang/delphi/cc7130578fb10f4db122ac537c1b387bc02b7837/assets/overall.png -------------------------------------------------------------------------------- /src/delphi/evaluate/evaluate_dynahate_v9.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import f1_score 2 | import sys 3 | import numpy as np 4 | sys.path.append("script/evaluate") 5 | 6 | from evaluate_utils import * 7 | 8 | def get_gold_class(round_id, data_split): 9 | """ 10 | Get gold inputs and targets class labels; format (1 ) 11 | """ 12 | if data_split == "validation": 13 | data_split = "dev" 14 | data_base_path = f"gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/data/v9_downstream/dynahate/{round_id}/{data_split}.tsv" 15 | df_inputs = pd.read_csv(data_base_path, sep="\t") 16 | 17 | inputs_all = list(df_inputs["inputs"]) 18 | inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all] 19 | 20 | targets_all = list(df_inputs["targets"]) 21 | targets = [int(i.split(" ") 22 | [0].split("")[-1]) for i in targets_all] 23 | return inputs, targets 24 | 25 | 26 | def get_pred_class(bucket, base_path, task_name, check_point, base_model): 27 | """ 28 | Get preds class labels 29 | """ 30 | preds_blob = bucket.get_blob( 31 | base_path + f"{task_name}_{check_point}_predictions") 32 | preds_blob_list = preds_blob.download_as_string().decode( 33 | 'utf-8').split("\n")[1:] 34 | 35 | preds_class = [] 36 | for i in preds_blob_list: 37 | try: 38 | if "v10" in base_model or "v11" in base_model: 39 | preds_class.append( 40 | int(i.split("[/class]")[0].split("[class]")[-1])) 41 | else: 42 | preds_class.append( 43 | int(i.split(" ⁇ /class>")[0].split(" ⁇ class>")[-1])) 44 | except: 45 | print("output form not identifiable:", i) 46 | preds_class.append(99) 47 | return preds_class 48 | 49 | 50 | def main_get_accuracy(base_path, data_split, training_data_type, round_ids=None, check_points=None, 51 | is_print=False, all_list=None, is_save_results=True): 52 | base_path += f"{data_split}_eval/" 53 | bucket_name = base_path.split("/")[0] 54 | result_prefix = "/".join(base_path.split("/")[1:]) 55 | base_path = "/".join(base_path.split("/")[1:]) 56 | base_model = base_path.split("/")[-5] 57 | training_data = base_path.split("/")[-4] 58 | 59 | client = storage.Client() 60 | bucket = client.get_bucket(bucket_name) 61 | 62 | if base_model == "v10-delphi": 63 | base_path = base_path.replace(training_data, "new_" + training_data) 64 | 65 | if check_points == None: 66 | check_points = get_check_points( 67 | client, bucket_name, result_prefix, after_check_point=-1)[1:] 68 | # check_points.sort(reverse=True) 69 | for check_point in check_points: 70 | all_inputs = [] 71 | all_targets = [] 72 | all_preds = [] 73 | all_f1s = [] 74 | all_accs = [] 75 | all_round_ids = [] 76 | for round_id in round_ids: 77 | if training_data_type == "": 78 | task_name = f"dynahate_round_{round_id}" 79 | else: 80 | task_name = f"dynahate_round_{round_id}_{training_data_type}" 81 | 82 | if base_model == "v10-delphi": 83 | task_name = "new_" + task_name 84 | 85 | inputs, targets = get_gold_class(round_id, data_split) 86 | preds = get_pred_class( 87 | bucket, base_path, task_name, check_point, base_model) 88 | if base_model == "v10-delphi": 89 | preds_new = [t - 1 for t in preds] 90 | preds = preds_new 91 | 92 | all_round_ids += [round_id] * len(preds) 93 | all_inputs += inputs 94 | all_targets += targets 95 | all_preds += preds 96 | 97 | f1 = f1_score(targets, preds, average='macro') 98 | acc = get_accuracy(targets, preds, accuracy_type="exact") 99 | 100 | all_f1s.append(f1) 101 | all_accs.append(acc) 102 | if is_print: 103 | print( 104 | f"round ({round_id}) {check_point}: f1 -- {f1} | accuracy -- {acc}") 105 | 106 | f1 = f1_score(all_targets, all_preds, average='macro') 107 | acc = get_accuracy(all_targets, all_preds, accuracy_type="exact") 108 | print(f"round (all) {check_point}: f1 -- {f1} | accuracy -- {acc}") 109 | 110 | if is_save_results: 111 | df_data = pd.DataFrame() 112 | df_data["round_id"] = all_round_ids 113 | df_data["input"] = all_inputs 114 | df_data["target"] = all_targets 115 | df_data["pred"] = all_preds 116 | 117 | df_data.to_csv( 118 | f"{training_data}-{base_model}-{data_split}.tsv", index=False, sep="\t") 119 | 120 | return all_list 121 | 122 | 123 | def main_r1_st(training_data): 124 | base_model_2_ckpt = { 125 | "dynahate_round_1_st": {"v11-delphi-declare": [1290200], 126 | "v10-delphi": [1323500], 127 | "v9-delphi": [1341200], 128 | "v9-delphi-new": [1264700], 129 | "unicorn-pt": [1127000], 130 | "11B": [1102000]}, 131 | 132 | "dynahate_round_1_st_100_shot": {"v11-delphi-declare": [1290200], 133 | "v10-delphi": [1349000], 134 | "v9-delphi": [1325900], 135 | "v9-delphi-new": [1285100], 136 | "unicorn-pt": [1106600], 137 | "11B": [1076500]}, 138 | } 139 | 140 | print(training_data) 141 | # "v11-delphi-declare", "v10-delphi" "v10-delphi", "v9-delphi", "unicorn-pt", "11B" 142 | for base_model in ["v9-delphi-new", "unicorn-pt", "11B"]: 143 | print("=" * 10, base_model) 144 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/" 145 | 146 | check_points = base_model_2_ckpt[training_data][base_model] 147 | main_get_accuracy(base_path, "validation", "st", [1], check_points) 148 | main_get_accuracy(base_path, "test", "st", [ 149 | 1, 2, 3, 4], check_points, is_print=True) 150 | # 151 | 152 | 153 | def main_all_st(training_data): 154 | base_model_2_ckpt = { 155 | "dynahate_all_st": {"v11-delphi-declare": [1290200], # 1290200 156 | "v10-delphi": [1292900], 157 | "v9-delphi": [1402400], 158 | "v9-delphi-new": [1387100], 159 | "unicorn-pt": [1157600], 160 | "11B": [1132600]}, 161 | 162 | "dynahate_all_st_100_shot": {"v11-delphi-declare": [1331000], 163 | "v10-delphi": [1354100], 164 | "v9-delphi": [1392200], 165 | "v9-delphi-new": [1407500], 166 | "unicorn-pt": [1147400], 167 | "11B": [1076500]}, 168 | } 169 | 170 | all_list = [] 171 | 172 | print(training_data) 173 | for base_model in ["v9-delphi-new", "unicorn-pt", "11B"]: 174 | print("=" * 10, base_model) 175 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/" 176 | # check_points = None 177 | check_points = base_model_2_ckpt[training_data][base_model] 178 | main_get_accuracy(base_path, "validation", "st", [ 179 | 1, 2, 3, 4], check_points, is_print=False) 180 | all_list = main_get_accuracy(base_path, "test", "st", [ 181 | 1, 2, 3, 4], check_points, is_print=True, all_list=all_list) 182 | 183 | return all_list 184 | 185 | 186 | if __name__ == "__main__": 187 | main_r1_st("dynahate_round_1_st_100_shot") 188 | main_all_st("dynahate_all_st_100_shot") 189 | -------------------------------------------------------------------------------- /src/delphi/evaluate/evaluate_dynahate_yesno.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import f1_score 2 | import sys 3 | import numpy as np 4 | sys.path.append("script/evaluate") 5 | from evaluate_utils import * 6 | 7 | 8 | def get_gold_class(round_id, data_split, training_data_type): 9 | """ 10 | Get gold inputs and targets class labels; format (1 ) 11 | """ 12 | if data_split == "validation": 13 | data_split = "dev" 14 | data_base_path = f"gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/data/v11_downstream/dynahate/{round_id}_{training_data_type}/{data_split}.tsv" 15 | df_inputs = pd.read_csv(data_base_path, sep="\t") 16 | 17 | inputs_all = list(df_inputs["inputs"]) 18 | inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all] 19 | 20 | targets_all = list(df_inputs["targets"]) 21 | targets = [int(i.split("[/class] [text]")[0].split("[class]")[-1]) 22 | for i in targets_all] 23 | return inputs, targets 24 | 25 | 26 | def get_pred_class(bucket, base_path, task_name, check_point, base_model): 27 | """ 28 | Get preds class labels 29 | """ 30 | preds_blob = bucket.get_blob( 31 | base_path + f"{task_name}_{check_point}_predictions") 32 | preds_blob_list = preds_blob.download_as_string().decode( 33 | 'utf-8').split("\n")[1:] 34 | 35 | preds_class = [] 36 | for i in preds_blob_list: 37 | try: 38 | preds_class.append( 39 | int(i.split("[/class]")[0].split("[class]")[-1])) 40 | except: 41 | print("output form not identifiable:", i) 42 | preds_class.append(99) 43 | return preds_class 44 | 45 | 46 | def main_get_accuracy(base_path, data_split, training_data_type, round_ids=None, check_points=None, is_print=False): 47 | base_path += f"{data_split}_eval/" 48 | bucket_name = base_path.split("/")[0] 49 | result_prefix = "/".join(base_path.split("/")[1:]) 50 | base_path = "/".join(base_path.split("/")[1:]) 51 | base_model = base_path.split("/")[-5] 52 | training_data = base_path.split("/")[-4] 53 | 54 | client = storage.Client() 55 | bucket = client.get_bucket(bucket_name) 56 | 57 | if check_points == None: 58 | check_points = get_check_points( 59 | client, bucket_name, result_prefix, after_check_point=-1)[1:] 60 | 61 | for check_point in check_points: # check_points 62 | all_inputs = [] 63 | all_targets = [] 64 | all_preds = [] 65 | all_f1s = [] 66 | all_accs = [] 67 | for round_id in round_ids: 68 | task_name = f"dynahate_round_{round_id}_{training_data_type}" 69 | 70 | inputs, targets = get_gold_class( 71 | round_id, data_split, training_data_type) 72 | preds = get_pred_class( 73 | bucket, base_path, task_name, check_point, base_model) 74 | targets, preds = remove_unknown_elements(targets, preds) 75 | f1 = f1_score(targets, preds, average='macro') 76 | acc = get_accuracy(targets, preds, accuracy_type="exact") 77 | 78 | all_inputs += inputs 79 | all_targets += targets 80 | all_preds += preds 81 | all_f1s.append(f1) 82 | all_accs.append(acc) 83 | if is_print: 84 | print( 85 | f"round ({round_id}) {check_point}: f1 -- {f1} | accuracy -- {acc}") 86 | 87 | f1 = f1_score(all_targets, all_preds, average='macro') 88 | acc = get_accuracy(all_targets, all_preds, accuracy_type="exact") 89 | print( 90 | f"round {str(round_ids)} {check_point}: f1 -- {f1} | accuracy -- {acc}") 91 | 92 | 93 | def main_r1_yesno(training_data, training_data_type): 94 | base_model_2_ckpt = { 95 | "dynahate_round_1_yesno": {"v11-delphi-declare": [1280000], 96 | "v10-delphi": [], 97 | "v9-delphi": [], 98 | "unicorn-pt": [], 99 | "11B": []}, 100 | 101 | "dynahate_round_1_yesno_100_shot": {"v11-delphi-declare": [1290200], 102 | "v10-delphi": [], 103 | "v9-delphi": [], 104 | "unicorn-pt": [], 105 | "11B": []}, 106 | 107 | "dynahate_round_1_yesno_class_only": {"v11-delphi-declare": [1290200], 108 | "v10-delphi": [], 109 | "v9-delphi": [], 110 | "unicorn-pt": [1132100], 111 | "11B": []}, 112 | 113 | "dynahate_round_1_yesno_class_only_100_shot": {"v11-delphi-declare": [1234100], 114 | "v10-delphi": [], 115 | "v9-delphi": [], 116 | "unicorn-pt": [1035200], 117 | "11B": []}, 118 | 119 | "dynahate_round_1_discriminate": {"v11-delphi-declare": [1331000], 120 | "v10-delphi": [], 121 | "v9-delphi": [], 122 | "unicorn-pt": [], 123 | "11B": []}, 124 | 125 | "dynahate_round_1_discriminate_100_shot": {"v11-delphi-declare": [1269800], 126 | "v10-delphi": [], 127 | "v9-delphi": [], 128 | "unicorn-pt": [1086200], 129 | "11B": []}, 130 | } 131 | 132 | print(training_data) 133 | # "v11-delphi-declare", "v10-delphi" "v10-delphi", "v9-delphi", "unicorn-pt", "11B" 134 | for base_model in ["v11-delphi-declare", "unicorn-pt"]: 135 | print("=" * 10, base_model) 136 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/" 137 | check_points = None 138 | main_get_accuracy(base_path, "validation", training_data_type, round_ids=[ 139 | 1], check_points=check_points, is_print=False) 140 | 141 | 142 | def main_all_yesno(training_data, training_data_type): 143 | base_model_2_ckpt = { 144 | "dynahate_all_yesno": {"v11-delphi-declare": [], 145 | "v10-delphi": [], 146 | "v9-delphi": [], 147 | "unicorn-pt": [], 148 | "11B": []}, 149 | 150 | "dynahate_all_yesno_100_shot": {"v11-delphi-declare": [1254500], 151 | "v10-delphi": [], 152 | "v9-delphi": [], 153 | "unicorn-pt": [], 154 | "11B": []}, 155 | 156 | "dynahate_all_yesno_class_only": {"v11-delphi-declare": [], 157 | "v10-delphi": [], 158 | "v9-delphi": [], 159 | "unicorn-pt": [], 160 | "11B": []}, 161 | 162 | "dynahate_all_yesno_class_only_100_shot": {"v11-delphi-declare": [1336100], 163 | "v10-delphi": [], 164 | "v9-delphi": [], 165 | "unicorn-pt": [1101500], 166 | "11B": []}, 167 | 168 | "dynahate_all_discriminate": {"v11-delphi-declare": [1351400], 169 | "v10-delphi": [], 170 | "v9-delphi": [], 171 | "unicorn-pt": [], 172 | "11B": []}, 173 | 174 | "dynahate_all_discriminate_100_shot": {"v11-delphi-declare": [1239200], 175 | "v10-delphi": [], 176 | "v9-delphi": [], 177 | "unicorn-pt": [1070900], 178 | "11B": []}, 179 | } 180 | 181 | print(training_data) 182 | for base_model in ["v11-delphi-declare", 183 | "unicorn-pt"]: 184 | print("=" * 10, base_model) 185 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/" 186 | check_points = None 187 | main_get_accuracy(base_path, "validation", training_data_type, round_ids=[ 188 | 1, 2, 3, 4], check_points=check_points, is_print=False) 189 | 190 | 191 | if __name__ == "__main__": 192 | main_r1_yesno() 193 | -------------------------------------------------------------------------------- /src/delphi/evaluate/evaluate_ethics_converted.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("script/evaluate") 3 | from evaluate_utils import * 4 | 5 | acc_type = {"ethics_deontology": 4, 6 | "ethics_justice": 4, 7 | "ethics_virtue": 5, 8 | "ethics_util": "exact", 9 | "ethics_cm": "exact"} 10 | 11 | 12 | def get_gold_class(training_data, data_split): 13 | """ 14 | Get gold inputs and targets class labels; format ([class]1[/class] [text] [/text]) 15 | """ 16 | if data_split == "validation": 17 | data_split = "test" 18 | elif data_split == "test": 19 | data_split = "test_hard" 20 | 21 | task_name = training_data.split("_")[1] 22 | 23 | data_base_path = f"gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/data/v11_downstream/ethics/ethics_converted/{task_name}/{data_split}.tsv" 24 | df_inputs = pd.read_csv(data_base_path, sep="\t") 25 | df_inputs["targets"] = df_inputs["targets"].astype(str) 26 | 27 | inputs_all = list(df_inputs["inputs"]) 28 | inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all] 29 | 30 | targets_all = list(df_inputs["targets"]) 31 | targets = [int(i.split("[/class]")[0].split("[class]")[-1]) 32 | for i in targets_all] 33 | return inputs, targets 34 | 35 | 36 | def get_pred_class(bucket, base_path, training_data, check_point, base_model): 37 | """ 38 | Get preds class labels 39 | """ 40 | preds_blob = bucket.get_blob( 41 | base_path + f"{training_data}_{check_point}_predictions") 42 | preds_blob_list = preds_blob.download_as_string().decode( 43 | 'utf-8').split("\n")[1:] 44 | 45 | preds_class = [] 46 | for i in preds_blob_list: 47 | try: 48 | if "v10" in base_model or "v11" in base_model or "unicorn-pt" in base_model: 49 | preds_class.append( 50 | int(i.split("[/class]")[0].split("[class]")[-1])) 51 | else: 52 | preds_class.append( 53 | int(i.split(" ⁇ /class>")[0].split(" ⁇ class>")[-1])) 54 | except: 55 | print("output form not identifiable:", i) 56 | preds_class.append(99) 57 | return preds_class 58 | 59 | 60 | def get_ethics_accuracy(targets, preds, accuracy_type="exact"): 61 | accuracies = [] 62 | for i in range(len(targets)): 63 | t_c = targets[i] 64 | p_c = preds[i] 65 | if t_c != 99 and p_c != 99: 66 | if accuracy_type == "exact": 67 | accuracies.append(int(t_c == p_c)) 68 | elif accuracy_type == "non conflict": 69 | accuracies.append(int(t_c * p_c >= 0)) 70 | else: 71 | accuracies.append( 72 | not int((t_c == -1 or p_c == -1) and (t_c * p_c != 1))) 73 | 74 | if type(accuracy_type) == type(0): 75 | group_acc = [] 76 | for i in range(0, len(accuracies), accuracy_type): 77 | if (accuracies[i] * accuracies[i + 1] * accuracies[i + 2] * accuracies[i + 3]) == 1: 78 | group_acc.append(1) 79 | else: 80 | group_acc.append(0) 81 | return sum(group_acc) / len(group_acc) 82 | else: 83 | return sum(accuracies) / len(accuracies) 84 | 85 | 86 | def main_get_accuracy(base_path, data_split, check_points=None): 87 | base_model = base_path.split("/")[-4] 88 | training_data = base_path.split("/")[-3] 89 | if base_model == "v10-delphi": 90 | base_path = base_path.replace(training_data, "new_" + training_data) 91 | 92 | base_path += f"{data_split}_eval/" 93 | bucket_name = base_path.split("/")[0] 94 | result_prefix = "/".join(base_path.split("/")[1:]) 95 | base_path = "/".join(base_path.split("/")[1:]) 96 | training_data = base_path.split("/")[-4] 97 | if "100_shot" in training_data: 98 | training_data = training_data.replace("_100_shot", "") 99 | base_model = base_path.split("/")[-5] 100 | 101 | client = storage.Client() 102 | bucket = client.get_bucket(bucket_name) 103 | 104 | if check_points == None: 105 | check_points = get_check_points( 106 | client, bucket_name, result_prefix, after_check_point=-1)[1:] 107 | 108 | for check_point in check_points: 109 | inputs, targets = get_gold_class(training_data, data_split) 110 | preds = get_pred_class( 111 | bucket, base_path, training_data, check_point, base_model) 112 | if base_model == "v10-delphi" and "util" not in training_data: 113 | preds_new = [t - 1 for t in preds] 114 | preds = preds_new 115 | 116 | index_to_remove = [] 117 | for i, p in enumerate(preds): 118 | if p not in [-1, 1]: 119 | index_to_remove.append(i) 120 | print("remove:", i, p) 121 | targets = np.delete(targets, index_to_remove).tolist() 122 | preds = np.delete(preds, index_to_remove).tolist() 123 | 124 | acc = get_ethics_accuracy( 125 | targets, preds, accuracy_type=acc_type["_".join(training_data.split("_")[:2])]) 126 | print(f"{check_point}: accuracy -- {acc}") 127 | 128 | 129 | def main_all(training_data): 130 | print("-" * 30, f"{training_data}", "-" * 30) 131 | base_model_2_ckpt = {"ethics_cm": {"v10-delphi": [], 132 | "v9-delphi": [1448300], 133 | "unicorn-pt": [1213700], 134 | "11B": []}, 135 | 136 | "ethics_deontology": {"v10-delphi": [], 137 | "v9-delphi": [], 138 | "unicorn-pt": [], 139 | "11B": []}, 140 | 141 | "ethics_justice": {"v10-delphi": [], 142 | "v9-delphi": [1356500], 143 | "unicorn-pt": [1137200], 144 | "11B": []}, 145 | 146 | "ethics_virtue": {"v10-delphi": [], 147 | "v9-delphi": [1356500], 148 | "unicorn-pt": [1050500], 149 | "11B": []}, 150 | 151 | "ethics_util": {"v10-delphi": [], 152 | "v9-delphi": [1351400], 153 | "unicorn-pt": [1035200], 154 | "11B": []}, 155 | 156 | 157 | "ethics_cm_100_shot": {"v10-delphi": [], 158 | "v9-delphi": [1315700], 159 | "unicorn-pt": [1055600], 160 | "11B": []}, 161 | 162 | "ethics_deontology_100_shot": {"v10-delphi": [], 163 | "v9-delphi": [1274900], 164 | "unicorn-pt": [1030100], 165 | "11B": []}, 166 | 167 | "ethics_justice_100_shot": {"v10-delphi": [], 168 | "v9-delphi": [1290200], 169 | "unicorn-pt": [1065800], 170 | "11B": []}, 171 | 172 | "ethics_virtue_100_shot": {"v10-delphi": [], 173 | "v9-delphi": [1295300], 174 | "unicorn-pt": [1060700], 175 | "11B": []}, 176 | 177 | "ethics_util_100_shot": {"v10-delphi": [], 178 | "v9-delphi": [1295300], 179 | "unicorn-pt": [1076000], 180 | "11B": []}, 181 | } 182 | 183 | for base_model in ["v11-delphi-declare", "unicorn-pt"]: 184 | print("=" * 10, base_model) 185 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/" 186 | check_points = None 187 | main_get_accuracy(base_path, "validation", check_points) 188 | 189 | 190 | if __name__ == "__main__": 191 | main_all("ethics_cm_converted_class_only") 192 | main_all("ethics_deontotlogy_converted_class_only") 193 | main_all("ethics_justice_converted_class_only") 194 | main_all("ethics_virtue_converted_class_only") 195 | -------------------------------------------------------------------------------- /src/delphi/evaluate/evaluate_ethics_v9.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("script/evaluate") 3 | from evaluate_utils import * 4 | 5 | 6 | acc_type = {"ethics_deontology": 4, 7 | "ethics_justice": 4, 8 | "ethics_virtue": 5, 9 | "ethics_util": "exact", 10 | "ethics_cm": "exact"} 11 | 12 | 13 | def get_gold_class(training_data, data_split): 14 | """ 15 | Get gold inputs and targets class labels; format (1 ) 16 | """ 17 | if data_split == "validation": 18 | data_split = "test" 19 | elif data_split == "test": 20 | data_split = "test_hard" 21 | 22 | task_name = training_data.split("_")[-1] 23 | 24 | data_base_path = f"gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/data/v9_downstream/ethics/ethics_st/{task_name}/{data_split}.tsv" 25 | df_inputs = pd.read_csv(data_base_path, sep="\t") 26 | df_inputs["targets"] = df_inputs["targets"].astype(str) 27 | 28 | inputs_all = list(df_inputs["inputs"]) 29 | inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all] 30 | 31 | targets_all = list(df_inputs["targets"]) 32 | targets = [int(i.split("")[0].split("")[-1]) 33 | for i in targets_all] 34 | return inputs, targets 35 | 36 | 37 | def get_pred_class(bucket, base_path, training_data, check_point, base_model): 38 | """ 39 | Get preds class labels 40 | """ 41 | preds_blob = bucket.get_blob( 42 | base_path + f"{training_data}_{check_point}_predictions") 43 | preds_blob_list = preds_blob.download_as_string().decode( 44 | 'utf-8').split("\n")[1:] 45 | 46 | preds_class = [] 47 | for i in preds_blob_list: 48 | try: 49 | if "v10" in base_model: 50 | preds_class.append( 51 | int(i.split("[/class]")[0].split("[class]")[-1])) 52 | else: 53 | preds_class.append( 54 | int(i.split(" ⁇ /class>")[0].split(" ⁇ class>")[-1])) 55 | except: 56 | print("output form not identifiable:", i) 57 | preds_class.append(99) 58 | return preds_class 59 | 60 | 61 | def get_ethics_accuracy(targets, preds, accuracy_type="exact"): 62 | accuracies = [] 63 | for i in range(len(targets)): 64 | t_c = targets[i] 65 | p_c = preds[i] 66 | if t_c != 99 and p_c != 99: 67 | if accuracy_type == "exact": 68 | accuracies.append(int(t_c == p_c)) 69 | elif accuracy_type == "non conflict": 70 | accuracies.append(int(t_c * p_c >= 0)) 71 | else: 72 | accuracies.append( 73 | not int((t_c == -1 or p_c == -1) and (t_c * p_c != 1))) 74 | 75 | if accuracy_type == 4: 76 | group_acc = [] 77 | for i in range(0, len(accuracies), accuracy_type): 78 | if (accuracies[i] * accuracies[i + 1] * accuracies[i + 2] * accuracies[i + 3]) == 1: 79 | group_acc += [1] * accuracy_type 80 | else: 81 | group_acc += [0] * accuracy_type 82 | return sum(group_acc) / len(group_acc) 83 | elif accuracy_type == 5: 84 | group_acc = [] 85 | for i in range(0, len(accuracies), accuracy_type): 86 | if (accuracies[i] * accuracies[i + 1] * accuracies[i + 2] * accuracies[i + 3] * accuracies[i + 4]) == 1: 87 | group_acc += [1] * accuracy_type 88 | else: 89 | group_acc += [0] * accuracy_type 90 | return sum(group_acc) / len(group_acc) 91 | else: 92 | return sum(accuracies) / len(accuracies) 93 | 94 | 95 | def main_get_accuracy(base_path, data_split, check_points=None, is_save_results=True): 96 | base_model = base_path.split("/")[-4] 97 | training_data = base_path.split("/")[-3] 98 | if base_model == "v10-delphi": 99 | base_path = base_path.replace(training_data, "new_" + training_data) 100 | 101 | base_path += f"{data_split}_eval/" 102 | bucket_name = base_path.split("/")[0] 103 | result_prefix = "/".join(base_path.split("/")[1:]) 104 | base_path = "/".join(base_path.split("/")[1:]) 105 | training_data = base_path.split("/")[-4] 106 | if "100_shot" in training_data: 107 | training_data = training_data.replace("_100_shot", "") 108 | base_model = base_path.split("/")[-5] 109 | 110 | client = storage.Client() 111 | bucket = client.get_bucket(bucket_name) 112 | 113 | if check_points == None: 114 | check_points = get_check_points( 115 | client, bucket_name, result_prefix, after_check_point=-1)[1:] 116 | 117 | for check_point in check_points: 118 | inputs, targets = get_gold_class(training_data, data_split) 119 | preds = get_pred_class( 120 | bucket, base_path, training_data, check_point, base_model) 121 | if base_model == "v10-delphi" and "util" not in training_data: 122 | preds_new = [t - 1 for t in preds] 123 | preds = preds_new 124 | index_to_remove = [] 125 | for i, p in enumerate(preds): 126 | if p not in [-1, 1, 2]: 127 | print("bad:", i, p) 128 | 129 | acc = get_ethics_accuracy( 130 | targets, preds, accuracy_type=acc_type[training_data.replace("new_", "")]) 131 | print(f"{check_point}: accuracy -- {acc}") 132 | 133 | if is_save_results: 134 | df_data = pd.DataFrame() 135 | df_data["input"] = inputs 136 | df_data["target"] = targets 137 | df_data["pred"] = preds 138 | 139 | df_data.to_csv( 140 | f"100_shot-{training_data}-{base_model}-{data_split}.tsv", index=False, sep="\t") 141 | 142 | 143 | def main_all(training_data): 144 | print("-" * 30, f"{training_data}", "-" * 30) 145 | base_model_2_ckpt = {"ethics_cm": {"v10-delphi": [], 146 | "v9-delphi": [1448300], 147 | "v9-delphi-new": [1290200], 148 | "unicorn-pt": [1213700], 149 | "11B": [1153000]}, 150 | 151 | "ethics_deontology": {"v10-delphi": [], 152 | "v9-delphi": [1361600], 153 | "v9-delphi-new": [1315700], 154 | "unicorn-pt": [1157600], 155 | "11B": [1387600]}, 156 | 157 | "ethics_justice": {"v10-delphi": [], 158 | "v9-delphi": [1356500], 159 | "v9-delphi-new": [1249400], 160 | "unicorn-pt": [1137200], 161 | "11B": [1112200]}, 162 | 163 | "ethics_virtue": {"v10-delphi": [], 164 | "v9-delphi": [1356500], 165 | "v9-delphi-new": [1336100], 166 | "unicorn-pt": [1050500], 167 | "11B": [1040800]}, 168 | 169 | "ethics_util": {"v10-delphi": [], 170 | "v9-delphi": [1351400], 171 | "v9-delphi-new": [1249400], 172 | "unicorn-pt": [1035200], 173 | "11B": [1015300]}, 174 | 175 | 176 | "ethics_cm_100_shot": {"v10-delphi": [], 177 | "v9-delphi": [1315700], 178 | "v9-delphi-new": [1264700], 179 | "unicorn-pt": [1055600], 180 | "11B": [1112200]}, 181 | 182 | "ethics_deontology_100_shot": {"v10-delphi": [], 183 | "v9-delphi": [1274900], 184 | "v9-delphi-new": [1315700], 185 | "unicorn-pt": [1030100], 186 | "11B": [1275400]}, 187 | 188 | "ethics_justice_100_shot": {"v10-delphi": [], 189 | "v9-delphi": [1290200], 190 | "v9-delphi-new": [1331000], 191 | "unicorn-pt": [1065800], 192 | "11B": [1504900]}, 193 | 194 | "ethics_virtue_100_shot": {"v10-delphi": [], 195 | "v9-delphi": [1295300], 196 | "v9-delphi-new": [1264700], 197 | "unicorn-pt": [1060700], 198 | "11B": [1020400]}, 199 | 200 | "ethics_util_100_shot": {"v10-delphi": [], 201 | "v9-delphi": [1295300], 202 | "v9-delphi-new": [1341200], 203 | "unicorn-pt": [1076000], 204 | "11B": [1275400]}, 205 | } 206 | 207 | for base_model in ["v9-delphi-new", "unicorn-pt", "11B"]: 208 | print("=" * 10, base_model) 209 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/" 210 | check_points = base_model_2_ckpt[training_data][base_model] 211 | main_get_accuracy(base_path, "validation", check_points) 212 | main_get_accuracy(base_path, "test", check_points) 213 | 214 | 215 | if __name__ == "__main__": 216 | main_all("ethics_cm_100_shot") 217 | main_all("ethics_deontology_100_shot") 218 | main_all("ethics_justice_100_shot") 219 | main_all("ethics_virtue_100_shot") 220 | main_all("ethics_util_100_shot") 221 | -------------------------------------------------------------------------------- /src/delphi/evaluate/evaluate_latenthatred_v9.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import precision_recall_fscore_support 2 | import sys 3 | import numpy as np 4 | sys.path.append("script/evaluate") 5 | from evaluate_utils import * 6 | 7 | 8 | def get_gold_class(data_split): 9 | """ 10 | Get gold inputs and targets class labels; format (1 ) 11 | """ 12 | if data_split == "validation": 13 | data_split = "dev" 14 | 15 | data_base_path = f"gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/data/v9_downstream/latenthatred/latenthatred_st/{data_split}.tsv" 16 | df_inputs = pd.read_csv(data_base_path, sep="\t") 17 | df_inputs["targets"] = df_inputs["targets"].astype(str) 18 | 19 | inputs_all = list(df_inputs["inputs"]) 20 | inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all] 21 | 22 | targets_all = list(df_inputs["targets"]) 23 | targets = [int(i.split("")[0].split("")[-1]) 24 | for i in targets_all] 25 | return inputs, targets 26 | 27 | 28 | def get_pred_class(bucket, base_path, training_data, check_point, base_model): 29 | """ 30 | Get preds class labels 31 | """ 32 | preds_blob = bucket.get_blob( 33 | base_path + f"{training_data}_{check_point}_predictions") 34 | preds_blob_list = preds_blob.download_as_string().decode( 35 | 'utf-8').split("\n")[1:] 36 | 37 | preds_class = [] 38 | for i in preds_blob_list: 39 | try: 40 | if "v10" in base_model or "v11" in base_model: 41 | preds_class.append( 42 | int(i.split("[/class]")[0].split("[class]")[-1])) 43 | else: 44 | preds_class.append( 45 | int(i.split(" ⁇ /class>")[0].split(" ⁇ class>")[-1])) 46 | except: 47 | print("output form not identifiable:", i) 48 | preds_class.append(99) 49 | return preds_class 50 | 51 | 52 | def main_get_accuracy(base_path, data_split, check_points=None, is_save_results=True): 53 | base_model = base_path.split("/")[-4] 54 | training_data = base_path.split("/")[-3] 55 | if base_model == "v10-delphi": 56 | base_path = base_path.replace(training_data, "new_" + training_data) 57 | task_name = "new_latenthatred" 58 | else: 59 | task_name = "latenthatred" 60 | 61 | base_path += f"{data_split}_eval/" 62 | bucket_name = base_path.split("/")[0] 63 | result_prefix = "/".join(base_path.split("/")[1:]) 64 | base_path = "/".join(base_path.split("/")[1:]) 65 | # training_data = base_path.split("/")[-4] 66 | base_model = base_path.split("/")[-5] 67 | 68 | client = storage.Client() 69 | bucket = client.get_bucket(bucket_name) 70 | 71 | if check_points == None: 72 | check_points = get_check_points( 73 | client, bucket_name, result_prefix, after_check_point=-1)[1:] 74 | 75 | for check_point in check_points: 76 | inputs, targets = get_gold_class(data_split) 77 | preds = get_pred_class( 78 | bucket, base_path, task_name, check_point, base_model) 79 | if base_model == "v10-delphi": 80 | preds_new = [t - 1 for t in preds] 81 | preds = preds_new 82 | 83 | if is_save_results: 84 | df_data = pd.DataFrame() 85 | df_data["input"] = inputs 86 | df_data["target"] = targets 87 | df_data["pred"] = preds 88 | 89 | df_data.to_csv( 90 | f"zero_shot-{training_data}-{base_model}-{data_split}.tsv", index=False, sep="\t") 91 | 92 | 93 | def main_all_0001(training_data): 94 | print("-" * 30, f"{training_data}", "-" * 30) 95 | base_model_2_ckpt = {"latenthatred": { 96 | "v10-delphi": [], 97 | "v9-delphi": [], 98 | "unicorn-pt": [], 99 | "11B": [], }, 100 | 101 | "latenthatred_100_shot": { 102 | "v10-delphi": [], 103 | "v9-delphi": [], 104 | "unicorn-pt": [], 105 | "11B": [], } 106 | } 107 | 108 | for base_model in ["v9-delphi", "unicorn-pt"]: 109 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0001_bs-16/" 110 | check_points = None 111 | main_get_accuracy(base_path, "validation", check_points) 112 | 113 | 114 | def main_zero_shot(training_data): 115 | print("-" * 30, f"{training_data}", "-" * 30) 116 | base_model_2_ckpt = {"dynahate_round_1_st": {"v10-delphi": [], 117 | "v9-delphi": [1341200], 118 | "v9-delphi-new": [1264700], 119 | "unicorn-pt": [1127000], 120 | "11B": [1102000]}, 121 | 122 | "dynahate_all_st": {"v10-delphi": [], 123 | "v9-delphi": [1402400], 124 | "v9-delphi-new": [1387100], 125 | "unicorn-pt": [1157600], 126 | "11B": [1132600]}, 127 | 128 | "dynahate_round_1_st_100_shot": {"v10-delphi": [], 129 | "v9-delphi": [1325900], 130 | "unicorn-pt": [1106600], 131 | "11B": [1076500]}, 132 | 133 | "dynahate_all_st_100_shot": {"v10-delphi": [], 134 | "v9-delphi": [1392200], 135 | "unicorn-pt": [1147400], 136 | "11B": [1076500]} 137 | } 138 | 139 | for base_model in ["v9-delphi-new", "unicorn-pt", "11B"]: 140 | print("=" * 10, base_model) 141 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/" 142 | check_points = base_model_2_ckpt[training_data][base_model] 143 | main_get_accuracy(base_path, "test", check_points) 144 | 145 | 146 | def main_all(training_data): 147 | print("-" * 30, f"{training_data}", "-" * 30) 148 | base_model_2_ckpt = {"latenthatred": {"v11-delphi-declare": [1356500], 149 | "v10-delphi": [1318400], 150 | "v9-delphi": [1397300], 151 | "v9-delphi-new": [1300400], 152 | "unicorn-pt": [1157600], 153 | "11B": [1071400], }, 154 | 155 | "latenthatred_100_shot": {"v11-delphi-declare": [1274900], 156 | "v10-delphi": [1328600], 157 | "v9-delphi": [1285100], 158 | "v9-delphi-new": [1259600], 159 | "unicorn-pt": [1055600], 160 | "11B": [1102000], } 161 | } 162 | 163 | for base_model in ["v9-delphi-new", "unicorn-pt", "11B"]: 164 | print("=" * 10, base_model) 165 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/" 166 | check_points = base_model_2_ckpt[training_data][base_model] 167 | main_get_accuracy(base_path, "test", check_points) 168 | 169 | 170 | if __name__ == "__main__": 171 | main_zero_shot("dynahate_round_1_st") 172 | -------------------------------------------------------------------------------- /src/delphi/evaluate/evaluate_latenthatred_yesno.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import precision_recall_fscore_support 2 | import sys 3 | import numpy as np 4 | sys.path.append("script/evaluate") 5 | from evaluate_utils import * 6 | 7 | 8 | def get_gold_class(data_split): 9 | """ 10 | Get gold inputs and targets class labels; format (1 ) 11 | """ 12 | if data_split == "validation": 13 | data_split = "dev" 14 | 15 | data_base_path = f"gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/data/v11_downstream/latenthatred/latenthatred_yesno/{data_split}.tsv" 16 | df_inputs = pd.read_csv(data_base_path, sep="\t") 17 | df_inputs["targets"] = df_inputs["targets"].astype(str) 18 | 19 | inputs_all = list(df_inputs["inputs"]) 20 | inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all] 21 | 22 | targets_all = list(df_inputs["targets"]) 23 | targets = [int(i.split("[/class]")[0].split("[class]")[-1]) 24 | for i in targets_all] 25 | return inputs, targets 26 | 27 | 28 | def get_pred_class(bucket, base_path, training_data, check_point, base_model): 29 | """ 30 | Get preds class labels 31 | """ 32 | preds_blob = bucket.get_blob( 33 | base_path + f"{training_data}_{check_point}_predictions") 34 | preds_blob_list = preds_blob.download_as_string().decode( 35 | 'utf-8').split("\n")[1:] 36 | 37 | preds_class = [] 38 | for i in preds_blob_list: 39 | try: 40 | preds_class.append( 41 | int(i.split("[/class]")[0].split("[class]")[-1])) 42 | except: 43 | print("output form not identifiable:", i) 44 | preds_class.append(99) 45 | return preds_class 46 | 47 | 48 | def main_get_accuracy(base_path, data_split, check_points=None): 49 | base_model = base_path.split("/")[-4] 50 | training_data = base_path.split("/")[-3] 51 | 52 | base_path += f"{data_split}_eval/" 53 | bucket_name = base_path.split("/")[0] 54 | result_prefix = "/".join(base_path.split("/")[1:]) 55 | base_path = "/".join(base_path.split("/")[1:]) 56 | training_data = base_path.split("/")[-4] 57 | base_model = base_path.split("/")[-5] 58 | 59 | if "100_shot" in training_data: 60 | training_data = training_data.replace("_100_shot", "") 61 | 62 | client = storage.Client() 63 | bucket = client.get_bucket(bucket_name) 64 | 65 | if check_points == None: 66 | check_points = get_check_points( 67 | client, bucket_name, result_prefix, after_check_point=-1)[1:] 68 | 69 | for check_point in check_points: 70 | inputs, targets = get_gold_class(data_split) 71 | preds = get_pred_class( 72 | bucket, base_path, training_data, check_point, base_model) 73 | targets, preds = remove_unknown_elements(targets, preds) 74 | scores = precision_recall_fscore_support( 75 | targets, preds, average='binary') 76 | acc = get_accuracy(targets, preds, accuracy_type="exact") 77 | 78 | print(check_point, scores, acc) 79 | 80 | 81 | def main_all(training_data): 82 | print("-" * 30, f"{training_data}", "-" * 30) 83 | base_model_2_ckpt = {"latenthatred": {"v11-delphi-declare": [], 84 | "v10-delphi": [], 85 | "v9-delphi": [], 86 | "unicorn-pt": [], 87 | "11B": [], }, 88 | 89 | "latenthatred_100_shot": {"v11-delphi-declare": [], 90 | "v10-delphi": [], 91 | "v9-delphi": [], 92 | "unicorn-pt": [], 93 | "11B": [], } 94 | } 95 | 96 | for base_model in ["v11-delphi-declare", "unicorn-pt"]: 97 | print("=" * 10, base_model) 98 | base_path = f"ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/finetune/{base_model}/{training_data}/lr-0.0002_bs-16/" 99 | check_points = None 100 | main_get_accuracy(base_path, "validation", check_points) 101 | 102 | 103 | if __name__ == "__main__": 104 | main_all("latenthatred_yesno") 105 | main_all("latenthatred_yesno_class_only") 106 | -------------------------------------------------------------------------------- /src/delphi/train/evaluate.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | """ 4 | Evaluate the model checkpoint 5 | """ 6 | 7 | import mixtures 8 | import tasks 9 | import t5 10 | import os 11 | import sys 12 | import util 13 | import seqio 14 | import click 15 | import logging 16 | import tensorflow.compat.v1 as tf 17 | 18 | print("python", sys.version) 19 | print("t5", t5.__version__) 20 | print("tf", tf.__version__) 21 | print("seqio", seqio.__version__) 22 | 23 | tf.disable_v2_behavior() 24 | 25 | # N.B. We must import tasks and mixtures here so that they are registered and available for evaluation. 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | @click.command() 31 | @click.argument("mixture", type=str) 32 | @click.argument("results_dir", type=str) 33 | # The name of the TPU. Defaults to the TPU_NAME environment variable. 34 | @click.argument("tpu-name", type=str) 35 | # The topology of the TPU. Defaults to the TPU_TOPOLOGY environment variable. 36 | @click.argument("tpu-topology", type=str) 37 | @click.argument("split", type=str) 38 | @click.argument("checkpoint", type=int) 39 | @click.option( 40 | "--model-parallelism", 41 | type=int, 42 | default=8, 43 | help="The degree of model parallelism to use. Defaults to 8.", 44 | ) 45 | def evaluate( 46 | mixture: str, 47 | results_dir: str, 48 | split: str, 49 | checkpoint: int, 50 | model_parallelism: int, 51 | tpu_name: str, 52 | tpu_topology: str, 53 | ) -> None: 54 | """ 55 | Evaluate the model located at RESULTS_DIR on MIXTURE. 56 | """ 57 | 58 | print(tpu_name) 59 | print(tpu_topology) 60 | 61 | # Initialize arguments 62 | if tpu_topology == "v3-32": 63 | batch_size = 16 64 | model_parallelism = 8 65 | elif tpu_topology == "v3-8": 66 | batch_size = 8 67 | model_parallelism = 8 68 | else: 69 | print("ERROR: tpu_topology invalid") 70 | return 71 | 72 | # Validate arguments 73 | util.validate_path(results_dir) 74 | 75 | checkpoints = util.get_result_check_points( 76 | results_dir, split, "ethics_virtue") 77 | checkpoints = util.get_result_check_points( 78 | results_dir, split, "latenthatred") 79 | 80 | print("-" * 10, "checkpoints todo", "-" * 10) 81 | 82 | if checkpoint == 100: 83 | checkpoints_to_eval = None 84 | elif checkpoint == 0: 85 | checkpoints_to_eval = checkpoints 86 | else: 87 | checkpoints_to_eval = [checkpoint] 88 | 89 | print(checkpoints_to_eval) 90 | 91 | # Run evaluation 92 | model = t5.models.MtfModel( 93 | model_dir=results_dir, 94 | tpu=tpu_name, 95 | tpu_topology=tpu_topology, 96 | model_parallelism=model_parallelism, 97 | batch_size=batch_size, 98 | sequence_length={"inputs": 512, "targets": 128}, 99 | learning_rate_schedule=None, 100 | save_checkpoints_steps=5000, 101 | keep_checkpoint_max=None, 102 | iterations_per_loop=100, 103 | ) 104 | 105 | model.eval( 106 | mixture_or_task_name=mixture, 107 | checkpoint_steps=checkpoints_to_eval, 108 | split=split, 109 | ) 110 | 111 | 112 | if __name__ == "__main__": 113 | evaluate() 114 | -------------------------------------------------------------------------------- /src/delphi/train/fine-tune.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | """Fine-tune T5 based models.""" 4 | 5 | import mixtures 6 | import tasks 7 | import warnings 8 | import util 9 | import t5 10 | import sys 11 | import seqio 12 | import click 13 | import logging 14 | import tensorflow.compat.v1 as tf 15 | 16 | print("python", sys.version) 17 | print("t5", t5.__version__) 18 | print("tf", tf.__version__) 19 | print("seqio", seqio.__version__) 20 | 21 | # We must import tasks and mixtures here so that the tasks and mixtures are registered and available for training. 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | v = tf.compat.v1.logging.FATAL 26 | tf.compat.v1.logging.set_verbosity(v) 27 | tf.disable_v2_behavior() 28 | 29 | config = tf.ConfigProto() 30 | config.gpu_options.allow_growth = True 31 | session = tf.InteractiveSession(config=config) 32 | 33 | warnings.filterwarnings("ignore", category=DeprecationWarning) 34 | 35 | PRETRAINED_MODELS = { 36 | "small": ("gs://t5-data/pretrained_models/small/", -1), 37 | "base": ("gs://t5-data/pretrained_models/base/", -1), 38 | "large": ("gs://t5-data/pretrained_models/large/", -1), 39 | "3B": ("gs://t5-data/pretrained_models/3B/", -1), 40 | "11B": ("gs://t5-data/pretrained_models/11B/", -1), 41 | "unicorn-pt": ("gs://ai2-mosaic-public/projects/rainbow/v1.0/unicorns/lr-2e-3_batch-size-32/", -1), 42 | "v9-delphi": ("gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/v9/unicorn-pt/sbic_commonsense_morality_joint_all_proportional/lr-0.0001_bs-16/", 1264700), 43 | "v9-delphi-new": ("gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/v9/unicorn-pt/sbic_commonsense_morality_joint_all_proportional/lr-0.0001_bs-16/", 1239200), 44 | "v9-delphi-large": ("gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/v9/large/sbic_commonsense_morality_joint_all_proportional/lr-0.0001_bs-8/", 1643100), 45 | } 46 | 47 | 48 | @click.command() 49 | @click.argument("mixture", type=str) 50 | @click.argument("results_dir", type=str) 51 | # The name of the TPU. Defaults to the TPU_NAME environment variable. 52 | @click.argument("tpu-name", type=str) 53 | # The topology of the TPU. Defaults to the TPU_TOPOLOGY environment variable. 54 | @click.argument("tpu-topology", type=str) 55 | @click.argument("pretrained-model", type=str) 56 | @click.option( 57 | "--split", 58 | type=str, 59 | default="train", 60 | help="The split on which to train. Defaults to 'train'.", 61 | ) 62 | @click.option( 63 | "--n-steps", 64 | type=int, 65 | default=600000, 66 | help="The number of gradient updates. Defaults to 25,000.", 67 | ) 68 | @click.option( 69 | "--save-checkpoints-steps", 70 | type=int, 71 | default=5000, 72 | help=( 73 | "The number of steps to take before saving a checkpoint. Defaults to" 74 | " 5000." 75 | ), 76 | ) 77 | @click.option( 78 | "--n-checkpoints-to-keep", 79 | type=int, 80 | default=300, 81 | help=( 82 | "The number of checkpoints to keep during fine-tuning. Defaults" 83 | " to 4." 84 | ), 85 | ) 86 | @click.option( 87 | "--learning-rate", 88 | type=float, 89 | default=2e-4, 90 | help="The learning rate to use for training. Defaults to 3e-3.", 91 | ) 92 | @click.option( 93 | "--continue_finetune", 94 | type=bool, 95 | default=False, 96 | help="Whether to continue training from an existing checkpoint.", 97 | ) 98 | def fine_tune( 99 | mixture: str, 100 | results_dir: str, 101 | split: str, 102 | pretrained_model: str, 103 | n_steps: int, 104 | learning_rate: float, 105 | save_checkpoints_steps: int, 106 | n_checkpoints_to_keep: int, 107 | tpu_name: str, 108 | tpu_topology: str, 109 | continue_finetune: bool, 110 | ) -> None: 111 | """ 112 | Fine-tune the model on MIXTURE, writing results to RESULTS_DIR. 113 | """ 114 | 115 | # Initialize arguments 116 | if tpu_topology == "v3-32": 117 | batch_size = 16 118 | model_parallelism = 32 119 | elif tpu_topology == "v3-8": 120 | batch_size = 8 121 | model_parallelism = 8 122 | else: 123 | print("ERROR: tpu_topology invalid") 124 | return 125 | 126 | pretrained_checkpoint_step = -1 127 | 128 | # Get result path given arguments 129 | result_path = util.get_result_path( 130 | results_dir, pretrained_model, mixture, learning_rate, batch_size) 131 | 132 | # Validate path 133 | util.validate_path(results_dir, pretrained_model, PRETRAINED_MODELS) 134 | 135 | # Process arguments 136 | if pretrained_model in PRETRAINED_MODELS: 137 | pretrained_model, pretrained_checkpoint_step = PRETRAINED_MODELS[pretrained_model] 138 | 139 | # If the training stops before finishing and we want to continue from the last checkpoint 140 | if continue_finetune: 141 | pretrained_model = result_path 142 | 143 | # Print arguments 144 | util.print_arguments(result_path, results_dir, mixture, split, pretrained_model, 145 | pretrained_checkpoint_step, n_steps, batch_size, model_parallelism, 146 | save_checkpoints_steps, n_checkpoints_to_keep, learning_rate, 147 | tpu_name, tpu_topology, tasks, continue_finetune) 148 | 149 | # Run fine-tuning 150 | model = t5.models.MtfModel( 151 | model_dir=result_path, 152 | tpu=tpu_name, 153 | tpu_topology=tpu_topology, 154 | model_parallelism=model_parallelism, 155 | batch_size=batch_size, 156 | sequence_length={"inputs": 512, "targets": 128}, 157 | learning_rate_schedule=learning_rate, 158 | save_checkpoints_steps=save_checkpoints_steps, 159 | keep_checkpoint_max=n_checkpoints_to_keep, 160 | iterations_per_loop=100, 161 | ) 162 | 163 | model.finetune( 164 | mixture_or_task_name=mixture, 165 | pretrained_model_dir=pretrained_model, 166 | pretrained_checkpoint_step=pretrained_checkpoint_step, 167 | finetune_steps=n_steps, 168 | split=split, 169 | ) 170 | 171 | 172 | if __name__ == "__main__": 173 | fine_tune() 174 | -------------------------------------------------------------------------------- /src/delphi/train/mixtures.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data mixtures 3 | """ 4 | 5 | import os 6 | import t5 7 | import tasks 8 | import rates 9 | import seqio 10 | import functools 11 | 12 | import util 13 | 14 | # ################### register mixtures ################### 15 | # seqio.MixtureRegistry.add( 16 | # "commonsense_morality_joint_all_proportional", 17 | # ["moral_acceptability", 18 | # "moral_agreement", 19 | # "moral_comparison"], 20 | # default_rate=rates.MIXING_RATES["proportional"] 21 | # ) 22 | # util.print_mixture_examples("commonsense_morality_joint_all_proportional") 23 | # 24 | # seqio.MixtureRegistry.add( 25 | # "commonsense_morality_separate_all_proportional", 26 | # ["moral_acceptability_class", 27 | # "moral_acceptability_text", 28 | # "moral_agreement_class", 29 | # "moral_agreement_text", 30 | # "moral_comparison_class"], 31 | # default_rate=rates.MIXING_RATES["proportional"] 32 | # ) 33 | # util.print_mixture_examples("commonsense_morality_separate_all_proportional", num_ex=1) 34 | 35 | 36 | # seqio.MixtureRegistry.add( 37 | # "sbic_commonsense_morality_joint_comparison_double_all_proportional", 38 | # ["sbic_moral_acceptability", 39 | # "sbic_moral_agreement", 40 | # "sbic_moral_comparison_double"], 41 | # default_rate=rates.MIXING_RATES["proportional"] 42 | # ) 43 | # util.print_mixture_examples("sbic_commonsense_morality_joint_comparison_double_all_proportional") 44 | # 45 | # 46 | # seqio.MixtureRegistry.add( 47 | # "sbic_commonsense_morality_separate_all_proportional", 48 | # ["sbic_moral_acceptability_class", 49 | # "sbic_moral_acceptability_text", 50 | # "sbic_moral_agreement_class", 51 | # "sbic_moral_agreement_text", 52 | # "sbic_moral_comparison_class"], 53 | # default_rate=rates.MIXING_RATES["proportional"] 54 | # ) 55 | # util.print_mixture_examples("sbic_commonsense_morality_separate_all_proportional") 56 | # 57 | # 58 | # seqio.MixtureRegistry.add( 59 | # "commonsense_morality_separate_wo_agreement_class_all_proportional", 60 | # ["moral_acceptability_class", 61 | # "moral_acceptability_text", 62 | # "moral_agreement_text", 63 | # "moral_comparison_class"], 64 | # default_rate=rates.MIXING_RATES["proportional"] 65 | # ) 66 | # util.print_mixture_examples("commonsense_morality_separate_wo_agreement_class_all_proportional") 67 | # 68 | # 69 | # seqio.MixtureRegistry.add( 70 | # "sbic_commonsense_morality_separate_wo_agreement_class_all_proportional", 71 | # ["sbic_moral_acceptability_class", 72 | # "sbic_moral_acceptability_text", 73 | # "sbic_moral_agreement_text", 74 | # "sbic_moral_comparison_class"], 75 | # default_rate=rates.MIXING_RATES["proportional"] 76 | # ) 77 | # util.print_mixture_examples("sbic_commonsense_morality_separate_wo_agreement_class_all_proportional") 78 | # 79 | # 80 | # seqio.MixtureRegistry.add( 81 | # "commonsense_morality_separate_text_only_all_proportional", 82 | # ["moral_acceptability_text", 83 | # "moral_agreement_text"], 84 | # default_rate=rates.MIXING_RATES["proportional"] 85 | # ) 86 | # util.print_mixture_examples("commonsense_morality_separate_text_only_all_proportional") 87 | # 88 | # 89 | # seqio.MixtureRegistry.add( 90 | # "sbic_commonsense_morality_separate_text_only_all_proportional", 91 | # ["sbic_moral_acceptability_text", 92 | # "sbic_moral_agreement_text"], 93 | # default_rate=rates.MIXING_RATES["proportional"] 94 | # ) 95 | # util.print_mixture_examples("sbic_commonsense_morality_separate_text_only_all_proportional") 96 | # , 1, 10, 30, 60, 90 97 | 98 | 99 | seqio.MixtureRegistry.add( 100 | "sbic_commonsense_morality_joint_all_proportional", 101 | ["sbic_moral_acceptability", 102 | "sbic_moral_agreement", 103 | "sbic_moral_comparison"], 104 | default_rate=rates.MIXING_RATES["proportional"] 105 | ) 106 | util.print_mixture_examples("sbic_commonsense_morality_joint_all_proportional") 107 | 108 | # ################## ethics values raw (for fine-tuning) ################## 109 | # seqio.MixtureRegistry.add( 110 | # "ethics_values_raw_with_cm_long_all_proportional", 111 | # ["ethics_cm_long_raw", 112 | # "ethics_deontology_raw", 113 | # "ethics_justice_raw", 114 | # "ethics_util_raw", 115 | # "ethics_virtue_raw"], 116 | # default_rate=rates.MIXING_RATES["proportional"] 117 | # ) 118 | # 119 | # seqio.MixtureRegistry.add( 120 | # "ethics_values_raw_with_cm_overall_all_proportional", 121 | # ["ethics_cm_raw", 122 | # "ethics_deontology_raw", 123 | # "ethics_justice_raw", 124 | # "ethics_util_raw", 125 | # "ethics_virtue_raw"], 126 | # default_rate=rates.MIXING_RATES["proportional"] 127 | # ) 128 | # 129 | # seqio.MixtureRegistry.add( 130 | # "ethics_values_raw_without_cm_all_proportional", 131 | # ["ethics_deontology_raw", 132 | # "ethics_justice_raw", 133 | # "ethics_util_raw", 134 | # "ethics_virtue_raw"], 135 | # default_rate=rates.MIXING_RATES["proportional"] 136 | # ) 137 | # 138 | # 139 | # ################## ethics values (for pre-train on ethics) ################## 140 | # seqio.MixtureRegistry.add( 141 | # "ethics_values_with_cm_long_all_proportional", 142 | # ["ethics_cm_long", 143 | # "ethics_deontology", 144 | # "ethics_justice", 145 | # "ethics_util", 146 | # "ethics_virtue"], 147 | # default_rate=rates.MIXING_RATES["proportional"] 148 | # ) 149 | # 150 | # seqio.MixtureRegistry.add( 151 | # "ethics_values_with_cm_overall_all_proportional", 152 | # ["ethics_cm", 153 | # "ethics_deontology", 154 | # "ethics_justice", 155 | # "ethics_util", 156 | # "ethics_virtue"], 157 | # default_rate=rates.MIXING_RATES["proportional"] 158 | # ) 159 | # 160 | # seqio.MixtureRegistry.add( 161 | # "ethics_values_without_cm_all_proportional", 162 | # ["ethics_deontology", 163 | # "ethics_justice", 164 | # "ethics_util", 165 | # "ethics_virtue"], 166 | # default_rate=rates.MIXING_RATES["proportional"] 167 | # ) 168 | 169 | 170 | ################## commonsense norm bank + ethics values ################## 171 | 172 | # seqio.MixtureRegistry.add( 173 | # "sbic_joint_norm_bank_ethics_all_proportional", 174 | # ["sbic_moral_acceptability", 175 | # "sbic_moral_agreement", 176 | # "sbic_moral_comparison", 177 | # "ethics_cm", 178 | # "ethics_deontology", 179 | # "ethics_justice", 180 | # "ethics_util", 181 | # "ethics_virtue"], 182 | # default_rate=rates.MIXING_RATES["proportional"] 183 | # ) 184 | # # util.print_mixture_examples("sbic_joint_norm_bank_ethics_all_proportional") 185 | 186 | # seqio.MixtureRegistry.add( 187 | # "sbic_commonsense_morality_joint_all_proportional_demo_v4", 188 | # ["sbic_moral_acceptability", 189 | # "sbic_moral_agreement", 190 | # "sbic_moral_comparison", 191 | # "demo_v4"], 192 | # default_rate=rates.MIXING_RATES["proportional"] 193 | # ) 194 | # util.print_mixture_examples("sbic_commonsense_morality_joint_all_proportional_demo_v4") 195 | # 196 | 197 | 198 | ################# commonsense norm bank + ablations ################## 199 | proportions = [0.01] # , 1, 10, 30, 60, 90, "base" 200 | for proportion in proportions: 201 | seqio.MixtureRegistry.add( 202 | f"sbic_commonsense_morality_joint_all_proportional_new_{proportion}", 203 | [f"sbic_moral_acceptability_{proportion}", 204 | f"sbic_moral_agreement_{proportion}", 205 | f"sbic_moral_comparison_{proportion}"], 206 | default_rate=rates.MIXING_RATES["proportional"] 207 | ) 208 | util.print_mixture_examples( 209 | f"sbic_commonsense_morality_joint_all_proportional_new_{proportion}") 210 | 211 | 212 | ################## commonsense norm bank + wild ablations ################## 213 | # proportions = [10, 20, 40, 60, 80, 100] 214 | # for p in proportions: 215 | # seqio.MixtureRegistry.add( 216 | # f"sbic_commonsense_morality_joint_all_proportional_wild_{p}", 217 | # [ f"wild_train_{p}", 218 | # "sbic_moral_acceptability", 219 | # "sbic_moral_agreement", 220 | # "sbic_moral_comparison"], 221 | # default_rate=rates.MIXING_RATES["proportional"] 222 | # ) 223 | # util.print_mixture_examples(f"sbic_commonsense_morality_joint_all_proportional_wild_{p}") 224 | 225 | # seqio.MixtureRegistry.add( 226 | # "sbic_commonsense_morality_joint_all_proportional_wild_woz_100", 227 | # [ f"wild_train_woz_100", 228 | # "sbic_moral_acceptability", 229 | # "sbic_moral_agreement", 230 | # "sbic_moral_comparison"], 231 | # default_rate=rates.MIXING_RATES["proportional"] 232 | # ) 233 | # util.print_mixture_examples(f"sbic_commonsense_morality_joint_all_proportional_wild_woz_100") 234 | 235 | 236 | # seqio.MixtureRegistry.add( 237 | # "wild_hard_test", 238 | # [ "race_test", 239 | # "gender_test"], 240 | # default_rate=rates.MIXING_RATES["proportional"] 241 | # ) 242 | # util.print_mixture_examples(f"wild_hard_test") 243 | 244 | 245 | seqio.MixtureRegistry.add( 246 | "all", 247 | [f"wild_train_100", 248 | "sbic_moral_acceptability", 249 | "sbic_moral_agreement", 250 | "sbic_moral_comparison", 251 | "race_test", 252 | "gender_test"], 253 | default_rate=rates.MIXING_RATES["proportional"] 254 | ) 255 | util.print_mixture_examples(f"all") 256 | 257 | 258 | seqio.MixtureRegistry.add( 259 | "hard_all", 260 | [f"wild_train_100", 261 | "race_test", 262 | "gender_test"], 263 | default_rate=rates.MIXING_RATES["proportional"] 264 | ) 265 | util.print_mixture_examples(f"hard_all") 266 | 267 | 268 | seqio.MixtureRegistry.add( 269 | "race_gender", 270 | ["race_test", 271 | "gender_test"], 272 | default_rate=rates.MIXING_RATES["proportional"] 273 | ) 274 | 275 | dynahate = False 276 | if dynahate: 277 | # seqio.MixtureRegistry.add( 278 | # "dynahate_all", 279 | # [f"dynahate_round_1", 280 | # f"dynahate_round_2", 281 | # f"dynahate_round_3", 282 | # f"dynahate_round_4", ], 283 | # default_rate=rates.MIXING_RATES["proportional"] 284 | # ) 285 | # 286 | # seqio.MixtureRegistry.add( 287 | # "dynahate_all_100_shot", 288 | # [f"dynahate_round_1_100_shot", 289 | # f"dynahate_round_2_100_shot", 290 | # f"dynahate_round_3_100_shot", 291 | # f"dynahate_round_4_100_shot", ], 292 | # default_rate=rates.MIXING_RATES["proportional"] 293 | # ) 294 | 295 | # seqio.MixtureRegistry.add( 296 | # "dynahate_all_nat", 297 | # [f"dynahate_round_1_nat", 298 | # f"dynahate_round_2_nat", 299 | # f"dynahate_round_3_nat", 300 | # f"dynahate_round_4_nat", ], 301 | # default_rate=rates.MIXING_RATES["proportional"] 302 | # ) 303 | # 304 | # seqio.MixtureRegistry.add( 305 | # "dynahate_all_nat_100_shot", 306 | # [f"dynahate_round_1_nat_100_shot", 307 | # f"dynahate_round_2_nat_100_shot", 308 | # f"dynahate_round_3_nat_100_shot", 309 | # f"dynahate_round_4_nat_100_shot", ], 310 | # default_rate=rates.MIXING_RATES["proportional"] 311 | # ) 312 | 313 | seqio.MixtureRegistry.add( 314 | "dynahate_all_st", 315 | [f"dynahate_round_1_st", 316 | f"dynahate_round_2_st", 317 | f"dynahate_round_3_st", 318 | f"dynahate_round_4_st",], 319 | default_rate=rates.MIXING_RATES["proportional"] 320 | ) 321 | 322 | seqio.MixtureRegistry.add( 323 | "dynahate_all_st_100_shot", 324 | [f"dynahate_round_1_st_100_shot", 325 | f"dynahate_round_2_st_100_shot", 326 | f"dynahate_round_3_st_100_shot", 327 | f"dynahate_round_4_st_100_shot",], 328 | default_rate=rates.MIXING_RATES["proportional"] 329 | ) 330 | 331 | # seqio.MixtureRegistry.add( 332 | # "dynahate_all_st_clean", 333 | # [ f"dynahate_round_1_st_clean", 334 | # f"dynahate_round_2_st_clean", 335 | # f"dynahate_round_3_st_clean", 336 | # f"dynahate_round_4_st_clean",], 337 | # default_rate=rates.MIXING_RATES["proportional"] 338 | # ) 339 | # 340 | # seqio.MixtureRegistry.add( 341 | # "dynahate_all_st_clean_100_shot", 342 | # [ f"dynahate_round_1_st_clean_100_shot", 343 | # f"dynahate_round_2_st_clean_100_shot", 344 | # f"dynahate_round_3_st_clean_100_shot", 345 | # f"dynahate_round_4_st_clean_100_shot",], 346 | # default_rate=rates.MIXING_RATES["proportional"] 347 | # ) 348 | # 349 | # seqio.MixtureRegistry.add( 350 | # "dynahate_all_bc", 351 | # [ f"dynahate_round_1_bc", 352 | # f"dynahate_round_2_bc", 353 | # f"dynahate_round_3_bc", 354 | # f"dynahate_round_4_bc",], 355 | # default_rate=rates.MIXING_RATES["proportional"] 356 | # ) 357 | # 358 | # seqio.MixtureRegistry.add( 359 | # "dynahate_all_bc_100_shot", 360 | # [ f"dynahate_round_1_bc_100_shot", 361 | # f"dynahate_round_2_bc_100_shot", 362 | # f"dynahate_round_3_bc_100_shot", 363 | # f"dynahate_round_4_bc_100_shot",], 364 | # default_rate=rates.MIXING_RATES["proportional"] 365 | # ) 366 | 367 | # seqio.MixtureRegistry.add( 368 | # "declare_only", 369 | # [f"freeform", 370 | # f"yesno"], 371 | # default_rate=rates.MIXING_RATES["proportional"] 372 | # ) 373 | -------------------------------------------------------------------------------- /src/delphi/train/predict.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | """Evaluate the model on the rainbow datasets.""" 4 | 5 | import t5 6 | import logging 7 | import click 8 | import tensorflow.compat.v1 as tf 9 | 10 | # Improve logging. 11 | from contextlib import contextmanager 12 | 13 | 14 | tf.disable_v2_behavior() 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def getSubstringBetweenMarkers(source_string, start_marker, end_marker): 20 | start = source_string.find(start_marker) + len(start_marker) 21 | end = source_string.find(end_marker) 22 | return source_string[start: end] 23 | 24 | 25 | @contextmanager 26 | def tf_verbosity_level(level): 27 | og_level = tf.logging.get_verbosity() 28 | tf.logging.set_verbosity(level) 29 | yield 30 | tf.logging.set_verbosity(og_level) 31 | 32 | 33 | @click.command() 34 | @click.option( 35 | "--batch-size", 36 | type=int, 37 | default=64, 38 | help=( 39 | "The batch size to use for prediction. For efficient prediction on the" 40 | " TPU, choose a multiple of either 8 or 128. Defaults to 64." 41 | ), 42 | ) 43 | @click.option( 44 | "--model-parallelism", 45 | type=int, 46 | default=8, 47 | help="The degree of model parallelism to use. Defaults to 8.", 48 | ) 49 | @click.option( 50 | "--tpu-name", 51 | type=str, 52 | default="de-tpu-5", 53 | required=True, 54 | help="The name of the TPU. Defaults to the TPU_NAME environment variable.", 55 | ) 56 | @click.option( 57 | "--tpu-topology", 58 | type=str, 59 | default="v3-32", 60 | required=True, 61 | help=( 62 | "The topology of the TPU. Defaults to the TPU_TOPOLOGY environment" 63 | " variable." 64 | ), 65 | ) 66 | def predict( 67 | batch_size: int, 68 | model_parallelism: int, 69 | tpu_name: str, 70 | tpu_topology: str, 71 | ) -> None: 72 | """Evaluate the model located at RESULTS_DIR on MIXTURE.""" 73 | eval_data = "UNDHR.idty.0" 74 | 75 | data_version = "v11" 76 | model_type = "distribution" 77 | check_point = 1249400 78 | 79 | lr = 0.0001 80 | bs = 16 81 | bucket_name = "ai2-tpu-europe-west4" 82 | models_dir = f"gs://{bucket_name}/projects/liweij/mosaic-commonsense-morality/model/{data_version}/" \ 83 | f"unicorn-pt/{model_type}/lr-{lr}_bs-{bs}" 84 | training_type = "joint" 85 | 86 | # Run evaluation. 87 | model = t5.models.MtfModel( 88 | model_dir=models_dir, 89 | tpu=tpu_name, 90 | tpu_topology=tpu_topology, 91 | model_parallelism=model_parallelism, 92 | batch_size=batch_size, 93 | sequence_length={"inputs": 512, "targets": 128}, 94 | learning_rate_schedule=None, 95 | save_checkpoints_steps=5000, 96 | keep_checkpoint_max=None, 97 | iterations_per_loop=100, 98 | ) 99 | 100 | predict_joint_inputs_paths = ["gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/" 101 | f"data/qualitative_eval/{training_type}/" + eval_data + "_qualitative_eval.tsv"] 102 | predict_joint_outputs_paths = [ 103 | models_dir.replace("model", "preds") + "/raw/" + eval_data + "_qualitative_eval.tsv"] 104 | 105 | for i in range(len(predict_joint_inputs_paths)): 106 | predict_joint_inputs_path = predict_joint_inputs_paths[i] 107 | predict_joint_outputs_path = predict_joint_outputs_paths[i] 108 | 109 | # Ignore any logging so that we only see the model's answers to the questions. 110 | with tf_verbosity_level('ERROR'): 111 | # Min size for small model on v2-8 with parallelism 1. 112 | model.batch_size = 8 113 | model.predict( 114 | input_file=predict_joint_inputs_path, 115 | output_file=predict_joint_outputs_path, 116 | # Select the most probable output token at each step. 117 | temperature=0, 118 | checkpoint_steps=check_point, 119 | ) 120 | 121 | 122 | if __name__ == "__main__": 123 | predict() 124 | -------------------------------------------------------------------------------- /src/delphi/train/rates.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mixing rates 3 | """ 4 | 5 | import seqio 6 | 7 | def equal_rate(task: seqio.Task): 8 | """Mix the datasets in equal amounts. 9 | 10 | Parameters 11 | ---------- 12 | task : t5.data.Task 13 | The task. 14 | 15 | Returns 16 | ------- 17 | float 18 | The constant: ``1.0``. 19 | """ 20 | return 1.0 21 | 22 | 23 | def proportional_rate(task: seqio.Task): 24 | """Mix the datasets proportionally. 25 | 26 | Parameters 27 | ---------- 28 | task : t5.data.Task 29 | The task. 30 | 31 | Returns 32 | ------- 33 | float 34 | The number of examples in the task's training set. 35 | """ 36 | return float(task.num_input_examples("train")) 37 | 38 | 39 | # constants 40 | 41 | MIXING_RATES = { 42 | "equal": equal_rate, 43 | "proportional": proportional_rate, 44 | } 45 | """A dictionary mapping mixing rates' names to their implementations.""" 46 | -------------------------------------------------------------------------------- /src/delphi/train/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Util functions for fine-tuning and evaluating models 3 | """ 4 | import seqio 5 | import pandas as pd 6 | from google.cloud import storage 7 | import tensorflow_datasets as tfds 8 | import tensorflow as tf 9 | 10 | 11 | def create_folder(client, bucket, destination_folder_name): 12 | """ 13 | Create a folder in Google Cloud Storage if such folder doesn't exist already 14 | """ 15 | if not storage.Blob(bucket=bucket, name=destination_folder_name).exists(client): 16 | blob = bucket.blob(destination_folder_name) 17 | blob.upload_from_string('') 18 | print('Created: {}'.format(destination_folder_name)) 19 | else: 20 | print('Exists: {}'.format(destination_folder_name)) 21 | 22 | 23 | def print_task_examples(task_name, split="validation", num_ex=1): 24 | """ 25 | Print examples from tasks 26 | """ 27 | print("#" * 20, task_name, "#" * 20) 28 | task = seqio.TaskRegistry.get(task_name) 29 | ds = task.get_dataset(split=split, sequence_length={ 30 | "inputs": 512, "targets": 128}) 31 | for i, ex in enumerate(tfds.as_numpy(ds.take(num_ex))): 32 | print(i, ex) 33 | print("test", task.num_input_examples("test")) 34 | print("train", task.num_input_examples("train")) 35 | print("validation", task.num_input_examples("validation")) 36 | 37 | 38 | def print_mixture_examples(mixture_name, split="validation", num_ex=1): 39 | """ 40 | Print examples from mixtures 41 | """ 42 | print("#" * 20, mixture_name, "#" * 20) 43 | mixture = seqio.MixtureRegistry.get(mixture_name) 44 | ds = mixture.get_dataset(split=split, 45 | sequence_length={"inputs": 512, "targets": 128}) 46 | 47 | for i, ex in enumerate(tfds.as_numpy(ds.take(num_ex))): 48 | print(i, ex) 49 | print("test", mixture.num_input_examples("test")) 50 | print("train", mixture.num_input_examples("train")) 51 | print("validation", mixture.num_input_examples("validation")) 52 | 53 | 54 | def get_num_elements_csv(file_name): 55 | """ 56 | Get the total number of elements in a given csv/tsv file 57 | """ 58 | df = pd.read_csv(file_name, delimiter="\t") 59 | return df.shape[0] 60 | 61 | 62 | def get_num_elements_split(split_paths): 63 | """ 64 | Get the number of elements in each split of a dataset 65 | """ 66 | num_elements_split = {} 67 | for split, path in split_paths.items(): 68 | num_elements_split[split] = get_num_elements_csv(path) 69 | return num_elements_split 70 | 71 | 72 | def get_result_check_points(result_prefix, split, eval_data_type, after_check_point=-1): 73 | """ 74 | Get a list of model checkpoints that haven't generated on the designated data split yet 75 | """ 76 | client = storage.Client() 77 | bucket_name = "ai2-tpu-europe-west4" 78 | result_prefix = result_prefix.split(bucket_name + "/")[-1] + "/" 79 | 80 | check_points = [] 81 | done_check_points = [] 82 | for blob in client.list_blobs(bucket_name, prefix=result_prefix): 83 | blob_name = str(blob).split("/")[-1] 84 | if ".meta" in blob_name: 85 | check_point = int(blob_name.split(".meta")[0].split("-")[-1]) 86 | if check_point > after_check_point: 87 | check_points.append(check_point) 88 | 89 | print("-" * 10, "checkpoints all", "-" * 10) 90 | print(check_points) 91 | 92 | for blob in client.list_blobs(bucket_name, prefix=result_prefix + f"{split}_eval/"): 93 | blob_name = str(blob).split("/")[-1] 94 | if "_predictions" in blob_name and eval_data_type in blob_name and "_predictions_clean" not in blob_name: 95 | check_point_done = int(blob_name.split( 96 | "_predictions")[0].split("_")[-1]) 97 | # check_point_done = int(blob_name.split("_")[0].split("_")[-1]) 98 | if check_point_done in check_points: 99 | done_check_points.append(check_point_done) 100 | check_points.remove(check_point_done) 101 | 102 | print("-" * 10, "checkpoints done", "-" * 10) 103 | print(done_check_points) 104 | return check_points 105 | 106 | 107 | def validate_path(results_dir, pretrained_model=None, PRETRAINED_MODELS=None): 108 | """ 109 | Validate result path 110 | """ 111 | if PRETRAINED_MODELS != None: 112 | if not results_dir.startswith("gs://"): 113 | raise ValueError( 114 | f"RESULTS_DIR ({results_dir}) must be a GCS path.") 115 | 116 | if pretrained_model.startswith("gs://"): 117 | if not tf.io.gfile.exists(pretrained_model): 118 | raise IOError( 119 | f"--pretrained-model ({pretrained_model}) does not exist." 120 | ) 121 | else: 122 | if pretrained_model not in PRETRAINED_MODELS: 123 | raise ValueError( 124 | f"--pretrained-model ({pretrained_model}) not recognized. It" 125 | f" must either be a GCS path or one of" 126 | f' {", ".join(PRETRAINED_MODELS.keys())}.') 127 | else: 128 | if not results_dir.startswith("gs://"): 129 | raise ValueError( 130 | f"RESULTS_DIR ({results_dir}) must be a GCS path.") 131 | elif not tf.io.gfile.exists(results_dir): 132 | raise IOError(f"RESULTS_DIR ({results_dir}) doesn't exist.") 133 | 134 | 135 | def print_arguments(result_path, results_dir, mixture, split, pretrained_model, 136 | pretrained_checkpoint_step, n_steps, batch_size, model_parallelism, 137 | save_checkpoints_steps, n_checkpoints_to_keep, learning_rate, 138 | tpu_name, tpu_topology, tasks, continue_finetune): 139 | print("=" * 10, "results_dir") 140 | print(results_dir) 141 | 142 | print("=" * 10, "mixture") 143 | print(mixture) 144 | 145 | print("=" * 10, "split") 146 | print(split) 147 | 148 | print("=" * 10, "pretrained_model") 149 | print(pretrained_model) 150 | 151 | print("=" * 10, "pretrained_checkpoint_step") 152 | print(pretrained_checkpoint_step) 153 | 154 | print("=" * 10, "n_steps") 155 | print(n_steps) 156 | 157 | print("=" * 10, "batch_size") 158 | print(batch_size) 159 | 160 | print("=" * 10, "model_parallelism") 161 | print(model_parallelism) 162 | 163 | print("=" * 10, "save_checkpoints_steps") 164 | print(save_checkpoints_steps) 165 | 166 | print("=" * 10, "n_checkpoints_to_keep") 167 | print(n_checkpoints_to_keep) 168 | 169 | print("=" * 10, "learning_rate") 170 | print(learning_rate) 171 | 172 | print("=" * 10, "tpu_name") 173 | print(tpu_name) 174 | 175 | print("=" * 10, "tpu_topology") 176 | print(tpu_topology) 177 | 178 | print("=" * 10, "result_path") 179 | print(result_path) 180 | 181 | print("=" * 10, "data_version") 182 | print(tasks.data_version) 183 | 184 | print("=" * 10, "continue_finetune") 185 | print(continue_finetune) 186 | 187 | 188 | def get_result_path( 189 | results_dir: str, 190 | pretrained_model: str, 191 | mixture: str, 192 | learning_rate: float, 193 | batch_size: int 194 | ) -> str: 195 | """ 196 | Get a result path given arguments 197 | """ 198 | result_path = results_dir + \ 199 | "/" + pretrained_model + \ 200 | "/" + mixture + \ 201 | f"/lr-{learning_rate}_bs-{batch_size}" 202 | return result_path 203 | -------------------------------------------------------------------------------- /src/delphi_hybrid/components/COMETGenerator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | from treelib import Node, Tree 5 | from transformers import GPT2LMHeadModel, GPT2TokenizerFast 6 | 7 | sys.path.append(os.getcwd()) 8 | from src.delphi_hybrid.components.bank import * 9 | 10 | 11 | class COMETGenerator(): 12 | def __init__(self, model_name="gpt2-xl-atomic2020", device_id=0, server="beaker"): # beaker_batch, local 13 | if model_name == "gpt2-xl-atomic2020": 14 | if server == "beaker_batch": 15 | base_path_atomic_2020 = "/model/atomic2020" 16 | else: 17 | base_path_atomic_2020 = "/net/nfs.cirrascale/mosaic/liweij/model/atomic2020" 18 | 19 | CUDA_DEVICE = f"cuda:{device_id}" if torch.cuda.is_available() else 'cpu' 20 | self.device = torch.device(CUDA_DEVICE) 21 | print(f"COMETGenerator device: {self.device}") 22 | self.model_name = model_name 23 | self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2-xl") 24 | self.model = GPT2LMHeadModel.from_pretrained(base_path_atomic_2020, 25 | pad_token_id=self.tokenizer.eos_token_id).to(self.device) 26 | self.add_special_tokens() 27 | 28 | self.BOS_TOKEN = self.tokenizer.bos_token 29 | self.EOS_TOKEN = self.tokenizer.eos_token 30 | self.GEN_TOKEN = "[GEN]" 31 | 32 | def add_special_tokens(self): 33 | self.tokenizer.add_special_tokens({ 34 | 'eos_token': '[EOS]', 35 | 'additional_special_tokens': [ 36 | 'LocationOfAction', 37 | 'HinderedBy', 38 | 'HasFirstSubevent', 39 | 'NotHasProperty', 40 | 'NotHasA', 41 | 'HasA', 42 | 'AtLocation', 43 | 'NotCapableOf', 44 | 'CausesDesire', 45 | 'HasPainCharacter', 46 | 'NotDesires', 47 | 'MadeUpOf', 48 | 'InstanceOf', 49 | 'SymbolOf', 50 | 'xReason', 51 | 'isAfter', 52 | 'HasPrerequisite', 53 | 'UsedFor', 54 | 'MadeOf', 55 | 'MotivatedByGoal', 56 | 'Causes', 57 | 'oEffect', 58 | 'CreatedBy', 59 | 'ReceivesAction', 60 | 'NotMadeOf', 61 | 'xWant', 62 | 'PartOf', 63 | 'DesireOf', 64 | 'HasPainIntensity', 65 | 'xAttr', 66 | 'DefinedAs', 67 | 'oReact', 68 | 'xIntent', 69 | 'HasSubevent', 70 | 'oWant', 71 | 'HasProperty', 72 | 'IsA', 73 | 'HasSubEvent', 74 | 'LocatedNear', 75 | 'Desires', 76 | 'isFilledBy', 77 | 'isBefore', 78 | 'InheritsFrom', 79 | 'xNeed', 80 | 'xEffect', 81 | 'xReact', 82 | 'HasLastSubevent', 83 | 'RelatedTo', 84 | 'CapableOf', 85 | 'NotIsA', 86 | 'ObjectUse', 87 | '[GEN]' 88 | ] 89 | }) 90 | self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 91 | self.model.resize_token_embeddings(len(self.tokenizer)) 92 | 93 | def __name__(self): 94 | return self.model_name 95 | 96 | def generate(self, head_event, relation): 97 | input_string = head_event + " " + relation + " [GEN]" 98 | input_ids = self.tokenizer(input_string, return_tensors='pt').to(self.device).input_ids 99 | outputs = self.model.generate(input_ids, max_length=200, output_scores=True, return_dict_in_generate=True) 100 | 101 | decoded_sequence = self.tokenizer.decode(outputs["sequences"][0]) 102 | tail_event = decoded_sequence.split(self.GEN_TOKEN)[-1].split(self.EOS_TOKEN)[0] 103 | return tail_event 104 | 105 | def generate_beam(self, head_event, relation, num_beams=5, num_return_sequences=5, max_length=100): 106 | input_string = head_event + " " + relation + " [GEN]" 107 | tokenized_input = self.tokenizer(input_string, 108 | max_length=max_length, 109 | truncation=True, 110 | return_tensors='pt').to(self.device) 111 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 112 | 113 | outputs = self.model.generate(input_ids=tokenized_input.input_ids, 114 | attention_mask=tokenized_input.attention_mask, 115 | # output_scores=True, 116 | # return_dict_in_generate=True, 117 | num_beams=num_beams, 118 | max_length=max_length, 119 | num_return_sequences=num_return_sequences,) 120 | 121 | decoded_sequences = self.tokenizer.batch_decode(outputs) 122 | 123 | tail_events = [ds.split(self.GEN_TOKEN)[-1].split(self.EOS_TOKEN)[0] for ds in decoded_sequences] 124 | 125 | return tail_events 126 | 127 | def generate_all_relations(self, event): 128 | comet_inferences = {} 129 | for relation in comet_relations: 130 | tail_events = self.generate_beam(event, relation) 131 | comet_inferences[relation] = tail_events 132 | return comet_inferences 133 | 134 | if __name__ == "__main__": 135 | comet_generator = COMETGenerator(device_id=0) 136 | 137 | head_events = ["a bear", 138 | "being a stupid bear", 139 | "performing genocide", 140 | "a protected bear"] 141 | 142 | # for head_event in head_events: 143 | for relation in comet_relations: 144 | tail_events = comet_generator.generate_beam(head_events[0], relation) 145 | -------------------------------------------------------------------------------- /src/delphi_hybrid/components/CacheHandler/COMETCacheHandler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | from tqdm import tqdm 5 | 6 | sys.path.append(os.getcwd()) 7 | from src.delphi_hybrid.components.utils import * 8 | from src.delphi_hybrid.components.COMETGenerator import * 9 | from src.delphi_hybrid.components.CacheHandler.CacheHandler import * 10 | 11 | class COMETCacheHandler(CacheHandler): 12 | def __init__(self, filename="comet_subset", cache_dir="cache", device_id=0): 13 | if filename != None and "comet" not in filename: 14 | print("ERROR: wrong cache file!") 15 | super().__init__("comet", cache_dir, filename) 16 | self.comet_generator = COMETGenerator(device_id=device_id) 17 | 18 | def _generate_instance(self, event): 19 | return self.comet_generator.generate_all_relations(event) 20 | 21 | 22 | if __name__ == "__main__": 23 | events = list(read_json(data_base_path + f"cache_norm_bank/all_sequences.json").keys()) 24 | print(events) 25 | 26 | cache_handler = COMETCacheHandler(filename="norm_bank_comet", cache_dir="cache_norm_bank") 27 | for event in tqdm(events): 28 | comet_instance = cache_handler.save_instance(event) 29 | -------------------------------------------------------------------------------- /src/delphi_hybrid/components/CacheHandler/CacheHandler.py: -------------------------------------------------------------------------------- 1 | """ 2 | File: CacheHandler.py 3 | Note: The parent class of all CacheHandlers 4 | """ 5 | 6 | import os 7 | import sys 8 | import abc 9 | import json 10 | from tqdm import tqdm 11 | 12 | sys.path.append(os.getcwd()) 13 | from src.delphi_hybrid.components.utils import * 14 | 15 | class CacheHandler(): 16 | __metaclass__ = abc.ABCMeta 17 | 18 | def __init__(self, cache_type, cache_dir="cache", filename=None): 19 | self.cache_type = cache_type 20 | self.filename = filename 21 | self.cache_dir = cache_dir 22 | self._load_cache() 23 | print(f"{cache_type} cache size: {len(self.cache)}") 24 | 25 | def __name__(self): 26 | return self.cache_type + "CacheHandler" 27 | 28 | def _get_cache_path(self): 29 | if self.filename == None: 30 | path = data_base_path + f"{self.cache_dir}/{self.cache_type}.json" 31 | else: 32 | path = data_base_path + f"{self.cache_dir}/{self.filename}.json" 33 | return path 34 | 35 | def _load_cache(self): 36 | cache_file_path = self._get_cache_path() 37 | cache_dir_path = "/".join(cache_file_path.split("/")[:-1]) 38 | if not os.path.exists(cache_file_path): 39 | ensure_dir(cache_dir_path) 40 | with open(cache_file_path, 'w') as f: 41 | json.dump({}, f) 42 | print(f"Loading {self.cache_type} cache from ...") 43 | print(cache_file_path) 44 | self.cache = read_json(cache_file_path) 45 | print(f"...done loading {self.cache_type} cache!") 46 | 47 | def _save_cache(self): 48 | cache_file_path = self._get_cache_path() 49 | with open(cache_file_path, 'w') as f: 50 | json.dump(self.cache, f) 51 | print("Saved cache to", cache_file_path) 52 | 53 | def _add_instance(self, cache_key, instance, is_save): 54 | self.cache[cache_key] = instance 55 | if is_save: 56 | self._save_cache() 57 | 58 | @abc.abstractmethod 59 | def _generate_instance(self, event): 60 | pass 61 | 62 | def fetch_instance(self, event): 63 | if event not in self.cache or (event in self.cache and self.cache[event] == None): 64 | return None 65 | else: 66 | return self.cache[event] 67 | 68 | def get_instance(self, event, is_save=False): 69 | if event not in self.cache or (event in self.cache and self.cache[event] == None): 70 | instance = self._generate_instance(event) 71 | self._add_instance(event, instance, is_save) 72 | else: 73 | # print("[Note] Instance exists in cache!") 74 | instance = self.cache[event] 75 | return instance 76 | 77 | def update_instance(self, event, is_save=True): 78 | """ 79 | Regenerate instance (no matter if it's already in the cache or not) and update the cache 80 | """ 81 | instance = self._generate_instance(event) 82 | self._add_instance(event, instance, is_save=is_save) 83 | return instance 84 | 85 | def save_instance(self, event): 86 | return self.get_instance(event, is_save=True) 87 | 88 | # def save_all_instance(self, events): 89 | # for event in tqdm(events): 90 | # self.save_instance(event) 91 | 92 | def save_all_instance(self, events, save_interval=1): 93 | for i, event in enumerate(tqdm(events)): 94 | if type(event) != type(""): 95 | continue 96 | self.get_instance(event) 97 | if i % save_interval == 0: 98 | self._save_cache() 99 | self._save_cache() 100 | 101 | def regenerate_all_instance(self, events, save_interval=1): 102 | for i, event in enumerate(tqdm(events)): 103 | if type(event) != type(""): 104 | continue 105 | self.update_instance(event, is_save=False) 106 | if i % save_interval == 0: 107 | self._save_cache() 108 | self._save_cache() 109 | 110 | if __name__ == "__main__": 111 | delphi_cache_handler = CacheHandler("delphi_scores") 112 | delphi_cache = delphi_cache_handler.cache 113 | -------------------------------------------------------------------------------- /src/delphi_hybrid/components/CacheHandler/DelphiCacheHandler.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | sys.path.append(os.getcwd()) 6 | from src.delphi_hybrid.components.DelphiScorer import * 7 | from src.delphi_hybrid.components.CacheHandler.CacheHandler import * 8 | 9 | class DelphiCacheHandler(CacheHandler): 10 | def __init__(self, filename=None, cache_dir="cache", model="t5-11b-1239200", device_id=0, server="local"): #"beaker_batch" 11 | if filename != None and "delphi" not in filename: 12 | print("ERROR: wrong cache file!") 13 | super().__init__("delphi_scores", filename=filename, cache_dir=cache_dir) 14 | self.delphi_generator = DelphiScorer(model=model, device_id=device_id, server=server) 15 | 16 | def _generate_instance(self, event): 17 | class_label, probs, text_label = self.delphi_generator.generate_with_score(event) 18 | return {"class_label": class_label, 19 | "prob_1": probs[0], 20 | "prob_0": probs[1], 21 | "prob_minus_1": probs[2], 22 | "text_label": text_label} 23 | 24 | if __name__ == "__main__": 25 | events = [] 26 | for split in ["test", "validation"]: 27 | input_file = data_base_path + f"cache_norm_bank/events/clean_{split}.moral_acceptability.tsv" 28 | df_data = pd.read_csv(input_file, sep="\t") 29 | events += df_data["clean_event"].tolist() 30 | 31 | cache_handler = DelphiCacheHandler(filename="norm_bank_delphi", cache_dir="cache_norm_bank") 32 | for event in tqdm(events): 33 | comet_instance = cache_handler.save_instance(event) 34 | -------------------------------------------------------------------------------- /src/delphi_hybrid/components/CacheHandler/ParaphraseCacheHandler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | 5 | sys.path.append(os.getcwd()) 6 | from src.delphi_hybrid.components.Paraphraser import * 7 | from src.delphi_hybrid.components.CacheHandler.CacheHandler import * 8 | 9 | class ParaphraseCacheHandler(CacheHandler): 10 | def __init__(self, filename=None, cache_dir="cache", num_paraphrases=8): 11 | super().__init__("paraphrases", cache_dir, filename) 12 | self.paraphraser = Paraphraser() 13 | self.num_paraphrases = num_paraphrases 14 | 15 | def _clean_paraphrases(self, event, paraphrases): 16 | paraphrases = list(set(paraphrases)) 17 | if event in paraphrases: 18 | paraphrases.remove(event) 19 | 20 | for paraphrase in paraphrases: 21 | is_qualified= self.paraphraser.qualify_paraphrase(event, paraphrase) 22 | if not is_qualified: 23 | paraphrases.remove(paraphrase) 24 | return paraphrases 25 | 26 | def save_instance(self, event): 27 | instance = [] 28 | if event in self.cache: 29 | instance = self.cache[event] 30 | 31 | if len(instance) < self.num_paraphrases: 32 | instance += self.paraphraser.generate_paraphrases(event, num_paraphrases=self.num_paraphrases)["paraphrases"] 33 | instance = list(set(instance)) 34 | self._add_instance(event, instance, is_save=True) 35 | else: 36 | print("[Note] Enough paraphrases in cache!") 37 | print("Num paraphrases:", len(instance)) 38 | return instance 39 | 40 | def update_instance(self, event, is_save=True): 41 | return None 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser(description="Generate answers with GPT-3.") 46 | parser.add_argument("--section_id", type=int, default="section_id") 47 | args = parser.parse_args() 48 | 49 | input_file = data_base_path + "cache_norm_bank/events/clean_test.moral_acceptability.tsv" 50 | df_data = pd.read_csv(input_file, sep="\t") 51 | events = df_data["clean_event"].tolist() 52 | 53 | section_id = 0 54 | 55 | cache_handler = ParaphraseCacheHandler(filename=f"paraphrases/norm_bank_paraphrases_{section_id}", 56 | cache_dir="cache_norm_bank", 57 | num_paraphrases=10) 58 | for e in tqdm(events): 59 | print(len(e)) 60 | cache_handler.save_instance(e) 61 | -------------------------------------------------------------------------------- /src/delphi_hybrid/components/CompositionalityParser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tqdm import tqdm 4 | 5 | sys.path.append(os.getcwd()) 6 | from src.delphi_hybrid.components.utils import * 7 | 8 | 9 | class CompositionalityParser(): 10 | def __init__(self): 11 | self.to_exclude_list = ["able", "want", "refuse", "try", "due", "go", 12 | "supposed", "claim", "pretend", 13 | # https://grammar.collinsdictionary.com/us/easy-learning/which-verbs-are-followed-by-the-to-infinitive-in-english 14 | "agree", "arrange", "attempt", "choose", "decide", "fail", "hope", "learn", "manage", "offer", "plan", "seem", "come", 15 | "how", "when", "which", "what", "struggle", "remember"] # "lie", 16 | 17 | self.phrases_to_replace = {"in order to": "to", "so that": "so"} 18 | 19 | self.relative_pronounce_list = [ 20 | "that", "which", "who", "whom", "where"] 21 | 22 | self.for_exclude_list = ["pay"] # "lie", 23 | 24 | self.adj_exclude_list = ["nuclear"] 25 | self.conj_additional_list = [ 26 | "otherwise", "to", "for", "so", "because", "by", "or"] # , "if", "when" 27 | 28 | def fix_event(self, event): 29 | for pr in self.phrases_to_replace: 30 | if pr in event: 31 | event = event.replace(pr, self.phrases_to_replace[pr]) 32 | return event 33 | 34 | def organize_subevents(self, tokens, lemmatized_tokens, poss, deps, idxs, idxs_to_segment, original_tokens): 35 | all_segments = [] 36 | last_idx = 0 37 | for current_idx in idxs_to_segment: 38 | if last_idx in idxs_to_segment: 39 | all_segments.append([tokens[last_idx], 40 | [tokens[last_idx]], 41 | [lemmatized_tokens[last_idx]], 42 | [poss[last_idx]], 43 | [deps[last_idx]], 44 | [idxs[last_idx]], 45 | [original_tokens[last_idx]], 46 | "conjunction"]) 47 | last_idx += 1 48 | all_segments.append([" ".join(tokens[last_idx:current_idx]), 49 | tokens[last_idx:current_idx], 50 | lemmatized_tokens[last_idx:current_idx], 51 | poss[last_idx:current_idx], 52 | deps[last_idx:current_idx], 53 | idxs[last_idx:current_idx], 54 | original_tokens[last_idx:current_idx], 55 | "content"]) 56 | last_idx = current_idx 57 | 58 | if last_idx in idxs_to_segment: 59 | all_segments.append([tokens[last_idx], 60 | [tokens[last_idx]], 61 | [lemmatized_tokens[last_idx]], 62 | [poss[last_idx]], 63 | [deps[last_idx]], 64 | [idxs[last_idx]], 65 | [original_tokens[last_idx]], 66 | "conjunction"]) 67 | last_idx += 1 68 | all_segments.append([" ".join(tokens[last_idx:]), 69 | tokens[last_idx:], 70 | lemmatized_tokens[last_idx:], 71 | poss[last_idx:], 72 | deps[last_idx:], 73 | idxs[last_idx:], 74 | original_tokens[last_idx:], 75 | "content"]) 76 | segments = [segment for segment in all_segments if segment[0] != ""] 77 | return segments 78 | 79 | def get_subevents(self, tokens, lemmatized_tokens, poss, deps): 80 | original_tokens = [None for _ in range(len(tokens))] 81 | conjs_idxs_global = [i for i, pos in enumerate(poss) if ( 82 | (len(pos) > 3 and pos[-4:] == "CONJ") or lemmatized_tokens[i] in self.conj_additional_list)] 83 | conjs_idxs_global += [len(original_tokens) - 1] 84 | 85 | idxs = [i for i in range(len(tokens))] 86 | idxs_to_segment = [] 87 | for i, token in enumerate(tokens): 88 | pos = poss[i] 89 | 90 | if (len(pos) > 3 and pos[-4:] == "CONJ") or token in self.conj_additional_list: 91 | if i != 0 and lemmatized_tokens[i - 1] not in self.to_exclude_list \ 92 | and poss[i - 1] not in ["ADJ"]: 93 | # print(pos, token) 94 | conj_idx_local = conjs_idxs_global.index(i) 95 | # print(conj_idx_local) 96 | 97 | if conj_idx_local < (len(conjs_idxs_global) - 1): 98 | next_conj_idx_global = conjs_idxs_global[conj_idx_local + 1] 99 | 100 | # if there's no verb in the next sequence, then don't segment 101 | if "VERB" in poss[i: next_conj_idx_global + 1] \ 102 | or "be" in lemmatized_tokens[i: next_conj_idx_global + 1]: 103 | # print(pos, token) 104 | idxs_to_segment.append(i) 105 | 106 | subevents = self.organize_subevents(tokens, lemmatized_tokens, poss, deps, 107 | idxs, idxs_to_segment, original_tokens) 108 | return subevents 109 | 110 | def _get_relative_clause(self, subevent): 111 | tokens = subevent[1] 112 | poss = subevent[3] 113 | 114 | for i, token in enumerate(tokens): 115 | if token in self.relative_pronounce_list and i != 0 and poss[i - 1] in ["NOUN"] \ 116 | or token in ["who", "whom"]: 117 | return [ 118 | [" ".join(tokens[:i])] + [l[:i] 119 | for l in subevent[1:-1]] + ["content"], 120 | [" ".join(tokens[i:i+1])] + [l[i:i+1] 121 | for l in subevent[1:-1]] + ["relative pronoun"], 122 | [" ".join(tokens[i+1:])] + [l[i+1:] 123 | for l in subevent[1:-1]] + ["relative clause"] 124 | ] 125 | 126 | def _get_all_subevents(self, event): 127 | event = self.fix_event(event) 128 | parsed_event = parse_sequence(event, is_dependency_parse=True) 129 | 130 | tokens = parsed_event["tokens"]["tokens_list"] 131 | lemmatized_tokens = parsed_event["lemmatized_tokens"]["tokens_list"] 132 | poss = [_token[2] for _token in parsed_event["tokens"]["tokens_dict"]] 133 | deps = [_token[3] for _token in parsed_event["tokens"]["tokens_dict"]] 134 | 135 | return self.get_subevents(tokens, lemmatized_tokens, poss, deps) 136 | 137 | def glue_subevents(self, subevents): 138 | glued_subevent = subevents[0] 139 | 140 | for e in subevents[1:]: 141 | for i, c in enumerate(e[1:-1]): 142 | glued_subevent[i + 1] += c 143 | 144 | glued_subevent[0] = " ".join(glued_subevent[1]) 145 | glued_subevent[-1] = "content" 146 | 147 | return glued_subevent 148 | 149 | def _get_parsed_event(self, event, is_simple=True): 150 | subevents = self._get_all_subevents(event) 151 | parsed_event = {} 152 | 153 | relative_clause = self._get_relative_clause(subevents[0]) 154 | if relative_clause != None: 155 | parsed_event["main_action"] = {"main_clause": relative_clause[0], 156 | "relative_pronoun": relative_clause[1], "relative_clause": relative_clause[2]} 157 | else: 158 | parsed_event["main_action"] = {"main_clause": subevents[0]} 159 | 160 | if len(subevents) > 1: 161 | parsed_event["connector"] = subevents[1] 162 | parsed_event["other_action"] = self.glue_subevents(subevents[2:]) 163 | 164 | if not is_simple: 165 | return parsed_event 166 | else: 167 | main_action = parsed_event["main_action"]["main_clause"][0] 168 | relative_pronoun = None 169 | relative_clause = None 170 | connector = None 171 | other_action = None 172 | if "relative_pronoun" in parsed_event["main_action"]: 173 | relative_pronoun = parsed_event["main_action"]["relative_pronoun"][0] 174 | relative_clause = parsed_event["main_action"]["relative_clause"][0] 175 | 176 | if "connector" in parsed_event: 177 | connector = parsed_event["connector"][0] 178 | other_action = parsed_event["other_action"][0] 179 | 180 | return {"main_action": main_action, 181 | "relative_pronoun": relative_pronoun, 182 | "relative_clause": relative_clause, 183 | "connector": connector, 184 | "other_action": other_action} 185 | 186 | def get_parsed_event(self, event): 187 | parsed_events = self._get_parsed_event(event, is_simple=True) 188 | 189 | main_event = parsed_events["main_action"] 190 | main_event_main_clause = parsed_events["main_action"] 191 | main_event_relative_pronoun = parsed_events["relative_pronoun"] 192 | main_event_relative_clause = parsed_events["relative_clause"] 193 | adjunct_event = parsed_events["other_action"] 194 | 195 | if main_event_main_clause != None and main_event_relative_clause != None: 196 | main_event = main_event_main_clause + " " + \ 197 | main_event_relative_pronoun + " " + main_event_relative_clause 198 | 199 | return {"main_event": main_event, 200 | "main_event_main_clause": main_event_main_clause, 201 | "main_event_relative_clause": main_event_relative_clause, 202 | "adjunct_event": adjunct_event, } 203 | 204 | 205 | if __name__ == "__main__": 206 | compositionality_parser = CompositionalityParser() 207 | 208 | events = [] 209 | for split in ["test", "validation"]: 210 | input_file = data_base_path + \ 211 | f"cache_norm_bank/events/clean_{split}.moral_acceptability.tsv" 212 | df_data = pd.read_csv(input_file, sep="\t") 213 | events += df_data["clean_event"].tolist() 214 | 215 | data_to_save = {} 216 | for event in tqdm(events): 217 | parsed_event = compositionality_parser.get_parsed_event(event) 218 | data_to_save[event] = parsed_event 219 | 220 | save_json(data_base_path + f"cache_norm_bank/constituents.json", data_to_save) 221 | -------------------------------------------------------------------------------- /src/delphi_hybrid/components/DelphiScorer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | 4 | import torch 5 | from scipy.special import softmax 6 | from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config 7 | 8 | 9 | class DelphiScorer: 10 | def __init__(self, device_id=1, model="t5-11b-1239200", server="beaker_batch"): 11 | CUDA_DEVICE = f"cuda:{device_id}" if torch.cuda.is_available() else 'cpu' 12 | self.device = torch.device(CUDA_DEVICE) 13 | print(f"DelphiScorer device: {self.device}", model) 14 | 15 | if model == "t5-large": 16 | MODEL_BASE = "t5-large" 17 | MODEL_LOCATION = "/net/nfs.cirrascale/mosaic/liweij/model/large_commonsense_morality_hf" 18 | self.class_token_pos = 4 19 | self.sep_tokens = [" /class> text>", " class>", " /text>"] 20 | 21 | elif model == "t5-11b": 22 | MODEL_BASE = "t5-11b" 23 | MODEL_LOCATION = "/net/nfs.cirrascale/mosaic/liweij/model/11b_commonsense_morality_hf" 24 | self.class_token_pos = 4 25 | self.sep_tokens = [" /class> text>", " class>", " /text>"] 26 | 27 | elif model == "t5-11b-1239200": 28 | MODEL_BASE = "t5-11b" 29 | if server == "beaker_batch": 30 | MODEL_LOCATION = "/model/delphi11b" 31 | else: 32 | MODEL_LOCATION = "/net/nfs.cirrascale/mosaic/liweij/model/11b_commonsense_morality_1239200_hf" 33 | self.class_token_pos = 4 34 | self.sep_tokens = [" /class> text>", " class>", " /text>"] 35 | 36 | elif model == "v11_distribution": 37 | MODEL_BASE = "t5-11b" 38 | MODEL_LOCATION = "/net/nfs.cirrascale/mosaic/liweij/model/v11_distribution_hf" 39 | self.class_token_pos = 3 40 | self.sep_tokens = ["[/class] [text]", "[class]", "[/text]"] 41 | 42 | elif model == "11B_001": 43 | MODEL_BASE = "t5-11b" 44 | MODEL_LOCATION = "/net/nfs.cirrascale/mosaic/liweij/model/11B_001_hf" 45 | self.class_token_pos = 4 46 | self.sep_tokens = [" /class> text>", " class>", " /text>"] 47 | 48 | else: 49 | print("ERROR: model doesn't exist") 50 | return 51 | 52 | self.model = T5ForConditionalGeneration.from_pretrained(MODEL_LOCATION) 53 | self.model.to(self.device) 54 | self.tokenizer = T5Tokenizer.from_pretrained(MODEL_BASE, model_max_length=512) 55 | 56 | def score(self, input_string, normalize=None): 57 | input_string = f"[moral_single]: {input_string}" 58 | input_ids = self.tokenizer(input_string, return_tensors='pt').to(self.device).input_ids 59 | outputs = self.model.generate(input_ids, max_length=200, output_scores=True, return_dict_in_generate=True) 60 | 61 | probs = [(self.tokenizer.decode(i), x) for (i, x) in enumerate(outputs['scores'][self.class_token_pos][0].softmax(0))] 62 | 63 | class1_prob = sum([v[1].item() for v in probs if v[0] == "1"]) 64 | class0_prob = sum([v[1].item() for v in probs if v[0] == "0"]) 65 | classminus1_prob = sum([v[1].item() for v in probs if v[0] == "-1"]) 66 | 67 | probs = [class1_prob, class0_prob, classminus1_prob] 68 | probs_sum = sum(probs) 69 | 70 | if normalize == "regular": 71 | probs = [p / probs_sum for p in probs] 72 | elif normalize == "softmax": 73 | probs = softmax(probs) 74 | 75 | return probs 76 | 77 | def generate(self, input_string): 78 | input_string = f"[moral_single]: {input_string}" 79 | input_ids = self.tokenizer(input_string, return_tensors='pt').to(self.device).input_ids 80 | outputs = self.model.generate(input_ids, max_length=200, output_scores=True, return_dict_in_generate=True) 81 | 82 | decoded_sequence = self.tokenizer.decode(outputs["sequences"][0]) 83 | class_label = int(decoded_sequence.split(self.sep_tokens[0])[0].split(self.sep_tokens[1])[-1]) 84 | text_label = decoded_sequence.split(self.sep_tokens[0])[-1].split(self.sep_tokens[2])[0] 85 | 86 | return class_label, text_label 87 | 88 | 89 | def generate_beam(self, 90 | input_string, 91 | num_beams=5, 92 | max_length=50, 93 | num_return_sequences=5,): 94 | input_string = f"[moral_single]: {input_string}" 95 | input_ids = self.tokenizer(input_string, 96 | max_length=16, 97 | truncation=True, 98 | return_tensors='pt').to(self.device).input_ids 99 | outputs = self.model.generate(input_ids, 100 | num_beams=num_beams, 101 | max_length=max_length, 102 | num_return_sequences=num_return_sequences,) 103 | 104 | decoded_sequences = self.tokenizer.batch_decode(outputs) 105 | 106 | class_labels = [ds.split(self.sep_tokens[0])[0].split(self.sep_tokens[1])[-1] for ds in decoded_sequences] 107 | text_labels = [ds.split(self.sep_tokens[0])[-1].split(self.sep_tokens[2])[0] for ds in decoded_sequences] 108 | 109 | return class_labels, text_labels 110 | 111 | 112 | def generate_with_score(self, input_string): 113 | input_string = f"[moral_single]: {input_string}" 114 | input_ids = self.tokenizer(input_string, max_length=512, return_tensors='pt').to(self.device).input_ids 115 | outputs = self.model.generate(input_ids, max_length=200, output_scores=True, return_dict_in_generate=True) 116 | 117 | probs = [(self.tokenizer.decode(i), x) for (i, x) in enumerate(outputs['scores'][self.class_token_pos][0].softmax(0))] 118 | 119 | class1_prob = sum([v[1].item() for v in probs if v[0] == "1"]) 120 | class0_prob = sum([v[1].item() for v in probs if v[0] == "0"]) 121 | classminus1_prob = sum([v[1].item() for v in probs if v[0] == "-1"]) 122 | 123 | probs = [class1_prob, class0_prob, classminus1_prob] 124 | # probs_sum = sum(probs) 125 | 126 | decoded_sequence = self.tokenizer.decode(outputs["sequences"][0]) 127 | class_label = int(decoded_sequence.split(self.sep_tokens[0])[0].split(self.sep_tokens[1])[-1]) 128 | text_label = decoded_sequence.split(self.sep_tokens[0])[-1].split(self.sep_tokens[2])[0] 129 | 130 | return class_label, probs, text_label 131 | 132 | 133 | def generate_with_score_comparison(self, action1, action2): 134 | input_string = f"[moral_pair]: {action1} {action2}" 135 | input_ids = self.tokenizer(input_string, return_tensors='pt').to(self.device).input_ids 136 | outputs = self.model.generate(input_ids, max_length=200, output_scores=True, return_dict_in_generate=True) 137 | 138 | probs = [(self.tokenizer.decode(i), x) for (i, x) in enumerate(outputs['scores'][0][0].softmax(0))] 139 | 140 | class1_prob = sum([v[1].item() for v in probs if v[0] == "1"]) 141 | class2_prob = sum([v[1].item() for v in probs if v[0] == "2"]) 142 | 143 | probs = [class1_prob, class2_prob] 144 | # probs_sum = sum(probs) 145 | 146 | class_label = int(self.tokenizer.decode(outputs["sequences"][0], skip_special_tokens=True)) 147 | # print(class1_prob, class2_prob, class_label) 148 | 149 | return class_label, probs 150 | 151 | 152 | if __name__ == '__main__': 153 | # delphi_scorer = DelphiScorer(device_id=0, model="t5-large") 154 | delphi_scorer = DelphiScorer(device_id=0, model="t5-11b-1239200") 155 | # print(delphi_scorer.generate_with_score("killing a cat")) 156 | # e_list = ["killing a cat", 157 | # "killing a dog", 158 | # "killing a child"] 159 | # 160 | # for e in e_list: 161 | # print(delphi_scorer.generate_with_score(e)) 162 | 163 | print(delphi_scorer.generate_with_score_comparison("killing a cat", "killing an evil cat")) 164 | 165 | 166 | # for i in range(5): 167 | # print(delphi_scorer.generate_beam("Gay parents adopting a child"))\ 168 | 169 | # for seq in ["Gay parents adopting a child", "quentin tarantino", "Can I kill a terrorist?", "Protecting China"]: 170 | # seq = 171 | # print(delphi_scorer.generate(seq)) 172 | 173 | -------------------------------------------------------------------------------- /src/delphi_hybrid/components/GPT3Scorer.py: -------------------------------------------------------------------------------- 1 | from scripts.sub_event_extraction.utils.LMScorer import * 2 | 3 | 4 | import math 5 | import openai 6 | 7 | 8 | class GPT3Scorer(LMScorer): 9 | def __init__(self, model_name="text-davinci-003"): # "text-davinci-002" 10 | super().__init__(model_name) 11 | 12 | self.MODEL_NAME = model_name 13 | openai.api_key = os.environ["OPENAI_API_KEY"] 14 | self.CONDITIONED_GEN_TOKEN = "<|endoftext|>" 15 | 16 | def correct_grammatical_error(self, input_sequence): 17 | response = openai.Completion.create( 18 | model=self.MODEL_NAME, 19 | prompt="Correct this to standard English:\n\n" + input_sequence, 20 | temperature=0, 21 | max_tokens=60, 22 | top_p=1.0, 23 | frequency_penalty=0.0, 24 | presence_penalty=0.0 25 | ) 26 | corrected_sequence = response["choices"][0]["text"].split("\n\n")[-1] 27 | return corrected_sequence 28 | 29 | def get_input_perplexity(self, input_sequence): 30 | """ 31 | Helper function to get the perplexity of a sequence from GPT3 32 | """ 33 | # formatted_input_sequence = CONDITIONED_GEN_TOKEN + input_sequence[0].upper() + input_sequence[1:] 34 | # formatted_input_sequence = CONDITIONED_GEN_TOKEN + input_sequence[0].upper() + input_sequence[1:] + "." 35 | # formatted_input_sequence = CONDITIONED_GEN_TOKEN + "'" + input_sequence[0].upper() + input_sequence[1:] + ".'" 36 | formatted_input_sequence = self.CONDITIONED_GEN_TOKEN + input_sequence 37 | return self._get_input_perplexity(formatted_input_sequence) 38 | 39 | def _get_input_perplexity(self, formatted_input_sequence): 40 | """ 41 | Helper function to get the perplexity of a sequence from GPT3 42 | """ 43 | response = openai.Completion.create( 44 | model=self.MODEL_NAME, 45 | prompt=formatted_input_sequence, 46 | # prompt=formatted_input_sequence 47 | max_tokens=0, 48 | logprobs=1, 49 | echo=True 50 | ) 51 | # print(formatted_input_sequence) 52 | 53 | tokens_logprobs = response["choices"][0]["logprobs"]["token_logprobs"][1:] 54 | num_tokens = len(tokens_logprobs) 55 | sequence_cross_entropy = -sum(tokens_logprobs) / num_tokens 56 | perplexity = math.exp(sequence_cross_entropy) 57 | return perplexity 58 | 59 | def get_input_perplexity_combo(self, input_sequence, return_all_ppl=False): 60 | """ 61 | Get a combined perplexity averaged across multiple input sequence formats. 62 | """ 63 | formatted_input_sequences = [ 64 | self.CONDITIONED_GEN_TOKEN + input_sequence[0].upper() + input_sequence[1:], 65 | self.CONDITIONED_GEN_TOKEN + input_sequence[0].upper() + input_sequence[1:] + ".", 66 | self.CONDITIONED_GEN_TOKEN + '"' + input_sequence[0].upper() + input_sequence[1:] + '"', 67 | ] 68 | 69 | return self._get_input_perplexity_combo(formatted_input_sequences, return_all_ppl) 70 | 71 | def get_input_logprob_gpt3(self, input_sequence): 72 | """ 73 | Helper function to get the perplexity of a sequence from GPT3 74 | """ 75 | formatted_input_sequence = self.CONDITIONED_GEN_TOKEN + '"' + input_sequence[0].upper() + input_sequence[1:] + '"' 76 | response = openai.Completion.create( 77 | model=self.MODEL_NAME, 78 | prompt=formatted_input_sequence, 79 | max_tokens=0, 80 | logprobs=1, 81 | echo=True 82 | ) 83 | 84 | tokens_logprobs = response["choices"][0]["logprobs"]["token_logprobs"][1:] 85 | return sum(tokens_logprobs) 86 | 87 | if __name__ == "__main__": 88 | gpt3_scorer = GPT3Scorer() 89 | 90 | # premise = "killing a bear to save your child." 91 | # print(gpt3_scorer.get_input_perplexity_combo(premise)) 92 | 93 | # premise = "killing a stupid bear to save your child" 94 | # print(gpt3_scorer.get_paraphrases(premise)) 95 | -------------------------------------------------------------------------------- /src/delphi_hybrid/components/LMScorer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from transformers import GPT2LMHeadModel, GPT2TokenizerFast 4 | 5 | 6 | class LMScorer: 7 | def __init__(self, model_name): 8 | print("Initialize model:",model_name) 9 | 10 | def get_input_perplexity(self, input_sequence): 11 | return 12 | 13 | def _get_input_perplexity(self, formatted_input_sequence): 14 | return 15 | 16 | def _get_input_perplexity_combo(self, formatted_input_sequences, return_all_ppl=False): 17 | all_ppl = [] 18 | for formatted_input_sequence in formatted_input_sequences: 19 | ppl = self._get_input_perplexity(formatted_input_sequence) 20 | all_ppl.append(ppl) 21 | 22 | if return_all_ppl: 23 | return formatted_input_sequences, (all_ppl + [(sum(all_ppl) / len(all_ppl))]) 24 | else: 25 | return (sum(all_ppl) / len(all_ppl)) 26 | 27 | def get_input_perplexity_combo(self, input_sequence, return_all_ppl=False): 28 | return 29 | 30 | -------------------------------------------------------------------------------- /src/delphi_hybrid/components/MoralSaliencyKeywordCounter.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | sys.path.append(os.getcwd()) 6 | from src.delphi_hybrid.components.DelphiScorer import * 7 | from src.delphi_hybrid.components.CacheHandler.CacheHandler import * 8 | 9 | class MoralSaliencyKeywordCounter(): 10 | def __init__(self, filename=data_base_path + f"cache/keywords.json"): 11 | self.keywords_cache = read_json(filename) 12 | 13 | def get_events_keyword_counts(self, e, keyword): 14 | e_keywords = self.keywords_cache[e] 15 | 16 | all_keywords_selected = all_keywords_categorize[keyword] 17 | keywords_counts = {k: 0 for k in all_keywords_selected} 18 | 19 | for k in all_keywords_selected: 20 | if k in e_keywords: 21 | keywords_counts[k] = e_keywords[k] 22 | return keywords_counts 23 | 24 | def get_all_top_level_keyword_count(self, keywords_map): 25 | return sum([keywords_map[k] for k in keywords_map]) 26 | 27 | def get_event_top_level_keyword_count(self, keywords_counts): 28 | return sum([keywords_counts[k] for k in keywords_counts]) 29 | 30 | def get_top_level_keywords_counts_map(self, top_level_keyword): 31 | top_level_keywords_counts_map = {} 32 | for e in self.keywords_cache: 33 | keywords_counts = self.get_events_keyword_counts(e, top_level_keyword) 34 | keywords_counts_map[e] = self.get_all_top_level_keyword_count(keywords_counts) 35 | return keywords_counts_map 36 | 37 | if __name__ == "__main__": 38 | keyword_counter = MoralSaliencyKeywordCounter() 39 | # instance = cache_handler.get_instance("Killing a bear") 40 | 41 | -------------------------------------------------------------------------------- /src/delphi_hybrid/components/Paraphraser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.getcwd()) 4 | 5 | import pandas as pd 6 | import re 7 | import random 8 | from src.delphi_hybrid.components.utils import * 9 | from src.delphi_hybrid.components.WANLIScorer import WANLIScorer 10 | 11 | # S: Putting my mother in a retirement home. 12 | # A: Placing my mother in a retirement home. 13 | # A: Killing someone because you are being kind. 14 | 15 | PROMPT_SOURCE = """ 16 | Paraphrase the given sentence. 17 | S: Torturing a cat if other people cannot see it. 18 | A: {} 19 | S: Not caring for Christmas. 20 | A: Being indifferent towards Christmas. 21 | S: Killing someone with kindness. 22 | A: Killing someone for showing your kindness. 23 | S: {} 24 | A: 25 | """ 26 | 27 | PARAPHRASE_OPTIONS = [ 28 | "Torturing a cat assuming that other people cannot see it.", 29 | "Torturing a cat on the assumption that other people cannot see it.", 30 | "Torturing a cat in case other people cannot see it.", 31 | "Torturing a cat, given that other people cannot see it.", 32 | "Torturing a cat if nobody can see it." 33 | ] 34 | 35 | 36 | class Paraphraser(): 37 | def __init__(self, model_name="text-davinci-003"): # "text-davinci-003", "text-curie-001" 38 | self.MODEL_NAME = model_name 39 | openai.api_key = os.environ["OPENAI_API_KEY"] 40 | self.CONDITIONED_GEN_TOKEN = "<|endoftext|>" 41 | self.nli_scorer = WANLIScorer() 42 | 43 | def fix_sentence(self, input_str): 44 | if input_str == "": # Don't change empty strings. 45 | return input_str 46 | if input_str[-1] in ["?", ".", "!"]: # Don't change if already okay. 47 | return input_str 48 | if input_str[-1] == ",": # Change trailing ',' to '.'. 49 | return input_str[:-1] + "." 50 | return input_str + "." 51 | 52 | def _qualify_paraphrase(self, action_doc, paraphrased_action_doc): 53 | action = action_doc.text.lower() 54 | paraphrased_action = paraphrased_action_doc.text.lower() 55 | 56 | num_tokens_action = len(action_doc) 57 | num_tokens_paraphrased_action = len(paraphrased_action_doc) 58 | 59 | action_lemma = " ".join([token.lemma_ for token in action_doc]).lower() 60 | paraphrased_action_lemma = " ".join([token.lemma_ for token in paraphrased_action_doc]).lower() 61 | 62 | if action_lemma in paraphrased_action_lemma and action_lemma != paraphrased_action_lemma: 63 | return False 64 | 65 | if any(paraphrased_action.startswith(prefix) for prefix in ["it's ", "it is "]) \ 66 | and not any(action.startswith(prefix) for prefix in ["it's ", "it is "]): 67 | return False 68 | 69 | if ":" not in action and ":" in paraphrased_action: 70 | return False 71 | 72 | if not re.search('[a-zA-Z]', paraphrased_action): 73 | return False 74 | 75 | if paraphrased_action.lower() in [action.lower(), "n/a"]: 76 | return False 77 | 78 | return abs(num_tokens_action - num_tokens_paraphrased_action) / num_tokens_action < 1 79 | 80 | def qualify_paraphrase(self, action, paraphrased_action): 81 | action_doc = nlp(action) 82 | paraphrased_action_doc = nlp(paraphrased_action) 83 | return self._qualify_paraphrase(action_doc, paraphrased_action_doc) 84 | 85 | def generate_paraphrases(self, action, num_paraphrases=1, max_iteration_count=20): 86 | fixed_action = self.fix_sentence(action.capitalize()) 87 | action_doc = nlp(action) 88 | 89 | paraphrases = [] 90 | iteration_count = 0 91 | while len(paraphrases) < num_paraphrases and iteration_count < max_iteration_count: 92 | _po = random.sample(PARAPHRASE_OPTIONS, 1)[0] 93 | _p = PROMPT_SOURCE.format(_po, fixed_action) 94 | 95 | # try: 96 | # response = gpt3_completion(_p, model_name="text-davinci-003", max_tokens=100, temperature=0.9, 97 | # logprobs=0, echo=False, num_outputs=10, top_p=1.0, best_of=1) 98 | # iteration_count += 1 99 | # except Exception as e: 100 | # continue 101 | 102 | response = gen_gpt3(_p, n=5, model=self.MODEL_NAME, temperature=0.8, max_tokens=400) 103 | iteration_count += 1 104 | paraphrased_action = [c['text'].strip() for c in response['choices']] 105 | 106 | for paraphrase in paraphrased_action: 107 | paraphrase_doc = nlp(paraphrase) 108 | if self._qualify_paraphrase(action_doc, paraphrase_doc): 109 | paraphrases.append(paraphrase) 110 | paraphrases = list(set(paraphrases)) 111 | print(len(paraphrases)) 112 | 113 | return { 114 | "action": action, 115 | "fixed_action": fixed_action, 116 | "paraphrases": paraphrases 117 | } 118 | 119 | if __name__ == "__main__": 120 | # paraphraser = Paraphraser() 121 | 122 | # action = "genocide if it creates job" 123 | # action = "Having sex with your best friend" 124 | # action = "Putting my mother in a retirement home" 125 | # action = "Killing a baby because you really want to" 126 | # action = "Stabbing someone who opened your mail without permission" 127 | # action = "Killing someone with kindness" 128 | # action = "Robbing a bank because it is my job" 129 | # action = "Let people play games and allow them to win lots of money" 130 | # action = "My husband doesn't know that our daughter is not his" 131 | # action = "Making money instead of saving someone" 132 | # print(paraphraser.generate_paraphrases(action, num_paraphrases=6)) 133 | 134 | # premise = "killing a bear to save your child." 135 | # hypothesis = "killing a bear." 136 | # prediction = wanli_scorer.get_scores(premise, hypothesis) 137 | # print(premise, "|", hypothesis, ":", prediction) 138 | openai.api_key = os.environ["OPENAI_API_KEY"] 139 | print(gen_gpt3("_p", n=20, model="text-davinci-003", temperature=0.7, max_tokens=400)) 140 | 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /src/delphi_hybrid/components/WANLIScorer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import RobertaTokenizer, RobertaForSequenceClassification 3 | 4 | 5 | class WANLIScorer: 6 | 7 | def __init__(self): 8 | self.model = RobertaForSequenceClassification.from_pretrained('alisawuffles/roberta-large-wanli') 9 | self.tokenizer = RobertaTokenizer.from_pretrained('alisawuffles/roberta-large-wanli') 10 | 11 | def get_scores(self, premise, hypothesis, is_return_score=True): 12 | x = self.tokenizer(premise, hypothesis, hypothesis, return_tensors='pt', max_length=128, truncation=True) 13 | logits = self.model(**x).logits 14 | probs = logits.softmax(dim=1).squeeze(0) 15 | label_id = torch.argmax(probs).item() 16 | prediction = self.model.config.id2label[label_id] 17 | 18 | if is_return_score: 19 | scores = {"contradiction": probs[0].item(), 20 | "entailment": probs[1].item(), 21 | "neutral": probs[2].item()} 22 | return scores 23 | return prediction 24 | 25 | 26 | if __name__ == "__main__": 27 | wanli_scorer = WANLIScorer() 28 | # premise = "killing a bear to save your child." 29 | # hypothesis = "killing a bear." 30 | # prediction = wanli_scorer.get_scores(premise, hypothesis) 31 | # print(premise, "|", hypothesis, ":", prediction) 32 | # 33 | # hypothesis = "killing a child." 34 | # prediction = wanli_scorer.get_scores(premise, hypothesis) 35 | # print(premise, "|", hypothesis, ":", prediction) 36 | # 37 | # hypothesis = "save your child." 38 | # prediction = wanli_scorer.get_scores(premise, hypothesis) 39 | # print(premise, "|", hypothesis, ":", prediction) 40 | # 41 | # hypothesis = "bear your child." 42 | # prediction = wanli_scorer.get_scores(premise, hypothesis) 43 | # print(premise, "|", hypothesis, ":", prediction) 44 | # 45 | # hypothesis = "bear." 46 | # prediction = wanli_scorer.get_scores(premise, hypothesis) 47 | # print(premise, "|", hypothesis, ":", prediction) 48 | # 49 | # hypothesis = "save." 50 | # prediction = wanli_scorer.get_scores(premise, hypothesis) 51 | # print(premise, "|", hypothesis, ":", prediction) 52 | # 53 | # hypothesis = "killing." 54 | # prediction = wanli_scorer.get_scores(premise, hypothesis) 55 | # print(premise, "|", hypothesis, ":", prediction) 56 | 57 | 58 | premise = "Killing a bear if I attack it." 59 | hypothesis = "Killing a bear that attacks you" 60 | prediction = wanli_scorer.get_scores(premise, hypothesis) 61 | print(premise, "|", hypothesis, ":", prediction) 62 | 63 | hypothesis = "Killing a bear if I attack it." 64 | premise = "Killing a bear that attacks you" 65 | prediction = wanli_scorer.get_scores(premise, hypothesis) 66 | print(premise, "|", hypothesis, ":", prediction) 67 | -------------------------------------------------------------------------------- /src/delphi_hybrid/components/constants.py: -------------------------------------------------------------------------------- 1 | base_path = "/net/nfs.cirrascale/mosaic/liweij/delphi_algo/" 2 | data_base_path = base_path + "data/" 3 | results_base_path = base_path + "results/" 4 | -------------------------------------------------------------------------------- /src/delphi_hybrid/prepare_data/compile_demo_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import pandas as pd 5 | 6 | # sys.path.append(os.getcwd()) 7 | 8 | sys.path.append("/Users/liweijiang/Desktop/delphi_algo/scripts/utils") 9 | # print(os.getcwd()) 10 | 11 | from main_utils import * 12 | 13 | random.seed(10) 14 | 15 | 16 | def get_event_len(e): 17 | return len(e) 18 | 19 | def filter_by_keywords_text_label(e): 20 | kw_to_exclude = ["yes,", "no,"] 21 | return not any(kw in e.lower() for kw in kw_to_exclude) 22 | 23 | def filter_by_keywords_event(e): 24 | kw_to_exclude = ["?", "\"", "yeri", "izone", "aaaaaaaaaaaaaaaaaaaaaaaaaaaa", 25 | "delphi", "is a ", "hhhhhooooooo", "poiuytpoiuytrpoiuytrezapoiuytrezaqwxcvbnracist", 26 | "rhum", "..", "ben"] 27 | return not any(kw in e.lower() for kw in kw_to_exclude) 28 | 29 | def filter_by_startswith_keywords_event(e): 30 | kw_to_exclude = ["can", "is ", "should", "how", "what", "when", 31 | "are ", "why"] 32 | return not any(e.lower().startswith(kw) for kw in kw_to_exclude) 33 | 34 | def is_in_english(e): 35 | try: 36 | e.encode(encoding='utf-8').decode('ascii') 37 | except UnicodeDecodeError: 38 | return False 39 | else: 40 | return True 41 | 42 | def get_upper_letter_count(e): 43 | return sum(1 for c in (e[0].lower() + e[1:]) if c.isupper()) 44 | 45 | def main(): 46 | df_data_done = pd.read_csv("/data/demo/mturk/result/v2_v3_comprehensible_clean.csv", sep=",") 47 | event_done = df_data_done["event"].tolist() 48 | 49 | df_data = pd.read_csv("/data/demo/demo_examples_102721_delphi.csv", sep="\t") 50 | df_data["action1_len"] = df_data["action1"].apply(get_event_len) 51 | df_data = df_data.dropna(subset=["action1", "text_label"]) 52 | 53 | df_data = df_data[~df_data["action1"].isin(event_done)] 54 | df_data = df_data[df_data["action1_len"] > 25] 55 | df_data = df_data[df_data["action1_len"] < 200] 56 | df_data = df_data[df_data["text_label"].apply(filter_by_keywords_text_label)] 57 | df_data = df_data[df_data["action1"].apply(filter_by_keywords_event)] 58 | df_data = df_data[df_data["action1"].apply(filter_by_startswith_keywords_event)] 59 | df_data = df_data[df_data["action1"].apply(is_in_english)] 60 | 61 | df_data["clean_event"] = df_data["action1"].apply(normalize_event) 62 | 63 | # print(df_data) any(kw in e.lower() for kw in remove_keywords_list) 64 | 65 | events_to_annotate = df_data["clean_event"].value_counts()[:10000].keys().tolist() 66 | # print(events_to_annotate) 67 | 68 | 69 | df_data["question1"] = df_data["clean_event"] 70 | 71 | df_data_selected = df_data[df_data["clean_event"].isin(events_to_annotate)] 72 | df_data_selected = df_data_selected[["action1", "class_label", "text_label", "clean_event", "question1"]] 73 | 74 | df_data_selected = df_data_selected.drop_duplicates(subset=["clean_event"]) 75 | df_data_selected.to_csv("/Users/liweijiang/Desktop/delphi_algo/data/demo/mturk/input/demo_events_to_annotate_v4.csv", index=False) 76 | 77 | # df_data = pd.read_csv("/Users/liweijiang/Desktop/delphi_algo/data/demo/mturk/input/demo_events_to_annotate_v4.csv", sep=",") 78 | # df_data = df_data.rename(columns={"action1": "raw_event"}) 79 | # df_data.to_csv("/Users/liweijiang/Desktop/delphi_algo/data/demo/mturk/input/demo_events_to_annotate_v4.csv", index=False) 80 | 81 | def compile_annotated_data(): 82 | df_data = pd.read_csv("/data/demo/mturk/result/demo_events_to_annotate_v4.csv", 83 | sep=",") 84 | 85 | df_data_done = pd.read_csv( 86 | "/data/demo/mturk/result/v2_v3_comprehensible_clean.csv", sep=",") 87 | event_done = df_data_done["event"].tolist() 88 | 89 | headers = ["Input.action1", "Input.class_label", 90 | "Input.text_label", "Input.clean_event", 91 | "Answer.controversial_1", "Answer.lewd_1", 92 | "Answer.feedback", "Answer.is_action_1", 93 | "Answer.make_sense_1", "Answer.privacy_1"] 94 | 95 | rename_map = {"Input.action1": "event", 96 | "Input.class_label": "class_label", 97 | "Input.text_label": "text_label", 98 | "Input.clean_event": "clean_event", 99 | "Answer.controversial_1": "is_controversial", 100 | "Answer.lewd_1": "is_lewd", 101 | "Answer.feedback": "feedback", 102 | "Answer.is_action_1": "is_action", 103 | "Answer.make_sense_1": "is_make_sense", 104 | "Answer.privacy_1": "is_privacy"} 105 | 106 | df_data = df_data[headers] 107 | df_data = df_data.rename(columns=rename_map) 108 | df_data = df_data[~df_data["event"].isin(event_done)] 109 | df_data = df_data[df_data["feedback"] == "{}"] 110 | df_data = df_data[df_data["is_make_sense"] == 1.0] 111 | df_data = df_data[df_data["is_lewd"] == -1.0] 112 | df_data = df_data[df_data["is_action"] == 1.0] 113 | df_data = df_data[df_data["is_privacy"] == -1.0] 114 | df_data = df_data[df_data["event"].apply(filter_by_keywords_event)] 115 | 116 | df_data["event_len"] = df_data["event"].apply(get_event_len) 117 | df_data = df_data[df_data["event_len"] < 120] 118 | df_data = df_data[df_data["text_label"].apply(filter_by_keywords_text_label)] 119 | 120 | df_data["upper_letter_count"] = df_data["event"].apply(get_upper_letter_count) 121 | df_data = df_data[df_data["upper_letter_count"] < 2] 122 | 123 | df_data = df_data.sort_values(by=["event_len"]) 124 | print(df_data["event_len"].mean()) 125 | 126 | df_data_non_controversial = df_data[df_data["is_controversial"] == -1.0] 127 | df_data_non_controversial = df_data_non_controversial.sample(n=2500, random_state=0) 128 | df_data_controversial = df_data[df_data["is_controversial"] == 1.0] 129 | df_data_controversial = df_data_controversial.sample(n=4000 - df_data_non_controversial.shape[0], random_state=0) 130 | 131 | df_data_selected = pd.concat([df_data_non_controversial, df_data_controversial], ignore_index=True, sort=False) 132 | df_data_selected.to_csv("/Users/liweijiang/Desktop/delphi_algo/data/demo/mturk/input/demo_events_to_annotate_v4_selected.csv", index=False) 133 | input_events = df_data_selected["event"].tolist() 134 | 135 | df_data_input = pd.DataFrame() 136 | for i in range(4): 137 | df_data_input[f"action{i+1}"] = input_events[i * 1000: (i+1) * 1000] 138 | df_data_input.to_csv("/Users/liweijiang/Desktop/delphi_algo/data/demo/mturk/input/demo_events_to_annotate_v4_selected_input.csv", 139 | index=False) 140 | 141 | # print(len(event_done), len(set(event_done))) 142 | # print(len(input_events), len(set(input_events))) 143 | # print(len(set(event_done + input_events))) 144 | 145 | 146 | def compile_v234_data(): 147 | df_data_v4 = pd.read_csv( 148 | "/data/demo/mturk/input/demo_events_to_annotate_v4_selected.csv", sep=",") 149 | event_v4 = df_data_v4["event"].tolist() 150 | 151 | df_data_v2_v3 = pd.read_csv( 152 | "/data/demo/mturk/result/v2_v3_comprehensible_clean.csv", sep=",") 153 | event_v2_v3 = df_data_v2_v3["event"].tolist() 154 | # print(len(event_v4)) 155 | # print(len(event_v2_v3)) 156 | 157 | all_events = event_v2_v3 + event_v4 158 | random.shuffle(all_events) 159 | 160 | # train_event = all_events[0: int(0.4 * len(all_events))] 161 | # dev_event = all_events[int(0.4 * len(all_events)): int(0.7 * len(all_events))] 162 | # test_event = all_events[int(0.7 * len(all_events)):] 163 | 164 | train_split_labels = ["train" for _ in range(int(0.5 * len(all_events)) + 2)] \ 165 | + ["dev" for _ in range(int(0.25 * len(all_events)))] \ 166 | + ["test" for _ in range(int(0.25 * len(all_events)))] 167 | 168 | df_data = pd.DataFrame() 169 | df_data["event"] = all_events 170 | df_data["split"] = train_split_labels 171 | 172 | df_data["source"] = "v23" 173 | df_data.loc[df_data["event"].isin(event_v4), "source"] = "v4" 174 | 175 | # df_data["source"] = df_data["source"] 176 | df_data["clean_event"] = df_data["event"].apply(normalize_event) 177 | 178 | # print(df_data["source"].value_counts()) 179 | df_data.to_csv("/Users/liweijiang/Desktop/delphi_algo/data/demo/mturk/split/event_only_v234.csv", index=False, sep="\t") 180 | # print(len(train_event)) 181 | # print(len(dev_event)) 182 | # print(len(test_event)) 183 | 184 | 185 | def compile_v5_data(): 186 | """ 187 | remove some controversial data from v4, and combine with v23 data to form v5 188 | """ 189 | df_data_v4 = pd.read_csv( 190 | "/data/demo/mturk/input/demo_events_to_annotate_v4_selected.csv", sep=",") 191 | df_data_v4 = df_data_v4[df_data_v4["is_controversial"] == -1.0] 192 | event_v4 = df_data_v4["event"].tolist() 193 | 194 | df_data_v2_v3 = pd.read_csv( 195 | "/data/demo/mturk/result/v2_v3_comprehensible_clean.csv", sep=",") 196 | event_v2_v3 = df_data_v2_v3["event"].tolist() 197 | # print(len(event_v4)) 198 | # print(len(event_v2_v3)) 199 | 200 | all_events = event_v2_v3 + event_v4 201 | random.shuffle(all_events) 202 | 203 | # train_event = all_events[0: int(0.4 * len(all_events))] 204 | # dev_event = all_events[int(0.4 * len(all_events)): int(0.7 * len(all_events))] 205 | # test_event = all_events[int(0.7 * len(all_events)):] 206 | 207 | train_split_labels = ["train" for _ in range(int(0.5 * len(all_events)) + 2)] \ 208 | + ["dev" for _ in range(int(0.25 * len(all_events)))] \ 209 | + ["test" for _ in range(int(0.25 * len(all_events)))] 210 | 211 | df_data = pd.DataFrame() 212 | df_data["event"] = all_events 213 | df_data["split"] = train_split_labels 214 | 215 | df_data["source"] = "v23" 216 | df_data.loc[df_data["event"].isin(event_v4), "source"] = "v4" 217 | 218 | df_data["clean_event"] = df_data["event"].apply(normalize_event) 219 | 220 | # df_data["source"] = df_data["source"] 221 | 222 | df_data.to_csv("/Users/liweijiang/Desktop/delphi_algo/data/demo/mturk/split/event_only_v5.csv", index=False, sep="\t") 223 | 224 | 225 | if __name__ == "__main__": 226 | # compile_v234_data() 227 | compile_v5_data() 228 | 229 | -------------------------------------------------------------------------------- /src/delphi_hybrid/prepare_data/compile_gold_labels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from collections import Counter 4 | 5 | sys.path.append(os.getcwd()) 6 | 7 | from scripts.utils.utils import * 8 | 9 | def get_maj_vote_class_label(raw_class_labels): 10 | class_label_counts = Counter(raw_class_labels) 11 | 12 | if len(set(raw_class_labels)) == 1: 13 | return raw_class_labels[0] 14 | elif len(set(raw_class_labels)) == 2: 15 | if -1 not in raw_class_labels: 16 | count_0 = class_label_counts[0] 17 | count_1 = class_label_counts[1] 18 | 19 | if count_0 > count_1: 20 | return 0 21 | else: 22 | return 1 23 | else: 24 | count_minus_1 = class_label_counts[-1] 25 | if 0 in class_label_counts: 26 | count_POS = class_label_counts[0] 27 | else: 28 | count_POS = class_label_counts[1] 29 | 30 | if count_minus_1 > count_POS: 31 | return -1 32 | else: 33 | if 0 in class_label_counts: 34 | return 0 35 | else: 36 | return 1 37 | else: 38 | count_minus_1 = class_label_counts[-1] 39 | count_0 = class_label_counts[0] 40 | count_1 = class_label_counts[1] 41 | 42 | if count_minus_1 == 3: 43 | return -1 44 | elif count_minus_1 == 2: 45 | if count_0 == 2: 46 | return 0 47 | elif count_1 == 2: 48 | return 1 49 | else: 50 | if count_0 == 3: 51 | return 0 52 | elif count_1 == 3: 53 | return 1 54 | else: 55 | return 0 56 | 57 | 58 | def get_agreement_rate(raw_class_labels): 59 | class_label_counts = Counter(raw_class_labels) 60 | 61 | if -1 in class_label_counts: 62 | count_minus_1 = class_label_counts[-1] 63 | if count_minus_1 in [2, 3]: 64 | return 3/5 65 | elif count_minus_1 in [5]: 66 | return 1 67 | else: 68 | return 4/5 69 | else: 70 | return 1 71 | 72 | 73 | def compile_gold_labels_v5(): 74 | data_path = "data/demo/mturk/result/v5.csv" 75 | df_data = pd.read_csv(data_path) 76 | 77 | # print(df_data) 78 | 79 | event_to_raw_labels = {} 80 | all_evants = [] 81 | for i, row in df_data.iterrows(): 82 | for j in range(4): 83 | event = row[f"Input.action{j+1}"] 84 | label = row[f"Answer.judgment_{j+1}"] 85 | 86 | all_evants.append(event) 87 | 88 | # if label not in [0, 1, -1]: 89 | # print(label) 90 | if event not in event_to_raw_labels: 91 | event_to_raw_labels[event] = [label] 92 | else: 93 | event_to_raw_labels[event].append(label) 94 | 95 | all_evants = list(set(all_evants)) 96 | print(len(event_to_raw_labels)) 97 | 98 | 99 | all_class_labels = [] 100 | all_agreement_rates = [] 101 | all_raw_class_labels = [] 102 | for event in all_evants: 103 | class_label = get_maj_vote_class_label(event_to_raw_labels[event]) 104 | agreement_rate = get_agreement_rate(event_to_raw_labels[event]) 105 | 106 | all_class_labels.append(class_label) 107 | all_agreement_rates.append(agreement_rate) 108 | all_raw_class_labels.append(event_to_raw_labels[event]) 109 | 110 | 111 | df_data_to_save = pd.DataFrame() 112 | df_data_to_save["event"] = all_evants 113 | df_data_to_save["raw_class_labels"] = all_raw_class_labels 114 | df_data_to_save["agreement_rate"] = all_agreement_rates 115 | df_data_to_save["class_label"] = all_class_labels 116 | 117 | print(df_data_to_save) 118 | 119 | df_data_to_save.to_csv("data/demo/mturk/result/v5_clean.csv", index=False) 120 | 121 | 122 | if __name__ == "__main__": 123 | compile_gold_labels_v5() 124 | 125 | -------------------------------------------------------------------------------- /src/delphi_hybrid/prepare_data/filter_paraphrases.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tqdm import tqdm 4 | import argparse 5 | import pandas as pd 6 | 7 | sys.path.append(os.getcwd()) 8 | from src.delphi_hybrid.components.utils import * 9 | from src.delphi_hybrid.components.main_utils import * 10 | from src.delphi_hybrid.components.WANLIScorer import * 11 | 12 | def generate_nli(): 13 | wanli_scorer = WANLIScorer() 14 | paraphrases_cache = read_json(data_base_path + "cache/paraphrases.json") 15 | 16 | all_events = paraphrases_cache.keys() 17 | 18 | nli_to_save = read_json(data_base_path + "cache/nli.json") 19 | for premise in tqdm(all_events): 20 | for hypothesis in paraphrases_cache[premise]: 21 | if premise + " | " + hypothesis not in nli_to_save: 22 | prediction = wanli_scorer.get_scores(premise, hypothesis) 23 | print(premise, "|", hypothesis, ":", prediction) 24 | prediction["type"] = "event_paraphrase" 25 | nli_to_save[premise + " | " + hypothesis] = prediction 26 | 27 | if hypothesis + " | " + premise not in nli_to_save: 28 | prediction = wanli_scorer.get_scores(hypothesis, premise) 29 | print(hypothesis, "|", premise, ":", prediction) 30 | prediction["type"] = "paraphrase_event" 31 | nli_to_save[hypothesis + " | " + premise] = prediction 32 | 33 | save_json(data_base_path + "cache/nli.json", nli_to_save) 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser(description="") 38 | 39 | parser.add_argument('--input_file', type=str, help="location of data file", 40 | default="data/demo/mturk/split/event_only_v5.csv") 41 | parser.add_argument('--device_id', type=int, help="device id", default=0) 42 | parser.add_argument('--total_num_device', type=int, 43 | default=8, help="total number device") 44 | args = parser.parse_args() 45 | 46 | df_data = pd.read_csv(args.input_file, sep="\t") 47 | all_events = df_data["clean_event"].tolist() 48 | 49 | generate_nli() 50 | -------------------------------------------------------------------------------- /src/delphi_plus/evalute/evaluate_declare_only.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("script/evaluate") 3 | from evaluate_utils import * 4 | 5 | 6 | # ######################## moral acceptability/agreement class ######################## 7 | def get_gold_single_input_task_class(bucket, bucket_name, base_path, data_version, task_name, data_split): 8 | """ 9 | Get gold inputs and targets class labels 10 | """ 11 | data_base_path = f"gs://{bucket_name}/" + "/".join(base_path.split("/")[:3]) + "/data" 12 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_declare_only/{task_name}/{data_split}.tsv", sep="\t") 13 | 14 | inputs_all = list(df_inputs["inputs"]) 15 | inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all] 16 | 17 | targets_all = list(df_inputs["targets"]) 18 | 19 | targets = [int(i.split("[/class] [text]")[0].split("[class]")[-1]) for i in targets_all] 20 | return inputs, targets 21 | 22 | 23 | def get_pred_single_input_task_class(bucket, base_path, task_name, check_point): 24 | """ 25 | Get preds class labels 26 | """ 27 | preds_blob = bucket.get_blob(base_path + f"{task_name}_{check_point}_predictions") 28 | preds_blob_list = preds_blob.download_as_string().decode('utf-8').split("\n")[1:] 29 | 30 | preds_class = [] 31 | for i in preds_blob_list: 32 | try: 33 | preds_class.append(int(i.split("[/class] [text]")[0].split("[class]")[-1])) 34 | except: 35 | print("output form not identifiable:", i) 36 | preds_class.append(1) 37 | 38 | return preds_class 39 | 40 | 41 | 42 | ######################## moral acceptability/agreement text ######################## 43 | def get_gold_single_input_task_text(bucket, bucket_name, base_path, data_version, task_name, data_split): 44 | data_base_path = f"gs://{bucket_name}/" + "/".join(base_path.split("/")[:3]) + "/data" 45 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_declare_only/{task_name}/{data_split}.tsv", sep="\t") 46 | inputs_all = list(df_inputs["inputs"]) 47 | inputs = [s.split("[moral_single]: ")[-1] for s in inputs_all] 48 | 49 | targets_all = list(df_inputs["targets"]) 50 | targets = [i.split("[/class] [text]")[1].split("[/text]")[0] for i in targets_all] 51 | return inputs, targets 52 | 53 | 54 | def get_pred_single_input_task_text(bucket, base_path, task_name, check_point): 55 | preds_blob = bucket.get_blob(base_path + f"{task_name}_{check_point}_predictions") 56 | preds_blob_list = preds_blob.download_as_string().decode('utf-8').split("\n")[1:] 57 | 58 | preds_text = [] 59 | for i in preds_blob_list: 60 | try: 61 | preds_text.append(i.split("[/class] [text]")[1].split("[/text")[0]) 62 | except: 63 | print("output form not identifiable:", i) 64 | preds_text.append("") 65 | return preds_text 66 | 67 | 68 | ######################## main ######################## 69 | def main_get_accuracy(base_path, data_split, check_points=None, 70 | is_include_accept_class=True, 71 | is_include_accept_text=True, 72 | is_include_agree_class=True, 73 | is_include_agree_text=True, 74 | is_include_compare=True,): 75 | base_path += f"{data_split}_eval/" 76 | 77 | print("=" * 20) 78 | bucket_name = base_path.split("/")[0] 79 | result_prefix = "/".join(base_path.split("/")[1:]) 80 | data_version = base_path.split("/")[5] 81 | base_path = "/".join(base_path.split("/")[1:]) 82 | 83 | client = storage.Client() 84 | bucket = client.get_bucket(bucket_name) 85 | 86 | if check_points == None: 87 | check_points = get_check_points(client, bucket_name, result_prefix, after_check_point=-1)[1:] 88 | # check_points.sort(reverse=True) 89 | 90 | for check_point in check_points: 91 | print("=" * 40, check_point, "=" * 40) 92 | 93 | main_moral_acceptability(client, bucket, bucket_name, base_path, check_point, 94 | data_version, data_split, 95 | get_gold_single_input_task_class, 96 | get_pred_single_input_task_class, 97 | get_gold_single_input_task_text, 98 | get_pred_single_input_task_text, 99 | is_include_accept_class, 100 | is_include_accept_text) 101 | 102 | main_moral_agreement(client, bucket, bucket_name, base_path, check_point, 103 | data_version, data_split, 104 | get_gold_single_input_task_class, 105 | get_pred_single_input_task_class, 106 | get_gold_single_input_task_text, 107 | get_pred_single_input_task_text, 108 | is_include_agree_class, 109 | is_include_agree_text) 110 | 111 | 112 | if __name__ == "__main__": 113 | base_path = "ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/v11/unicorn-pt/declare_only/lr-0.0001_bs-16/" 114 | # check_points = [1266200] 115 | check_points = None 116 | main_get_accuracy(base_path, "validation", check_points) 117 | # main_get_accuracy(base_path, "test", check_points) 118 | 119 | # base_path = "ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/v11/11B/declare_only/lr-0.0001_bs-16/" 120 | # # check_points = [1266200] 121 | # check_points = None 122 | # main_get_accuracy(base_path, "validation", check_points) 123 | # # main_get_accuracy(base_path, "test", check_points) 124 | -------------------------------------------------------------------------------- /src/delphi_plus/evalute/evaluate_distribution.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("script/evaluate") 3 | from evaluate_utils import * 4 | 5 | 6 | # ######################## moral acceptability/agreement class ######################## 7 | def get_gold_single_input_task_class(bucket, bucket_name, base_path, data_version, task_name, data_split): 8 | """ 9 | Get gold inputs and targets class labels 10 | """ 11 | data_base_path = f"gs://{bucket_name}/" + "/".join(base_path.split("/")[:3]) + "/data" 12 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/{task_name}/" 13 | f"{data_split}.{task_name}.tsv", sep="\t") 14 | 15 | inputs_all = list(df_inputs["inputs"]) 16 | inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all] 17 | 18 | targets_all = list(df_inputs["targets"]) 19 | targets = [int(i.split("[/class] [text]")[0].split("[class]")[-1]) for i in targets_all] 20 | return inputs, targets 21 | 22 | 23 | def get_pred_single_input_task_class(bucket, base_path, task_name, check_point): 24 | """ 25 | Get preds class labels 26 | """ 27 | preds_blob = bucket.get_blob(base_path + f"{task_name}_{check_point}_predictions") 28 | preds_blob_list = preds_blob.download_as_string().decode('utf-8').split("\n")[1:] 29 | 30 | preds_class = [] 31 | for i in preds_blob_list: 32 | try: 33 | preds_class.append(int(i.split("[/class] [text]")[0].split("[class]")[-1])) 34 | except: 35 | print("output form not identifiable:", i) 36 | preds_class.append(1) 37 | return preds_class 38 | 39 | 40 | def get_gold_single_input_task_class_wild_v11(bucket, bucket_name, base_path, data_version, task_name, data_split): 41 | data_base_path = f"gs://{bucket_name}/" + "/".join(base_path.split("/")[:3]) + "/data" 42 | 43 | if data_split == "validation": 44 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/" 45 | f"dev.tsv", sep="\t") 46 | # elif data_split == "test" and task_name == "general": 47 | elif data_split == "test": 48 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/" 49 | f"{task_name}.tsv", sep="\t") 50 | # elif data_split == "test" and task_name == "race": 51 | # df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/" 52 | # f"race_test.tsv", sep="\t") 53 | # elif data_split == "test" and task_name == "gender": 54 | # df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/" 55 | # f"gender_test.tsv", sep="\t") 56 | else: 57 | print("ERROR: not validation split") 58 | 59 | inputs_all = list(df_inputs["inputs"]) 60 | targets_all = list(df_inputs["targets"]) 61 | 62 | # inputs = [i.split("[moral_single]: ")[-1] for i in inputs_all] 63 | inputs = [] 64 | for _, i in enumerate(inputs_all): 65 | if type(i) != type(""): 66 | print("gold class input error:", _, i, targets_all[_]) 67 | inputs.append("") 68 | else: 69 | inputs.append(i.split("[moral_single]: ")[-1]) 70 | 71 | 72 | # targets = [int(i.split("[/class] [text]")[0].split("[class]")[-1]) for i in targets_all] 73 | targets = [] 74 | for _, i in enumerate(targets_all): 75 | if type(i) != type(""): 76 | print("gold class output error:", _, i, inputs_all[_]) 77 | targets.append(0) 78 | else: 79 | targets.append(int(i.split("[/class] [text]")[0].split("[class]")[-1])) 80 | 81 | return inputs, targets 82 | 83 | 84 | def get_pred_single_input_task_class_wild_v11(bucket, base_path, task_name, check_point): 85 | if task_name == "general_test": 86 | preds_blob = bucket.get_blob(base_path + f"wild_{check_point}_predictions") 87 | else: 88 | preds_blob = bucket.get_blob(base_path + f"{task_name}_{check_point}_predictions") 89 | preds_blob_list = preds_blob.download_as_string().decode('utf-8').split("\n")[1:] 90 | 91 | preds_class = [] 92 | for i in preds_blob_list: 93 | try: 94 | preds_class.append(int(i.split("[/class] [text]")[0].split("[class]")[-1])) 95 | except: 96 | print("output form not identifiable:", i) 97 | preds_class.append(1) 98 | return preds_class 99 | 100 | 101 | ######################## moral acceptability/agreement text ######################## 102 | def get_gold_single_input_task_text(bucket, bucket_name, base_path, data_version, task_name, data_split): 103 | data_base_path = f"gs://{bucket_name}/" + "/".join(base_path.split("/")[:3]) + "/data" 104 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/{task_name}/" 105 | f"{data_split}.{task_name}.tsv", sep="\t") 106 | inputs_all = list(df_inputs["inputs"]) 107 | inputs = [s.split("[moral_single]: ")[-1] for s in inputs_all] 108 | 109 | targets_all = list(df_inputs["targets"]) 110 | targets = [i.split("[/class] [text]")[1].split("[/text]")[0] for i in targets_all] 111 | return inputs, targets 112 | 113 | 114 | def get_pred_single_input_task_text(bucket, base_path, task_name, check_point): 115 | preds_blob = bucket.get_blob(base_path + f"{task_name}_{check_point}_predictions") 116 | preds_blob_list = preds_blob.download_as_string().decode('utf-8').split("\n")[1:] 117 | 118 | preds_text = [] 119 | for i in preds_blob_list: 120 | try: 121 | preds_text.append(i.split("[/class] [text]")[1].split("[/text")[0]) 122 | except: 123 | print("output form not identifiable:", i) 124 | preds_text.append("") 125 | return preds_text 126 | 127 | 128 | def get_gold_single_input_task_text_wild_v11(bucket, bucket_name, base_path, data_version, task_name, data_split): 129 | data_base_path = f"gs://{bucket_name}/" + "/".join(base_path.split("/")[:3]) + "/data" 130 | if data_split == "validation": 131 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/dev.tsv", sep="\t") 132 | # elif data_split == "test" and task_name == "wild": 133 | # df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/general_test.tsv", sep="\t") 134 | # elif data_split == "test" and task_name == "race_test": 135 | # df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/race_test.tsv", sep="\t") 136 | # elif data_split == "test" and task_name == "gender_test": 137 | # df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/gender_test.tsv", sep="\t") 138 | elif data_split == "test": 139 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/wild/{task_name}.tsv", sep="\t") 140 | else: 141 | print("ERROR: not validation split") 142 | 143 | inputs_all = list(df_inputs["inputs"]) 144 | targets_all = list(df_inputs["targets"]) 145 | 146 | inputs = [] 147 | for _, i in enumerate(inputs_all): 148 | if type(i) != type(""): 149 | print("gold text input error:", _, i, targets_all[_]) 150 | inputs.append("") 151 | else: 152 | inputs.append(i.split("[moral_single]: ")[-1]) 153 | 154 | 155 | # targets = [int(i.split("[/class] [text]")[0].split("[class]")[-1]) for i in targets_all] 156 | targets = [] 157 | for _, i in enumerate(targets_all): 158 | if type(i) != type(""): 159 | print("gold text output error:", _, i, inputs_all[_]) 160 | targets.append(0) 161 | else: 162 | targets.append(i.split("[/class] [text]")[1].split("[/text]")[0]) 163 | 164 | return inputs, targets 165 | 166 | 167 | def get_pred_single_input_task_text_wild_v11(bucket, base_path, task_name, check_point): 168 | if task_name == "general_test": 169 | preds_blob = bucket.get_blob(base_path + f"wild_{check_point}_predictions") 170 | else: 171 | preds_blob = bucket.get_blob(base_path + f"{task_name}_{check_point}_predictions") 172 | 173 | preds_blob_list = preds_blob.download_as_string().decode('utf-8').split("\n")[1:] 174 | 175 | preds_text = [] 176 | for i in preds_blob_list: 177 | try: 178 | preds_text.append(i.split("[/class] [text]")[1].split("[/text")[0]) 179 | except: 180 | print("output form not identifiable:", i) 181 | preds_text.append("") 182 | return preds_text 183 | 184 | 185 | ######################## moral comparison class ######################## 186 | def get_gold_moral_comparison_class(bucket, bucket_name, base_path, data_version, data_split): 187 | data_base_path = f"gs://{bucket_name}/" + "/".join(base_path.split("/")[:3]) + "/data" 188 | df_inputs = pd.read_csv(data_base_path + f"/{data_version}_distribution/moral_comparison/" 189 | f"{data_split}.moral_comparison.tsv", sep="\t") 190 | inputs_all = list(df_inputs["inputs"]) 191 | inputs = [s.split("[moral_pair]: ")[-1] for s in inputs_all] 192 | 193 | targets_all = list(df_inputs["targets"]) 194 | targets = [int(t) for t in targets_all] 195 | 196 | return inputs, targets 197 | 198 | 199 | def get_pred_moral_comparison_class(bucket, base_path, check_point): 200 | preds_blob = bucket.get_blob(base_path + f"moral_comparison_{check_point}_predictions") 201 | preds_blob_list = preds_blob.download_as_string().decode('utf-8').split("\n")[1:] 202 | return [int(i) for i in preds_blob_list] 203 | 204 | 205 | ######################## main ######################## 206 | def main_get_accuracy(base_path, data_split, check_points=None, 207 | is_include_accept_class=True, 208 | is_include_accept_text=True, 209 | is_include_agree_class=True, 210 | is_include_agree_text=True, 211 | is_include_compare=True,): 212 | base_path += f"{data_split}_eval/" 213 | 214 | bucket_name = base_path.split("/")[0] 215 | result_prefix = "/".join(base_path.split("/")[1:]) 216 | data_version = base_path.split("/")[5] 217 | model_type = base_path.split("/")[7] 218 | # eval_data = data_version + "_sbic_" + model_type.split("_")[-3] 219 | base_path = "/".join(base_path.split("/")[1:]) 220 | 221 | client = storage.Client() 222 | bucket = client.get_bucket(bucket_name) 223 | 224 | if check_points == None: 225 | check_points = get_check_points(client, bucket_name, result_prefix, after_check_point=-1)[2:] 226 | # check_points.sort(reverse=True) 227 | for check_point in check_points: 228 | print("=" * 40, check_point, "=" * 40) 229 | 230 | # main_moral_acceptability(client, bucket, bucket_name, base_path, check_point, 231 | # data_version, data_split, 232 | # get_gold_single_input_task_class, 233 | # get_pred_single_input_task_class, 234 | # get_gold_single_input_task_text, 235 | # get_pred_single_input_task_text, 236 | # is_include_accept_class, 237 | # is_include_accept_text, 238 | # task_name="moral_acceptability") 239 | # 240 | # main_moral_agreement(client, bucket, bucket_name, base_path, check_point, 241 | # data_version, data_split, 242 | # get_gold_single_input_task_class, 243 | # get_pred_single_input_task_class, 244 | # get_gold_single_input_task_text, 245 | # get_pred_single_input_task_text, 246 | # is_include_agree_class, 247 | # is_include_agree_text, 248 | # task_name="moral_agreement") 249 | # 250 | # if is_include_compare: 251 | # main_moral_comparison(client, bucket, bucket_name, base_path, check_point, 252 | # data_version, data_split, 253 | # get_gold_moral_comparison_class, 254 | # get_pred_moral_comparison_class) 255 | 256 | main_wild_v11(client, bucket, bucket_name, base_path, check_point, 257 | data_version, data_split, 258 | get_gold_single_input_task_class_wild_v11, 259 | get_pred_single_input_task_class_wild_v11, 260 | get_gold_single_input_task_text_wild_v11, 261 | get_pred_single_input_task_text_wild_v11, "general_test") 262 | 263 | if data_split == "test": 264 | main_wild_v11(client, bucket, bucket_name, base_path, check_point, 265 | data_version, data_split, 266 | get_gold_single_input_task_class_wild_v11, 267 | get_pred_single_input_task_class_wild_v11, 268 | get_gold_single_input_task_text_wild_v11, 269 | get_pred_single_input_task_text_wild_v11, "race_test") 270 | 271 | main_wild_v11(client, bucket, bucket_name, base_path, check_point, 272 | data_version, data_split, 273 | get_gold_single_input_task_class_wild_v11, 274 | get_pred_single_input_task_class_wild_v11, 275 | get_gold_single_input_task_text_wild_v11, 276 | get_pred_single_input_task_text_wild_v11, "gender_test") 277 | 278 | 279 | if __name__ == "__main__": 280 | base_path = "ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/v11/unicorn-pt/distribution/lr-0.0001_bs-16/" 281 | check_points = None 282 | # main_get_accuracy(base_path, "validation", check_points) 283 | main_get_accuracy(base_path, "test", [1249400]) 284 | 285 | -------------------------------------------------------------------------------- /src/delphi_plus/evalute/select_declare_only.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("script/evaluate") 3 | from evaluate_utils import * 4 | 5 | def eval_accept(row_accuracies, df_results): 6 | class_targets = df_results["freeform_class_targets"].tolist() 7 | class_preds = df_results["freeform_class_preds"].tolist() 8 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="exact")) 9 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="binary")) 10 | 11 | text_class_targets = df_results["moral_acceptability_text_2_class_targets"].tolist() 12 | text_class_preds = df_results["moral_acceptability_text_2_class_preds"].tolist() 13 | row_accuracies.append(get_accuracy(text_class_targets, text_class_preds, accuracy_type="binary")) 14 | 15 | text_targets = df_results["moral_acceptability_text_targets"].tolist() 16 | text_preds = df_results["moral_acceptability_text_preds"].tolist() 17 | exact_match_accuracy = get_moral_acceptability_text_exact_match_accuracy(text_targets, text_preds) 18 | return row_accuracies 19 | 20 | 21 | def eval_agree(row_accuracies, df_results): 22 | class_targets = df_results["yesno_class_targets"].tolist() 23 | class_preds = df_results["yesno_class_preds"].tolist() 24 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="binary")) 25 | 26 | text_targets = df_results["moral_agreement_text_targets"].tolist() 27 | text_preds = df_results["moral_agreement_text_preds"].tolist() 28 | exact_match_accuracy, polarity_align_accuracy = get_moral_agreement_text_accuracy(text_targets, text_preds) 29 | row_accuracies.append(polarity_align_accuracy) 30 | return row_accuracies 31 | 32 | 33 | 34 | def select_check_point(data_split, model_type, pt_model, bs, check_points): 35 | data_version = "v11" 36 | bucket_name = "ai2-tpu-europe-west4" 37 | lr = 0.0001 38 | 39 | client = storage.Client() 40 | bucket = client.get_bucket(bucket_name) 41 | 42 | print("model_type:", model_type) 43 | print("lr:", lr) 44 | print("bs:", bs) 45 | 46 | result_prefix = f"projects/liweij/mosaic-commonsense-morality/results/{data_version}/" \ 47 | f"{pt_model}/{model_type}/lr-{lr}_bs-{bs}/" \ 48 | f"freeform/{data_split}/" 49 | 50 | if check_points == None: 51 | check_points = get_result_check_points(client, bucket_name, result_prefix, after_check_point=-1)[2:] 52 | 53 | accuracies = [] 54 | for check_point in check_points: 55 | row_accuracies = [check_point] 56 | 57 | ##################### accept ##################### 58 | df_results = read_result_file(bucket_name, data_version, model_type, 59 | check_point, data_split, "freeform", lr, bs, pt_model) 60 | row_accuracies = eval_accept(row_accuracies, df_results) 61 | 62 | ##################### agree ##################### 63 | df_results = read_result_file(bucket_name, data_version, model_type, 64 | check_point, data_split, "yesno", lr, bs, pt_model) 65 | row_accuracies = eval_agree(row_accuracies, df_results) 66 | 67 | accuracies.append(row_accuracies) 68 | print("-- check point:", check_point, row_accuracies) 69 | 70 | df_to_save = pd.DataFrame(accuracies) 71 | df_to_save.to_csv("temp_result_file_2.csv", index=False) 72 | 73 | 74 | if __name__ == "__main__": 75 | model_type = "declare_only" 76 | pt_model = "unicorn-pt" 77 | bs = 16 78 | check_points = None 79 | select_check_point("validation", model_type, pt_model, bs, check_points) 80 | -------------------------------------------------------------------------------- /src/delphi_plus/evalute/select_distribution.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("script/evaluate") 3 | from evaluate_utils import * 4 | 5 | def eval_accept(row_accuracies, df_results): 6 | class_targets = df_results["moral_acceptability_class_targets"].tolist() 7 | class_preds = df_results["moral_acceptability_class_preds"].tolist() 8 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="exact")) 9 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="binary")) 10 | 11 | text_class_targets = df_results["moral_acceptability_text_2_class_targets"].tolist() 12 | text_class_preds = df_results["moral_acceptability_text_2_class_preds"].tolist() 13 | row_accuracies.append(get_accuracy(text_class_targets, text_class_preds, accuracy_type="binary")) 14 | 15 | text_targets = df_results["moral_acceptability_text_targets"].tolist() 16 | text_preds = df_results["moral_acceptability_text_preds"].tolist() 17 | exact_match_accuracy = get_moral_acceptability_text_exact_match_accuracy(text_targets, text_preds) 18 | return row_accuracies 19 | 20 | 21 | def eval_agree(row_accuracies, df_results): 22 | class_targets = df_results["moral_agreement_class_targets"].tolist() 23 | class_preds = df_results["moral_agreement_class_preds"].tolist() 24 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="binary")) 25 | 26 | text_targets = df_results["moral_agreement_text_targets"].tolist() 27 | text_preds = df_results["moral_agreement_text_preds"].tolist() 28 | exact_match_accuracy, polarity_align_accuracy = get_moral_agreement_text_accuracy(text_targets, text_preds) 29 | row_accuracies.append(polarity_align_accuracy) 30 | return row_accuracies 31 | 32 | 33 | def eval_compare(row_accuracies, df_results): 34 | class_targets = df_results["moral_comparison_class_targets"].tolist() 35 | class_preds = df_results["moral_comparison_class_preds"].tolist() 36 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="exact")) 37 | return row_accuracies 38 | 39 | 40 | def eval_wild(row_accuracies, df_results): 41 | df_results["wild_class_targets"] = df_results["wild_class_targets"] 42 | df_results["wild_class_preds"] = df_results["wild_class_preds"] 43 | class_targets = df_results["wild_class_targets"].tolist() 44 | class_preds = df_results["wild_class_preds"].tolist() 45 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="exact")) 46 | row_accuracies.append(get_accuracy(class_targets, class_preds, accuracy_type="binary")) 47 | 48 | text_class_targets = df_results["wild_v11_text_2_class_targets"].tolist() 49 | text_class_preds = df_results["wild_v11_text_2_class_preds"].tolist() 50 | 51 | row_accuracies.append(get_accuracy(text_class_targets, text_class_preds, accuracy_type="binary")) 52 | return row_accuracies 53 | 54 | 55 | def select_check_point(data_split, model_type, pt_model, bs, check_points): 56 | data_version = "v11" 57 | bucket_name = "ai2-tpu-europe-west4" 58 | lr = 0.0001 59 | 60 | client = storage.Client() 61 | bucket = client.get_bucket(bucket_name) 62 | 63 | print("model_type:", model_type) 64 | print("lr:", lr) 65 | print("bs:", bs) 66 | 67 | result_prefix = f"projects/liweij/mosaic-commonsense-morality/results/{data_version}/" \ 68 | f"{pt_model}/{model_type}/lr-{lr}_bs-{bs}/" \ 69 | f"moral_acceptability/{data_split}/" 70 | 71 | if check_points == None: 72 | check_points = get_result_check_points(client, bucket_name, result_prefix, after_check_point=-1)[2:] 73 | 74 | accuracies = [] 75 | for check_point in check_points: 76 | row_accuracies = [check_point] 77 | 78 | ##################### accept ##################### 79 | df_results = read_result_file(bucket_name, data_version, model_type, 80 | check_point, data_split, "moral_acceptability", lr, bs, pt_model) 81 | row_accuracies = eval_accept(row_accuracies, df_results) 82 | 83 | ##################### agree ##################### 84 | df_results = read_result_file(bucket_name, data_version, model_type, 85 | check_point, data_split, "moral_agreement", lr, bs, pt_model) 86 | row_accuracies = eval_agree(row_accuracies, df_results) 87 | 88 | ##################### compare ##################### 89 | df_results = read_result_file(bucket_name, data_version, model_type, 90 | check_point, data_split, "moral_comparison", lr, bs, pt_model) 91 | row_accuracies = eval_compare(row_accuracies, df_results) 92 | 93 | ##################### wild general ##################### 94 | df_results = read_result_file(bucket_name, data_version, model_type, 95 | check_point, data_split, "wild", lr, bs, pt_model) 96 | df_results = df_results[df_results["wild_class_targets"] < 90] 97 | row_accuracies = eval_wild(row_accuracies, df_results) 98 | 99 | 100 | accuracies.append(row_accuracies) 101 | print("-- check point:", check_point, row_accuracies) 102 | 103 | df_to_save = pd.DataFrame(accuracies) 104 | df_to_save.to_csv("temp_result_file_2.csv", index=False) 105 | 106 | 107 | if __name__ == "__main__": 108 | model_type = "distribution" 109 | pt_model = "unicorn-pt" 110 | bs = 16 111 | check_points = None 112 | select_check_point("validation", model_type, pt_model, bs, check_points) 113 | -------------------------------------------------------------------------------- /src/delphi_plus/train/evaluate.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | """ 4 | Evaluate the model checkpoint 5 | """ 6 | 7 | import t5 8 | import os 9 | import sys 10 | import util 11 | import seqio 12 | import click 13 | import logging 14 | import tensorflow.compat.v1 as tf 15 | 16 | print("python", sys.version) 17 | print("t5", t5.__version__) 18 | print("tf", tf.__version__) 19 | print("seqio", seqio.__version__) 20 | 21 | tf.disable_v2_behavior() 22 | 23 | import tasks, mixtures 24 | # N.B. We must import tasks and mixtures here so that they are registered and available for evaluation. 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | @click.command() 29 | @click.argument("mixture", type=str) 30 | @click.argument("results_dir", type=str) 31 | @click.argument("tpu-name", type=str) # The name of the TPU. Defaults to the TPU_NAME environment variable. 32 | @click.argument("tpu-topology", type=str) # The topology of the TPU. Defaults to the TPU_TOPOLOGY environment variable. 33 | @click.argument("split", type=str) 34 | @click.argument("checkpoint", type=int) 35 | @click.option( 36 | "--model-parallelism", 37 | type=int, 38 | default=8, 39 | help="The degree of model parallelism to use. Defaults to 8.", 40 | ) 41 | 42 | def evaluate( 43 | mixture: str, 44 | results_dir: str, 45 | split: str, 46 | checkpoint: int, 47 | model_parallelism: int, 48 | tpu_name: str, 49 | tpu_topology: str, 50 | ) -> None: 51 | """ 52 | Evaluate the model located at RESULTS_DIR on MIXTURE. 53 | """ 54 | 55 | print(tpu_name) 56 | print(tpu_topology) 57 | # print(mixture) 58 | 59 | # Initialize arguments 60 | if tpu_topology == "v3-32": 61 | batch_size = 16 62 | model_parallelism = 8 63 | elif tpu_topology == "v3-8": 64 | batch_size = 8 65 | model_parallelism = 8 66 | else: 67 | print("ERROR: tpu_topology invalid") 68 | return 69 | 70 | # Validate arguments 71 | util.validate_path(results_dir) 72 | 73 | checkpoints = util.get_result_check_points(results_dir, split, "ethics_cm_converted_class_only") 74 | 75 | print("-" * 10, "checkpoints todo", "-" * 10) 76 | 77 | if checkpoint == 100: 78 | checkpoints_to_eval = None 79 | elif checkpoint == 0: 80 | checkpoints_to_eval = checkpoints 81 | else: 82 | checkpoints_to_eval = [checkpoint] 83 | print(checkpoints_to_eval) 84 | 85 | # Run evaluation 86 | model = t5.models.MtfModel( 87 | model_dir=results_dir, 88 | tpu=tpu_name, 89 | tpu_topology=tpu_topology, 90 | model_parallelism=model_parallelism, 91 | batch_size=batch_size, 92 | sequence_length={"inputs": 512, "targets": 128}, 93 | learning_rate_schedule=None, 94 | save_checkpoints_steps=5000, 95 | keep_checkpoint_max=None, 96 | iterations_per_loop=100, 97 | ) 98 | 99 | model.eval( 100 | mixture_or_task_name=mixture, 101 | checkpoint_steps=checkpoints_to_eval, 102 | split=split, 103 | ) 104 | 105 | 106 | if __name__ == "__main__": 107 | evaluate() 108 | -------------------------------------------------------------------------------- /src/delphi_plus/train/fine-tune.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | """Fine-tune T5 based models.""" 4 | 5 | import t5 6 | import sys 7 | import seqio 8 | import click 9 | import logging 10 | import tensorflow.compat.v1 as tf 11 | 12 | print("python", sys.version) 13 | print("t5", t5.__version__) 14 | print("tf", tf.__version__) 15 | print("seqio", seqio.__version__) 16 | 17 | import util 18 | import warnings 19 | import tasks, mixtures 20 | # We must import tasks and mixtures here so that the tasks and mixtures are registered and available for training. 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | v=tf.compat.v1.logging.FATAL 25 | tf.compat.v1.logging.set_verbosity(v) 26 | tf.disable_v2_behavior() 27 | 28 | config = tf.ConfigProto() 29 | config.gpu_options.allow_growth = True 30 | session = tf.InteractiveSession(config=config) 31 | 32 | warnings.filterwarnings("ignore", category=DeprecationWarning) 33 | 34 | PRETRAINED_MODELS = { 35 | "small": ("gs://t5-data/pretrained_models/small/", -1), 36 | "base": ("gs://t5-data/pretrained_models/base/", -1), 37 | "large": ("gs://t5-data/pretrained_models/large/", -1), 38 | "3B": ("gs://t5-data/pretrained_models/3B/", -1), 39 | "11B": ("gs://t5-data/pretrained_models/11B/", -1), 40 | "unicorn-pt": ("gs://ai2-mosaic-public/projects/rainbow/v1.0/unicorns/lr-2e-3_batch-size-32/", -1), 41 | "v11-delphi-declare": ( 42 | "gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/model/v11/unicorn-pt/declare_only/lr-0.0001_bs-16/", 43 | 1218800), 44 | } 45 | 46 | @click.command() 47 | @click.argument("mixture", type=str) 48 | @click.argument("results_dir", type=str) 49 | @click.argument("tpu-name", type=str) # The name of the TPU. Defaults to the TPU_NAME environment variable. 50 | @click.argument("tpu-topology", type=str) # The topology of the TPU. Defaults to the TPU_TOPOLOGY environment variable. 51 | @click.argument("pretrained-model", type=str) 52 | @click.option( 53 | "--split", 54 | type=str, 55 | default="train", 56 | help="The split on which to train. Defaults to 'train'.", 57 | ) 58 | @click.option( 59 | "--n-steps", 60 | type=int, 61 | default=600000, 62 | help="The number of gradient updates. Defaults to 25,000.", 63 | ) 64 | @click.option( 65 | "--save-checkpoints-steps", 66 | type=int, 67 | default=5000, 68 | help=( 69 | "The number of steps to take before saving a checkpoint. Defaults to" 70 | " 5000." 71 | ), 72 | ) 73 | @click.option( 74 | "--n-checkpoints-to-keep", 75 | type=int, 76 | default=300, 77 | help=( 78 | "The number of checkpoints to keep during fine-tuning. Defaults" 79 | " to 4." 80 | ), 81 | ) 82 | @click.option( 83 | "--learning-rate", 84 | type=float, 85 | default=2e-4, 86 | help="The learning rate to use for training. Defaults to 3e-3.", 87 | ) 88 | @click.option( 89 | "--continue_finetune", 90 | type=bool, 91 | default=True, 92 | help="Whether to continue training from an existing checkpoint.", 93 | ) 94 | 95 | def fine_tune( 96 | mixture: str, 97 | results_dir: str, 98 | split: str, 99 | pretrained_model: str, 100 | n_steps: int, 101 | learning_rate: float, 102 | save_checkpoints_steps: int, 103 | n_checkpoints_to_keep: int, 104 | tpu_name: str, 105 | tpu_topology: str, 106 | continue_finetune: bool, 107 | ) -> None: 108 | """ 109 | Fine-tune the model on MIXTURE, writing results to RESULTS_DIR. 110 | """ 111 | 112 | # Initialize arguments 113 | if tpu_topology == "v3-32": 114 | batch_size = 16 115 | model_parallelism = 32 116 | elif tpu_topology == "v3-8": 117 | batch_size = 8 118 | model_parallelism = 8 119 | else: 120 | print("ERROR: tpu_topology invalid") 121 | return 122 | 123 | pretrained_checkpoint_step = -1 124 | 125 | # Get result path given arguments 126 | result_path = util.get_result_path(results_dir, pretrained_model, mixture, learning_rate, batch_size) 127 | 128 | # Validate path 129 | util.validate_path(results_dir, pretrained_model, PRETRAINED_MODELS) 130 | 131 | # Process arguments 132 | if pretrained_model in PRETRAINED_MODELS: 133 | pretrained_model, pretrained_checkpoint_step = PRETRAINED_MODELS[pretrained_model] 134 | 135 | # If the training stops before finishing and we want to continue from the last checkpoint 136 | if continue_finetune: 137 | pretrained_model = result_path 138 | 139 | # Print arguments 140 | util.print_arguments(result_path, results_dir, mixture, split, pretrained_model, 141 | pretrained_checkpoint_step, n_steps, batch_size, model_parallelism, 142 | save_checkpoints_steps, n_checkpoints_to_keep, learning_rate, 143 | tpu_name, tpu_topology, tasks, continue_finetune) 144 | 145 | # Run fine-tuning 146 | model = t5.models.MtfModel( 147 | model_dir=result_path, 148 | tpu=tpu_name, 149 | tpu_topology=tpu_topology, 150 | model_parallelism=model_parallelism, 151 | batch_size=batch_size, 152 | sequence_length={"inputs": 512, "targets": 128}, 153 | learning_rate_schedule=learning_rate, 154 | save_checkpoints_steps=save_checkpoints_steps, 155 | keep_checkpoint_max=n_checkpoints_to_keep, 156 | iterations_per_loop=100, 157 | ) 158 | 159 | model.finetune( 160 | mixture_or_task_name=mixture, 161 | pretrained_model_dir=pretrained_model, 162 | pretrained_checkpoint_step=pretrained_checkpoint_step, 163 | finetune_steps=n_steps, 164 | split=split, 165 | ) 166 | 167 | 168 | if __name__ == "__main__": 169 | fine_tune() 170 | -------------------------------------------------------------------------------- /src/delphi_plus/train/mixtures.py: -------------------------------------------------------------------------------- 1 | import os 2 | import t5 3 | import tasks 4 | import rates 5 | import seqio 6 | import functools 7 | 8 | import util 9 | 10 | # seqio.MixtureRegistry.add( 11 | # "sbic_commonsense_morality_joint_all_proportional", 12 | # ["sbic_moral_acceptability", 13 | # "sbic_moral_agreement", 14 | # "sbic_moral_comparison"], 15 | # default_rate=rates.MIXING_RATES["proportional"] 16 | # ) 17 | # util.print_mixture_examples("sbic_commonsense_morality_joint_all_proportional") 18 | # 19 | # ################## commonsense norm bank + wild ablations ################## 20 | # proportions = [10, 20, 40, 60, 80, 100] 21 | # for p in proportions: 22 | # seqio.MixtureRegistry.add( 23 | # f"sbic_commonsense_morality_joint_all_proportional_wild_{p}", 24 | # [ f"wild_train_{p}", 25 | # "sbic_moral_acceptability", 26 | # "sbic_moral_agreement", 27 | # "sbic_moral_comparison"], 28 | # default_rate=rates.MIXING_RATES["proportional"] 29 | # ) 30 | # util.print_mixture_examples(f"sbic_commonsense_morality_joint_all_proportional_wild_{p}") 31 | # 32 | # 33 | # seqio.MixtureRegistry.add( 34 | # "sbic_commonsense_morality_joint_all_proportional_wild_woz_100", 35 | # [ f"wild_train_woz_100", 36 | # "sbic_moral_acceptability", 37 | # "sbic_moral_agreement", 38 | # "sbic_moral_comparison"], 39 | # default_rate=rates.MIXING_RATES["proportional"] 40 | # ) 41 | # util.print_mixture_examples(f"sbic_commonsense_morality_joint_all_proportional_wild_woz_100") 42 | # 43 | # 44 | # seqio.MixtureRegistry.add( 45 | # "sbic_commonsense_morality_joint_all_proportional_wild_woz_100_v1", 46 | # [ f"wild_train_woz_100", 47 | # "sbic_moral_acceptability", 48 | # "sbic_moral_agreement", 49 | # "sbic_moral_comparison"], 50 | # default_rate=rates.MIXING_RATES["proportional"] 51 | # ) 52 | # util.print_mixture_examples(f"sbic_commonsense_morality_joint_all_proportional_wild_woz_100_v1") 53 | # 54 | # 55 | # seqio.MixtureRegistry.add( 56 | # "wild_hard_test", 57 | # [ "race_test", 58 | # "gender_test"], 59 | # default_rate=rates.MIXING_RATES["proportional"] 60 | # ) 61 | # util.print_mixture_examples(f"wild_hard_test") 62 | # 63 | # seqio.MixtureRegistry.add( 64 | # "all", 65 | # [ f"wild_train_100", 66 | # "sbic_moral_acceptability", 67 | # "sbic_moral_agreement", 68 | # "sbic_moral_comparison", 69 | # "race_test", 70 | # "gender_test"], 71 | # default_rate=rates.MIXING_RATES["proportional"] 72 | # ) 73 | # util.print_mixture_examples(f"all") 74 | 75 | 76 | # seqio.MixtureRegistry.add( 77 | # "sbic_single_class_only", 78 | # [ "sbic_moral_acceptability_single_class_only", 79 | # "sbic_moral_agreement_single_class_only"], 80 | # default_rate=rates.MIXING_RATES["proportional"] 81 | # ) 82 | # util.print_mixture_examples(f"sbic_single_class_only") 83 | # 84 | 85 | # dynahate = True 86 | if tasks.dynahate: 87 | for t in tasks.dynahate_task_formats: 88 | seqio.MixtureRegistry.add( 89 | f"dynahate_all_{t}", 90 | [ f"dynahate_round_1_{t}", 91 | f"dynahate_round_2_{t}", 92 | f"dynahate_round_3_{t}", 93 | f"dynahate_round_4_{t}",], 94 | default_rate=rates.MIXING_RATES["proportional"] 95 | ) 96 | 97 | seqio.MixtureRegistry.add( 98 | f"dynahate_all_{t}_100_shot", 99 | [ f"dynahate_round_1_{t}_100_shot", 100 | f"dynahate_round_2_{t}_100_shot", 101 | f"dynahate_round_3_{t}_100_shot", 102 | f"dynahate_round_4_{t}_100_shot",], 103 | default_rate=rates.MIXING_RATES["proportional"] 104 | ) 105 | 106 | 107 | seqio.MixtureRegistry.add( 108 | "declare_only", 109 | [f"freeform", 110 | f"yesno"], 111 | default_rate=rates.MIXING_RATES["proportional"] 112 | ) 113 | 114 | seqio.MixtureRegistry.add( 115 | "norm_bank", 116 | ["moral_acceptability", 117 | "moral_agreement", 118 | "moral_comparison", 119 | ], 120 | default_rate=rates.MIXING_RATES["proportional"] 121 | ) 122 | 123 | 124 | # seqio.MixtureRegistry.add( 125 | # "distribution", 126 | # ["moral_acceptability", 127 | # "moral_agreement", 128 | # "moral_comparison", 129 | # "wild", 130 | # ], 131 | # default_rate=rates.MIXING_RATES["proportional"] 132 | # ) 133 | # 134 | # seqio.MixtureRegistry.add( 135 | # "all_distribution_wild", 136 | # ["wild", 137 | # "race_test", 138 | # "gender_test", 139 | # ], 140 | # default_rate=rates.MIXING_RATES["proportional"] 141 | # ) 142 | 143 | 144 | seqio.MixtureRegistry.add( 145 | "maj_vote", 146 | ["moral_acceptability", 147 | "moral_agreement", 148 | "moral_comparison", 149 | "wild", 150 | ], 151 | default_rate=rates.MIXING_RATES["proportional"] 152 | ) 153 | 154 | seqio.MixtureRegistry.add( 155 | "all_maj_vote_wild", 156 | ["wild", 157 | "race_test", 158 | "gender_test", 159 | ], 160 | default_rate=rates.MIXING_RATES["proportional"] 161 | ) 162 | -------------------------------------------------------------------------------- /src/delphi_plus/train/predict.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | """Evaluate the model on the rainbow datasets.""" 4 | 5 | import t5 6 | import os 7 | import sys 8 | import seqio 9 | import logging 10 | import click 11 | import util 12 | import pandas as pd 13 | import tensorflow.compat.v1 as tf 14 | 15 | # Improve logging. 16 | from contextlib import contextmanager 17 | 18 | # print("python", sys.version) 19 | # print("t5", t5.__version__) 20 | # print("tf", tf.__version__) 21 | # print("seqio", seqio.__version__) 22 | 23 | tf.disable_v2_behavior() 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | def getSubstringBetweenMarkers(source_string, start_marker, end_marker): 28 | start = source_string.find(start_marker) + len(start_marker) 29 | end = source_string.find(end_marker) 30 | return source_string[start: end] 31 | 32 | 33 | @contextmanager 34 | def tf_verbosity_level(level): 35 | og_level = tf.logging.get_verbosity() 36 | tf.logging.set_verbosity(level) 37 | yield 38 | tf.logging.set_verbosity(og_level) 39 | 40 | 41 | @click.command() 42 | @click.option( 43 | "--batch-size", 44 | type=int, 45 | default=64, 46 | help=( 47 | "The batch size to use for prediction. For efficient prediction on the" 48 | " TPU, choose a multiple of either 8 or 128. Defaults to 64." 49 | ), 50 | ) 51 | @click.option( 52 | "--model-parallelism", 53 | type=int, 54 | default=8, 55 | help="The degree of model parallelism to use. Defaults to 8.", 56 | ) 57 | @click.option( 58 | "--tpu-name", 59 | type=str, 60 | default="de-tpu1", 61 | required=True, 62 | help="The name of the TPU. Defaults to the TPU_NAME environment variable.", 63 | ) 64 | @click.option( 65 | "--tpu-topology", 66 | type=str, 67 | default="v3-32", 68 | required=True, 69 | help=( 70 | "The topology of the TPU. Defaults to the TPU_TOPOLOGY environment" 71 | " variable." 72 | ), 73 | ) 74 | def predict( 75 | batch_size: int, 76 | model_parallelism: int, 77 | tpu_name: str, 78 | tpu_topology: str, 79 | ) -> None: 80 | """Evaluate the model located at RESULTS_DIR on MIXTURE.""" 81 | 82 | eval_data = "race_topk_batch6to10" 83 | 84 | data_version = "v9" 85 | model_type = "sbic_commonsense_morality_joint_all_proportional" 86 | check_point = 1264700 87 | lr = 0.0001 88 | bs = 16 89 | bucket_name = "ai2-tpu-europe-west4" 90 | models_dir = f"gs://{bucket_name}/projects/liweij/mosaic-commonsense-morality/model/{data_version}/" \ 91 | f"unicorn-pt/{model_type}/lr-{lr}_bs-{bs}" 92 | training_type = model_type.split("_")[-3] 93 | 94 | # Run evaluation. 95 | model = t5.models.MtfModel( 96 | model_dir=models_dir, 97 | tpu=tpu_name, 98 | tpu_topology=tpu_topology, 99 | model_parallelism=model_parallelism, 100 | batch_size=batch_size, 101 | sequence_length={"inputs": 512, "targets": 128}, 102 | learning_rate_schedule=None, 103 | save_checkpoints_steps=5000, 104 | keep_checkpoint_max=None, 105 | iterations_per_loop=100, 106 | ) 107 | 108 | predict_joint_inputs_paths = ["gs://ai2-tpu-europe-west4/projects/liweij/mosaic-commonsense-morality/" \ 109 | f"data/qualitative_eval/{training_type}/" + eval_data + "_qualitative_eval.tsv"] 110 | predict_joint_outputs_paths = [ 111 | models_dir.replace("model", "preds") + "/raw/" + eval_data + "_qualitative_eval.tsv"] 112 | 113 | for i in range(len(predict_joint_inputs_paths)): 114 | predict_joint_inputs_path = predict_joint_inputs_paths[i] 115 | predict_joint_outputs_path = predict_joint_outputs_paths[i] 116 | 117 | # Ignore any logging so that we only see the model's answers to the questions. 118 | with tf_verbosity_level('ERROR'): 119 | model.batch_size = 8 # Min size for small model on v2-8 with parallelism 1. 120 | model.predict( 121 | input_file=predict_joint_inputs_path, 122 | output_file=predict_joint_outputs_path, 123 | # Select the most probable output token at each step. 124 | temperature=0, 125 | checkpoint_steps=check_point, 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | predict() 131 | -------------------------------------------------------------------------------- /src/delphi_plus/train/rates.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mixing rates 3 | """ 4 | 5 | import seqio 6 | 7 | def equal_rate(task: seqio.Task): 8 | """Mix the datasets in equal amounts. 9 | 10 | Parameters 11 | ---------- 12 | task : t5.data.Task 13 | The task. 14 | 15 | Returns 16 | ------- 17 | float 18 | The constant: ``1.0``. 19 | """ 20 | return 1.0 21 | 22 | 23 | def proportional_rate(task: seqio.Task): 24 | """Mix the datasets proportionally. 25 | 26 | Parameters 27 | ---------- 28 | task : t5.data.Task 29 | The task. 30 | 31 | Returns 32 | ------- 33 | float 34 | The number of examples in the task's training set. 35 | """ 36 | return float(task.num_input_examples("train")) 37 | 38 | 39 | # constants 40 | 41 | MIXING_RATES = { 42 | "equal": equal_rate, 43 | "proportional": proportional_rate, 44 | } 45 | """A dictionary mapping mixing rates' names to their implementations.""" 46 | -------------------------------------------------------------------------------- /src/delphi_plus/train/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Util functions for fine-tuning and evaluating models 3 | """ 4 | import seqio 5 | import pandas as pd 6 | from google.cloud import storage 7 | import tensorflow_datasets as tfds 8 | import tensorflow as tf 9 | 10 | 11 | def create_folder(client, bucket, destination_folder_name): 12 | """ 13 | Create a folder in Google Cloud Storage if such folder doesn't exist already 14 | """ 15 | if not storage.Blob(bucket=bucket, name=destination_folder_name).exists(client): 16 | blob = bucket.blob(destination_folder_name) 17 | blob.upload_from_string('') 18 | print('Created: {}'.format(destination_folder_name)) 19 | else: 20 | print('Exists: {}'.format(destination_folder_name)) 21 | 22 | 23 | def print_task_examples(task_name, split="validation", num_ex=1): 24 | """ 25 | Print examples from tasks 26 | """ 27 | print("#" * 20, task_name, "#" * 20) 28 | task = seqio.TaskRegistry.get(task_name) 29 | ds = task.get_dataset(split=split, sequence_length={"inputs": 512, "targets": 128}) 30 | for i, ex in enumerate(tfds.as_numpy(ds.take(num_ex))): 31 | print(i, ex) 32 | print("test", task.num_input_examples("test")) 33 | print("train", task.num_input_examples("train")) 34 | print("validation", task.num_input_examples("validation")) 35 | 36 | 37 | def print_mixture_examples(mixture_name, split="validation", num_ex=1): 38 | """ 39 | Print examples from mixtures 40 | """ 41 | print("#" * 20, mixture_name, "#" * 20) 42 | mixture = seqio.MixtureRegistry.get(mixture_name) 43 | ds = mixture.get_dataset(split=split, 44 | sequence_length={"inputs": 512, "targets": 128}) 45 | 46 | for i, ex in enumerate(tfds.as_numpy(ds.take(num_ex))): 47 | print(i, ex) 48 | print("test", mixture.num_input_examples("test")) 49 | print("train", mixture.num_input_examples("train")) 50 | print("validation", mixture.num_input_examples("validation")) 51 | 52 | 53 | def get_num_elements_csv(file_name): 54 | """ 55 | Get the total number of elements in a given csv/tsv file 56 | """ 57 | df = pd.read_csv(file_name, delimiter="\t") 58 | return df.shape[0] 59 | 60 | 61 | def get_num_elements_split(split_paths): 62 | """ 63 | Get the number of elements in each split of a dataset 64 | """ 65 | num_elements_split = {} 66 | for split, path in split_paths.items(): 67 | num_elements_split[split] = get_num_elements_csv(path) 68 | return num_elements_split 69 | 70 | 71 | def get_result_check_points(result_prefix, split, eval_data_type, after_check_point=-1): 72 | """ 73 | Get a list of model checkpoints that haven't generated on the designated data split yet 74 | """ 75 | client = storage.Client() 76 | bucket_name = "ai2-tpu-europe-west4" 77 | result_prefix = result_prefix.split(bucket_name + "/")[-1] + "/" 78 | 79 | check_points = [] 80 | done_check_points = [] 81 | for blob in client.list_blobs(bucket_name, prefix=result_prefix): 82 | blob_name = str(blob).split("/")[-1] 83 | if ".meta" in blob_name: 84 | check_point = int(blob_name.split(".meta")[0].split("-")[-1]) 85 | if check_point > after_check_point: 86 | check_points.append(check_point) 87 | 88 | print("-" * 10, "checkpoints all", "-" * 10) 89 | print(check_points) 90 | 91 | for blob in client.list_blobs(bucket_name, prefix=result_prefix + f"{split}_eval/"): 92 | blob_name = str(blob).split("/")[-1] 93 | if "_predictions" in blob_name and eval_data_type in blob_name and "_predictions_clean" not in blob_name: 94 | check_point_done = int(blob_name.split("_predictions")[0].split("_")[-1]) 95 | # check_point_done = int(blob_name.split("_")[0].split("_")[-1]) 96 | if check_point_done in check_points: 97 | done_check_points.append(check_point_done) 98 | check_points.remove(check_point_done) 99 | 100 | print("-" * 10, "checkpoints done", "-" * 10) 101 | print(done_check_points) 102 | return check_points 103 | 104 | 105 | def validate_path(results_dir, pretrained_model=None, PRETRAINED_MODELS=None): 106 | """ 107 | Validate result path 108 | """ 109 | if PRETRAINED_MODELS != None: 110 | if not results_dir.startswith("gs://"): 111 | raise ValueError(f"RESULTS_DIR ({results_dir}) must be a GCS path.") 112 | 113 | if pretrained_model.startswith("gs://"): 114 | if not tf.io.gfile.exists(pretrained_model): 115 | raise IOError( 116 | f"--pretrained-model ({pretrained_model}) does not exist." 117 | ) 118 | else: 119 | if pretrained_model not in PRETRAINED_MODELS: 120 | raise ValueError( 121 | f"--pretrained-model ({pretrained_model}) not recognized. It" 122 | f" must either be a GCS path or one of" 123 | f' {", ".join(PRETRAINED_MODELS.keys())}.') 124 | else: 125 | if not results_dir.startswith("gs://"): 126 | raise ValueError(f"RESULTS_DIR ({results_dir}) must be a GCS path.") 127 | elif not tf.io.gfile.exists(results_dir): 128 | raise IOError(f"RESULTS_DIR ({results_dir}) doesn't exist.") 129 | 130 | 131 | def print_arguments(result_path, results_dir, mixture, split, pretrained_model, 132 | pretrained_checkpoint_step, n_steps, batch_size, model_parallelism, 133 | save_checkpoints_steps, n_checkpoints_to_keep, learning_rate, 134 | tpu_name, tpu_topology, tasks, continue_finetune): 135 | print("=" * 10, "results_dir") 136 | print(results_dir) 137 | 138 | print("=" * 10, "mixture") 139 | print(mixture) 140 | 141 | print("=" * 10, "split") 142 | print(split) 143 | 144 | print("=" * 10, "pretrained_model") 145 | print(pretrained_model) 146 | 147 | print("=" * 10, "pretrained_checkpoint_step") 148 | print(pretrained_checkpoint_step) 149 | 150 | print("=" * 10, "n_steps") 151 | print(n_steps) 152 | 153 | print("=" * 10, "batch_size") 154 | print(batch_size) 155 | 156 | print("=" * 10, "model_parallelism") 157 | print(model_parallelism) 158 | 159 | print("=" * 10, "save_checkpoints_steps") 160 | print(save_checkpoints_steps) 161 | 162 | print("=" * 10, "n_checkpoints_to_keep") 163 | print(n_checkpoints_to_keep) 164 | 165 | print("=" * 10, "learning_rate") 166 | print(learning_rate) 167 | 168 | print("=" * 10, "tpu_name") 169 | print(tpu_name) 170 | 171 | print("=" * 10, "tpu_topology") 172 | print(tpu_topology) 173 | 174 | print("=" * 10, "result_path") 175 | print(result_path) 176 | 177 | print("=" * 10, "data_version") 178 | print(tasks.data_version) 179 | 180 | print("=" * 10, "continue_finetune") 181 | print(continue_finetune) 182 | 183 | 184 | def get_result_path( 185 | results_dir: str, 186 | pretrained_model: str, 187 | mixture: str, 188 | learning_rate: float, 189 | batch_size: int 190 | ) -> str: 191 | """ 192 | Get a result path given arguments 193 | """ 194 | result_path = results_dir + \ 195 | "/" + pretrained_model + \ 196 | "/" + mixture + \ 197 | f"/lr-{learning_rate}_bs-{batch_size}" 198 | return result_path 199 | 200 | --------------------------------------------------------------------------------