├── .gitignore ├── README.md ├── banner.png ├── data └── README.md ├── figures ├── .placeholder ├── edit_distance │ ├── k_vs_ed.pdf │ ├── legend.pdf │ ├── p_vs_ed.pdf │ ├── tau_vs_ed.pdf │ └── tau_vs_ed_categorical.pdf ├── llama3 │ ├── bias │ │ ├── effects_Female_income_hist_median.pdf │ │ ├── llama3_effects_Female_income_hist_median.pdf │ │ ├── llama3_effects_Male_income_hist_median.pdf │ │ ├── median_sex_income_direct_cf.pdf │ │ ├── median_sex_income_total_cf.pdf │ │ └── race_education_total.pdf │ └── occupation │ │ └── occupation_factual_Asian_American_cf_Black_or_African_American.pdf └── mistral │ ├── bias │ ├── median_sex_income_direct_cf.pdf │ ├── median_sex_income_total_cf.pdf │ ├── mistral_effects_Female_income_hist_median.pdf │ ├── mistral_effects_Male_income_hist_median.pdf │ └── race_education_total.pdf │ └── occupation │ └── occupation_factual_Asian_American_cf_Black_or_African_American.pdf ├── notebooks ├── bias.ipynb └── edit_distance.ipynb ├── outputs ├── .placeholder ├── llama3 │ ├── bias │ │ ├── direct_Sex_cf_changes.json │ │ ├── total_Race_cf_changes.json │ │ └── total_Sex_cf_changes.json │ ├── census1 │ │ ├── factual.json │ │ ├── intervention_1.json │ │ └── rngstates_1.pt │ ├── census2 │ │ ├── factual.json │ │ ├── intervention_1.json │ │ └── rngstates_1.pt │ └── census3 │ │ ├── factual.json │ │ ├── intervention_1.json │ │ └── rngstates_1.pt ├── mistral │ ├── bias │ │ ├── direct_Sex_cf_changes.json │ │ ├── total_Race_cf_changes.json │ │ └── total_Sex_cf_changes.json │ ├── census1 │ │ ├── factual.json │ │ ├── intervention_1.json │ │ └── rngstates_1.pt │ ├── census2 │ │ ├── factual.json │ │ ├── intervention_1.json │ │ └── rngstates_1.pt │ └── census3 │ │ ├── factual.json │ │ ├── intervention_1.json │ │ └── rngstates_1.pt ├── story-edit-word │ ├── counterfactual_1.json │ ├── factual.json │ ├── intervention_1.json │ ├── intervention_2.json │ └── rngstates_1.pt ├── story-interventional │ ├── counterfactual_1.json │ ├── factual.json │ ├── intervention_1.json │ ├── intervention_2.json │ └── rngstates_1.pt ├── story-name │ ├── counterfactual_1.json │ ├── factual.json │ ├── intervention_1.json │ ├── intervention_2.json │ └── rngstates_1.pt ├── story-remove-word │ ├── counterfactual_1.json │ ├── factual.json │ ├── intervention_1.json │ ├── intervention_2.json │ └── rngstates_1.pt └── story-ship │ ├── counterfactual_1.json │ ├── factual.json │ ├── intervention_1.json │ ├── intervention_2.json │ └── rngstates_1.pt ├── requirements.txt ├── scripts ├── bias.py ├── census_queries.py ├── cf_query.py ├── stability.py └── story_query.py └── src ├── bias.py ├── cf_query.py ├── llama3 ├── CODE_OF_CONDUCT.md ├── LICENSE ├── MODEL_CARD.md ├── README.md ├── USE_POLICY ├── USE_POLICY.md ├── llama │ ├── __init__.py │ ├── generation.py │ ├── model.py │ ├── sampler.py │ ├── test_tokenizer.py │ └── tokenizer.py ├── pretrained │ └── .placeholder ├── requirements.txt └── setup.py ├── mistral-inference ├── LICENSE ├── README.md ├── assets │ └── smoe.png ├── deploy │ ├── .dockerignore │ ├── Dockerfile │ └── entrypoint.sh ├── moe_one_file_ref.py ├── one_file_ref.py ├── poetry.lock ├── pyproject.toml ├── src │ └── mistral_inference │ │ ├── __init__.py │ │ ├── args.py │ │ ├── cache.py │ │ ├── generate.py │ │ ├── lora.py │ │ ├── main.py │ │ ├── mamba.py │ │ ├── model.py │ │ ├── moe.py │ │ ├── rope.py │ │ ├── transformer.py │ │ ├── transformer_layers.py │ │ └── vision_encoder.py ├── tests │ └── test_generate.py └── tutorials │ ├── classifier.ipynb │ └── getting_started.ipynb ├── sampler.py ├── single_query.py ├── stability.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # Ignore the outputs 141 | outputs/*json 142 | 143 | # Ignore the figures 144 | figures/*pdf 145 | figures/*png 146 | figures/*tif 147 | 148 | # Ignore the backups 149 | backups* 150 | 151 | # Ignore vscode things 152 | **.DS_Store 153 | 154 | # Ignore all subdirectories of src/llama3/downloads 155 | src/llama3/pretrained/Meta* 156 | 157 | # ignore all json files in data/ and its subdirectories 158 | data/**/*.json -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Counterfactual Token Generation in Large Language Models 2 | 3 | This repository contains the code used in the paper ["Counterfactual Token Generation in Large Language Models"](https://arxiv.org/abs/2409.17027) by Ivi Chatzi, Nina Corvelo Benz, Eleni Straitouri, Stratis Tsirtsis, and Manuel Gomez-Rodriguez. 4 | 5 | __Contents__: 6 | - [Introduction](#introduction) 7 | - [Dependencies](#dependencies) 8 | - [Repository structure](#repository-structure) 9 | - [Contact & attribution](#contact--attribution) 10 | 11 | ## Introduction 12 | 13 |
14 | 15 |
16 | 17 | *"Sure, I am happy to generate a story for you: Captain Lyra stood at the helm of her trusty ship, the Maelstrom’s Fury, gazing out at the endless sea. \[...\] Lyra’s eyes welled up with tears as she realized the bitter truth - she had sacrificed everything for fleeting riches, and lost the love of her crew, her family, and herself."* Although this story, generated by a large language model, is captivating, one may wonder—how would the story have unfolded if the model had chosen "Captain Maeve" as the protagonist instead? We cannot know. State-of-the-art large language models are stateless—they maintain no internal memory or state. Given a prompt, they generate a sequence of tokens as an output using an autoregressive process. As a consequence, they cannot reason about counterfactual alternatives to tokens they have generated in the past. In this work, our goal is to enhance them with this functionality. To this end, we develop a causal model of token generation that builds upon the Gumbel-Max structural causal model. Our model allows any large language model to perform counterfactual token generation at almost no cost in comparison with vanilla token generation, it is embarrassingly simple to implement, and it does not require any fine-tuning nor prompt engineering. We implement our model on Llama 3 8B-Instruct and Ministral-8B-Instruct, and conduct both qualitative and quantitative analyses of counterfactually generated text. We conclude with a demonstrative application of counterfactual token generation for bias detection, unveiling interesting insights about the model of the world constructed by large language models. 18 | 19 | 20 | ## Dependencies 21 | 22 | All the experiments were performed using Python 3.11.2. In order to create a virtual environment and install the project dependencies you can run the following commands: 23 | 24 | ```bash 25 | python3 -m venv env 26 | source env/bin/activate 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | Our code builds upon the popular open-weight large language models Llama 3 8B-instruct and Ministral-8B-instruct. For instructions regarding getting access to the weights of the Llama 3 model refer to the [Llama 3 GitHub repository](https://github.com/meta-llama/llama3) and for the weights of the Ministral model to the Mistral [online documentation](https://docs.mistral.ai/getting-started/models/weights/). In addition to the project dependencies of our own code mentioned above, make sure to first install the project dependencies of the Ministral and Llama 3 models. 31 | 32 | ## Repository structure 33 | 34 | ``` 35 | ├── data 36 | ├── figures 37 | ├── notebooks 38 | ├── outputs 39 | │ ├── llama3 40 | │ │ ├── census* 41 | │ │ └── bias 42 | │ ├── mistral 43 | │ │ ├── census* 44 | │ │ └── bias 45 | │ └── story* 46 | ├── scripts 47 | │ ├── bias.py 48 | │ ├── census_queries.py 49 | │ ├── cf_query.py 50 | │ ├── stability.py 51 | │ └── story_query.py 52 | └── src 53 | ├── llama3 54 | │ ├── llama 55 | │ │ ├── generation.py 56 | │ │ └── sampler.py 57 | │ └── pretrained 58 | ├── mistral-inference 59 | │ └── src/mistral-inference 60 | │ ├── generate.py 61 | │ └── 8B-Instruct 62 | ├── bias.py 63 | ├── cf_query.py 64 | ├── sampler.py 65 | ├── single_query.py 66 | ├── stability.py 67 | └── utils.py 68 | ``` 69 | 70 | - `data` contains configuration files for our experiments. 71 | - `figures` contains all the figures presented in the paper. 72 | - `notebooks` contains python notebooks to generate all the figures included in the paper. 73 | - `outputs/`, where `` is either `llama3` or `mistral`, contains intermediate output files generated by the experiments' scripts. Specifically: 74 | - `bias` contains the counterfactual census data of Section 4.3. 75 | - `census*` directories contain the factual census data of Section 4.3. 76 | - `outputs/story*` directories contain the results of Section 4.1 and Appendix A. 77 | - `scripts` contains a set of scripts used to run all the experiments presented in the paper. 78 | - `src` contains all the code necessary to reproduce the results in the paper. Specifically: 79 | - `llama3` contains the code of the LLama 3 8B-instruct. Therein: 80 | - `llama/generation.py` uses the LLM to perform factual/counterfactual token generation. 81 | - `llama/pretrained/` is a placeholder directory where the weights of the (pre-trained) LLM should be placed. 82 | - `mistral-inference` contains the code of the Ministral 8B-instruct. Therein: 83 | - `src/mistral-inference/generate.py` uses the LLM to perform factual/counterfactual token generation. 84 | - `src/mistral-inference/8B-Instruct/` is a placeholder directory where the weights of the (pre-trained) LLM should be placed. 85 | - `sampler.py` samples from a token distribution using a Gumbel-Max SCM or its top-p and top-k variants. 86 | - `bias.py` performs counterfactual and international token generation for the experiments of Section 4.3 using the LLM generated census data. 87 | - `cf_query.py` performs counterfactual token generation for a single query. 88 | - `single_query.py` performs factual token generation for a single query. It creates and saves its results in a subdirectory of `outputs/`, where `` is either `llama3` or `mistral`. The results are then used by `cf_query.py`. 89 | - `stability.py` performs interventional and counterfactual token generation for the experiments of Section 4.2. 90 | - `utils.py` contains auxiliary functions for plotting. 91 | 92 | 93 | ## Contact & attribution 94 | 95 | In case you have questions about the code, you identify potential bugs or you would like us to include additional functionalities, feel free to open an issue or contact [Ivi Chatzi](mailto:ichatzi@mpi-sws.org) or [Stratis Tsirtsis](mailto:stsirtsis@mpi-sws.org). 96 | 97 | If you use parts of the code in this repository for your own research, please consider citing: 98 | 99 | @article{chatzi2024counterfactual, 100 | title={Counterfactual Token Generation in Large Language Models}, 101 | author={Ivi Chatzi and Nina Corvelo Benz and Eleni Straitouri and Stratis Tsirtsis and Manuel Gomez-Rodriguez}, 102 | year={2024}, 103 | journal={arXiv preprint arXiv:2409.17027} 104 | } 105 | 106 | -------------------------------------------------------------------------------- /banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/banner.png -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## Dataset information 2 | 3 | We used questions from the [LMSYS Chat 1M dataset](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) for our text similarity experiment in section 4.2. -------------------------------------------------------------------------------- /figures/.placeholder: -------------------------------------------------------------------------------- 1 | # Placeholder for figures -------------------------------------------------------------------------------- /figures/edit_distance/k_vs_ed.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/edit_distance/k_vs_ed.pdf -------------------------------------------------------------------------------- /figures/edit_distance/legend.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/edit_distance/legend.pdf -------------------------------------------------------------------------------- /figures/edit_distance/p_vs_ed.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/edit_distance/p_vs_ed.pdf -------------------------------------------------------------------------------- /figures/edit_distance/tau_vs_ed.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/edit_distance/tau_vs_ed.pdf -------------------------------------------------------------------------------- /figures/edit_distance/tau_vs_ed_categorical.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/edit_distance/tau_vs_ed_categorical.pdf -------------------------------------------------------------------------------- /figures/llama3/bias/effects_Female_income_hist_median.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/llama3/bias/effects_Female_income_hist_median.pdf -------------------------------------------------------------------------------- /figures/llama3/bias/llama3_effects_Female_income_hist_median.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/llama3/bias/llama3_effects_Female_income_hist_median.pdf -------------------------------------------------------------------------------- /figures/llama3/bias/llama3_effects_Male_income_hist_median.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/llama3/bias/llama3_effects_Male_income_hist_median.pdf -------------------------------------------------------------------------------- /figures/llama3/bias/median_sex_income_direct_cf.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/llama3/bias/median_sex_income_direct_cf.pdf -------------------------------------------------------------------------------- /figures/llama3/bias/median_sex_income_total_cf.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/llama3/bias/median_sex_income_total_cf.pdf -------------------------------------------------------------------------------- /figures/llama3/bias/race_education_total.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/llama3/bias/race_education_total.pdf -------------------------------------------------------------------------------- /figures/llama3/occupation/occupation_factual_Asian_American_cf_Black_or_African_American.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/llama3/occupation/occupation_factual_Asian_American_cf_Black_or_African_American.pdf -------------------------------------------------------------------------------- /figures/mistral/bias/median_sex_income_direct_cf.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/mistral/bias/median_sex_income_direct_cf.pdf -------------------------------------------------------------------------------- /figures/mistral/bias/median_sex_income_total_cf.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/mistral/bias/median_sex_income_total_cf.pdf -------------------------------------------------------------------------------- /figures/mistral/bias/mistral_effects_Female_income_hist_median.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/mistral/bias/mistral_effects_Female_income_hist_median.pdf -------------------------------------------------------------------------------- /figures/mistral/bias/mistral_effects_Male_income_hist_median.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/mistral/bias/mistral_effects_Male_income_hist_median.pdf -------------------------------------------------------------------------------- /figures/mistral/bias/race_education_total.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/mistral/bias/race_education_total.pdf -------------------------------------------------------------------------------- /figures/mistral/occupation/occupation_factual_Asian_American_cf_Black_or_African_American.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/mistral/occupation/occupation_factual_Asian_American_cf_Black_or_African_American.pdf -------------------------------------------------------------------------------- /outputs/.placeholder: -------------------------------------------------------------------------------- 1 | # Placeholder for outputs -------------------------------------------------------------------------------- /outputs/llama3/census1/rngstates_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/llama3/census1/rngstates_1.pt -------------------------------------------------------------------------------- /outputs/llama3/census2/rngstates_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/llama3/census2/rngstates_1.pt -------------------------------------------------------------------------------- /outputs/llama3/census3/rngstates_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/llama3/census3/rngstates_1.pt -------------------------------------------------------------------------------- /outputs/mistral/census1/rngstates_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/mistral/census1/rngstates_1.pt -------------------------------------------------------------------------------- /outputs/mistral/census2/rngstates_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/mistral/census2/rngstates_1.pt -------------------------------------------------------------------------------- /outputs/mistral/census3/rngstates_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/mistral/census3/rngstates_1.pt -------------------------------------------------------------------------------- /outputs/story-edit-word/counterfactual_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless blue. For years, she'd sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she anchored the ship and stepped onto the sandy beach, a warm light enveloped her. With trembling hands, she reached out and touched the glittering treasure. The Golden Isle's magic surged through her, and Lyra's greatest desire took hold: to reunite with her lost love, who'd perished at sea many years ago.\n\nWith tears of joy streaming down her face, Lyra opened her eyes to find her beloved standing before her, alive and well. The two shared a long, tender kiss, surrounded by the sweet scent of tropical flowers. The \"Maelstrom's Fury\" sailed away, its crew rejoicing, as Lyra and her love disappeared into the sunset, their hearts full of happiness.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "top_p": 0.9, 8 | "top_k": 5, 9 | "sampler_type": "vocabulary", 10 | "start_from": 0, 11 | "token_list": { 12 | "0": "Captain", 13 | "1": " Ly", 14 | "2": "ra", 15 | "3": " stood", 16 | "4": " at", 17 | "5": " the", 18 | "6": " helm", 19 | "7": " of", 20 | "8": " her", 21 | "9": " trust", 22 | "10": "y", 23 | "11": " ship", 24 | "12": ",", 25 | "13": " the", 26 | "14": " \"", 27 | "15": "M", 28 | "16": "ael", 29 | "17": "strom", 30 | "18": "'s", 31 | "19": " Fury", 32 | "20": "\",", 33 | "21": " g", 34 | "22": "azing", 35 | "23": " out", 36 | "24": " at", 37 | "25": " the", 38 | "26": " endless", 39 | "27": " blue", 40 | "28": ".", 41 | "29": " For", 42 | "30": " years", 43 | "31": ",", 44 | "32": " she", 45 | "33": "'d", 46 | "34": " sailed", 47 | "35": " the", 48 | "36": " seven", 49 | "37": " seas", 50 | "38": ",", 51 | "39": " seeking", 52 | "40": " the", 53 | "41": " legendary", 54 | "42": " Golden", 55 | "43": " Isle", 56 | "44": ",", 57 | "45": " said", 58 | "46": " to", 59 | "47": " grant", 60 | "48": " a", 61 | "49": " single", 62 | "50": " wish", 63 | "51": " to", 64 | "52": " whoever", 65 | "53": " found", 66 | "54": " it", 67 | "55": ".", 68 | "56": " After", 69 | "57": " a", 70 | "58": " lifetime", 71 | "59": " of", 72 | "60": " searching", 73 | "61": ",", 74 | "62": " Ly", 75 | "63": "ra", 76 | "64": " finally", 77 | "65": " spotted", 78 | "66": " the", 79 | "67": " island", 80 | "68": " on", 81 | "69": " the", 82 | "70": " horizon", 83 | "71": ".\n\n", 84 | "72": "As", 85 | "73": " she", 86 | "74": " anchored", 87 | "75": " the", 88 | "76": " ship", 89 | "77": " and", 90 | "78": " stepped", 91 | "79": " onto", 92 | "80": " the", 93 | "81": " sandy", 94 | "82": " beach", 95 | "83": ",", 96 | "84": " a", 97 | "85": " warm", 98 | "86": " light", 99 | "87": " envelop", 100 | "88": "ed", 101 | "89": " her", 102 | "90": ".", 103 | "91": " With", 104 | "92": " trembling", 105 | "93": " hands", 106 | "94": ",", 107 | "95": " she", 108 | "96": " reached", 109 | "97": " out", 110 | "98": " and", 111 | "99": " touched", 112 | "100": " the", 113 | "101": " glitter", 114 | "102": "ing", 115 | "103": " treasure", 116 | "104": ".", 117 | "105": " The", 118 | "106": " Golden", 119 | "107": " Isle", 120 | "108": "'s", 121 | "109": " magic", 122 | "110": " surged", 123 | "111": " through", 124 | "112": " her", 125 | "113": ",", 126 | "114": " and", 127 | "115": " Ly", 128 | "116": "ra", 129 | "117": "'s", 130 | "118": " greatest", 131 | "119": " desire", 132 | "120": " took", 133 | "121": " hold", 134 | "122": ":", 135 | "123": " to", 136 | "124": " reun", 137 | "125": "ite", 138 | "126": " with", 139 | "127": " her", 140 | "128": " lost", 141 | "129": " love", 142 | "130": ",", 143 | "131": " who", 144 | "132": "'d", 145 | "133": " per", 146 | "134": "ished", 147 | "135": " at", 148 | "136": " sea", 149 | "137": " many", 150 | "138": " years", 151 | "139": " ago", 152 | "140": ".\n\n", 153 | "141": "With", 154 | "142": " tears", 155 | "143": " of", 156 | "144": " joy", 157 | "145": " streaming", 158 | "146": " down", 159 | "147": " her", 160 | "148": " face", 161 | "149": ",", 162 | "150": " Ly", 163 | "151": "ra", 164 | "152": " opened", 165 | "153": " her", 166 | "154": " eyes", 167 | "155": " to", 168 | "156": " find", 169 | "157": " her", 170 | "158": " beloved", 171 | "159": " standing", 172 | "160": " before", 173 | "161": " her", 174 | "162": ",", 175 | "163": " alive", 176 | "164": " and", 177 | "165": " well", 178 | "166": ".", 179 | "167": " The", 180 | "168": " two", 181 | "169": " shared", 182 | "170": " a", 183 | "171": " long", 184 | "172": ",", 185 | "173": " tender", 186 | "174": " kiss", 187 | "175": ",", 188 | "176": " surrounded", 189 | "177": " by", 190 | "178": " the", 191 | "179": " sweet", 192 | "180": " scent", 193 | "181": " of", 194 | "182": " tropical", 195 | "183": " flowers", 196 | "184": ".", 197 | "185": " The", 198 | "186": " \"", 199 | "187": "M", 200 | "188": "ael", 201 | "189": "strom", 202 | "190": "'s", 203 | "191": " Fury", 204 | "192": "\"", 205 | "193": " sailed", 206 | "194": " away", 207 | "195": ",", 208 | "196": " its", 209 | "197": " crew", 210 | "198": " rejo", 211 | "199": "icing", 212 | "200": ",", 213 | "201": " as", 214 | "202": " Ly", 215 | "203": "ra", 216 | "204": " and", 217 | "205": " her", 218 | "206": " love", 219 | "207": " disappeared", 220 | "208": " into", 221 | "209": " the", 222 | "210": " sunset", 223 | "211": ",", 224 | "212": " their", 225 | "213": " hearts", 226 | "214": " full", 227 | "215": " of", 228 | "216": " happiness", 229 | "217": "." 230 | } 231 | } -------------------------------------------------------------------------------- /outputs/story-edit-word/factual.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold and jewels lay at her feet. Lyra's heart soared as she reached for the treasure. But, as she touched the gold, the isle began to fade away, leaving her with nothing but a whispered promise: \"The true treasure was the journey, not the gold.\" Lyra's eyes welled up with tears as she realized the bitter truth \u2013 she had sacrificed everything for fleeting riches, and lost the love of her crew, her family, and herself. The \"Maelstrom's Fury\" sailed away, a lonely shadow on the vast ocean, carrying a captain forever haunted by the what-ifs.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "sampler_type": "vocabulary", 8 | "top_p": 0.9, 9 | "top_k": 5 10 | } -------------------------------------------------------------------------------- /outputs/story-edit-word/intervention_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless blue.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "sampler_type": "vocabulary", 8 | "top_p": 0.9, 9 | "top_k": 5, 10 | "start_from": 29, 11 | "token_list": { 12 | "0": "Captain", 13 | "1": " Ly", 14 | "2": "ra", 15 | "3": " stood", 16 | "4": " at", 17 | "5": " the", 18 | "6": " helm", 19 | "7": " of", 20 | "8": " her", 21 | "9": " trust", 22 | "10": "y", 23 | "11": " ship", 24 | "12": ",", 25 | "13": " the", 26 | "14": " \"", 27 | "15": "M", 28 | "16": "ael", 29 | "17": "strom", 30 | "18": "'s", 31 | "19": " Fury", 32 | "20": "\",", 33 | "21": " g", 34 | "22": "azing", 35 | "23": " out", 36 | "24": " at", 37 | "25": " the", 38 | "26": " endless", 39 | "27": " sea", 40 | "28": ".", 41 | "29": " For", 42 | "30": " years", 43 | "31": ",", 44 | "32": " she", 45 | "33": " had", 46 | "34": " sailed", 47 | "35": " the", 48 | "36": " seven", 49 | "37": " seas", 50 | "38": ",", 51 | "39": " seeking", 52 | "40": " the", 53 | "41": " legendary", 54 | "42": " Golden", 55 | "43": " Isle", 56 | "44": ",", 57 | "45": " said", 58 | "46": " to", 59 | "47": " grant", 60 | "48": " a", 61 | "49": " single", 62 | "50": " wish", 63 | "51": " to", 64 | "52": " whoever", 65 | "53": " found", 66 | "54": " it", 67 | "55": ".", 68 | "56": " After", 69 | "57": " a", 70 | "58": " lifetime", 71 | "59": " of", 72 | "60": " searching", 73 | "61": ",", 74 | "62": " Ly", 75 | "63": "ra", 76 | "64": " finally", 77 | "65": " spotted", 78 | "66": " the", 79 | "67": " island", 80 | "68": " on", 81 | "69": " the", 82 | "70": " horizon", 83 | "71": ".\n\n", 84 | "72": "As", 85 | "73": " she", 86 | "74": " stepped", 87 | "75": " foot", 88 | "76": " on", 89 | "77": " its", 90 | "78": " sandy", 91 | "79": " shores", 92 | "80": ",", 93 | "81": " a", 94 | "82": " warm", 95 | "83": " glow", 96 | "84": " envelop", 97 | "85": "ed", 98 | "86": " her", 99 | "87": ".", 100 | "88": " Before", 101 | "89": " her", 102 | "90": ",", 103 | "91": " a", 104 | "92": " chest", 105 | "93": " overflowing", 106 | "94": " with", 107 | "95": " glitter", 108 | "96": "ing", 109 | "97": " gold", 110 | "98": " and", 111 | "99": " jewels", 112 | "100": " lay", 113 | "101": " at", 114 | "102": " her", 115 | "103": " feet", 116 | "104": ".", 117 | "105": " Ly", 118 | "106": "ra", 119 | "107": "'s", 120 | "108": " heart", 121 | "109": " soared", 122 | "110": " as", 123 | "111": " she", 124 | "112": " reached", 125 | "113": " for", 126 | "114": " the", 127 | "115": " treasure", 128 | "116": ".", 129 | "117": " But", 130 | "118": ",", 131 | "119": " as", 132 | "120": " she", 133 | "121": " touched", 134 | "122": " the", 135 | "123": " gold", 136 | "124": ",", 137 | "125": " the", 138 | "126": " is", 139 | "127": "le", 140 | "128": " began", 141 | "129": " to", 142 | "130": " fade", 143 | "131": " away", 144 | "132": ",", 145 | "133": " leaving", 146 | "134": " her", 147 | "135": " with", 148 | "136": " nothing", 149 | "137": " but", 150 | "138": " a", 151 | "139": " whispered", 152 | "140": " promise", 153 | "141": ":", 154 | "142": " \"", 155 | "143": "The", 156 | "144": " true", 157 | "145": " treasure", 158 | "146": " was", 159 | "147": " the", 160 | "148": " journey", 161 | "149": ",", 162 | "150": " not", 163 | "151": " the", 164 | "152": " gold", 165 | "153": ".\"", 166 | "154": " Ly", 167 | "155": "ra", 168 | "156": "'s", 169 | "157": " eyes", 170 | "158": " wel", 171 | "159": "led", 172 | "160": " up", 173 | "161": " with", 174 | "162": " tears", 175 | "163": " as", 176 | "164": " she", 177 | "165": " realized", 178 | "166": " the", 179 | "167": " bitter", 180 | "168": " truth", 181 | "169": " \u2013", 182 | "170": " she", 183 | "171": " had", 184 | "172": " sacrificed", 185 | "173": " everything", 186 | "174": " for", 187 | "175": " fleeting", 188 | "176": " riches", 189 | "177": ",", 190 | "178": " and", 191 | "179": " lost", 192 | "180": " the", 193 | "181": " love", 194 | "182": " of", 195 | "183": " her", 196 | "184": " crew", 197 | "185": ",", 198 | "186": " her", 199 | "187": " family", 200 | "188": ",", 201 | "189": " and", 202 | "190": " herself", 203 | "191": ".", 204 | "192": " The", 205 | "193": " \"", 206 | "194": "M", 207 | "195": "ael", 208 | "196": "strom", 209 | "197": "'s", 210 | "198": " Fury", 211 | "199": "\"", 212 | "200": " sailed", 213 | "201": " away", 214 | "202": ",", 215 | "203": " a", 216 | "204": " lonely", 217 | "205": " shadow", 218 | "206": " on", 219 | "207": " the", 220 | "208": " vast", 221 | "209": " ocean", 222 | "210": ",", 223 | "211": " carrying", 224 | "212": " a", 225 | "213": " captain", 226 | "214": " forever", 227 | "215": " haunted", 228 | "216": " by", 229 | "217": " the", 230 | "218": " what", 231 | "219": "-", 232 | "220": "ifs", 233 | "221": "." 234 | } 235 | } -------------------------------------------------------------------------------- /outputs/story-edit-word/intervention_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless blue. For years, she'd sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she anchored the ship and stepped onto the sandy beach, a warm light enveloped her. With trembling hands, she reached out and touched the glittering treasure. The Golden Isle's magic surged through her, and Lyra's greatest desire took hold: to reunite with her lost love, who'd perished at sea many years ago.\n\nWith tears of joy streaming down her face, Lyra opened her eyes to find her beloved standing before her, alive and well. The two shared a long, tender kiss, surrounded by the sweet scent of tropical flowers. The \"Maelstrom's Fury\" sailed away, its crew rejoicing, as Lyra and her love disappeared into the sunset, their hearts full of happiness.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "top_p": 0.9, 8 | "top_k": 5, 9 | "sampler_type": "vocabulary", 10 | "start_from": 0, 11 | "token_list": { 12 | "0": "Captain", 13 | "1": " Ly", 14 | "2": "ra", 15 | "3": " stood", 16 | "4": " at", 17 | "5": " the", 18 | "6": " helm", 19 | "7": " of", 20 | "8": " her", 21 | "9": " trust", 22 | "10": "y", 23 | "11": " ship", 24 | "12": ",", 25 | "13": " the", 26 | "14": " \"", 27 | "15": "M", 28 | "16": "ael", 29 | "17": "strom", 30 | "18": "'s", 31 | "19": " Fury", 32 | "20": "\",", 33 | "21": " g", 34 | "22": "azing", 35 | "23": " out", 36 | "24": " at", 37 | "25": " the", 38 | "26": " endless", 39 | "27": " blue", 40 | "28": ".", 41 | "29": " For", 42 | "30": " years", 43 | "31": ",", 44 | "32": " she", 45 | "33": "'d", 46 | "34": " sailed", 47 | "35": " the", 48 | "36": " seven", 49 | "37": " seas", 50 | "38": ",", 51 | "39": " seeking", 52 | "40": " the", 53 | "41": " legendary", 54 | "42": " Golden", 55 | "43": " Isle", 56 | "44": ",", 57 | "45": " said", 58 | "46": " to", 59 | "47": " grant", 60 | "48": " a", 61 | "49": " single", 62 | "50": " wish", 63 | "51": " to", 64 | "52": " whoever", 65 | "53": " found", 66 | "54": " it", 67 | "55": ".", 68 | "56": " After", 69 | "57": " a", 70 | "58": " lifetime", 71 | "59": " of", 72 | "60": " searching", 73 | "61": ",", 74 | "62": " Ly", 75 | "63": "ra", 76 | "64": " finally", 77 | "65": " spotted", 78 | "66": " the", 79 | "67": " island", 80 | "68": " on", 81 | "69": " the", 82 | "70": " horizon", 83 | "71": ".\n\n", 84 | "72": "As", 85 | "73": " she", 86 | "74": " anchored", 87 | "75": " the", 88 | "76": " ship", 89 | "77": " and", 90 | "78": " stepped", 91 | "79": " onto", 92 | "80": " the", 93 | "81": " sandy", 94 | "82": " beach", 95 | "83": ",", 96 | "84": " a", 97 | "85": " warm", 98 | "86": " light", 99 | "87": " envelop", 100 | "88": "ed", 101 | "89": " her", 102 | "90": ".", 103 | "91": " With", 104 | "92": " trembling", 105 | "93": " hands", 106 | "94": ",", 107 | "95": " she", 108 | "96": " reached", 109 | "97": " out", 110 | "98": " and", 111 | "99": " touched", 112 | "100": " the", 113 | "101": " glitter", 114 | "102": "ing", 115 | "103": " treasure", 116 | "104": ".", 117 | "105": " The", 118 | "106": " Golden", 119 | "107": " Isle", 120 | "108": "'s", 121 | "109": " magic", 122 | "110": " surged", 123 | "111": " through", 124 | "112": " her", 125 | "113": ",", 126 | "114": " and", 127 | "115": " Ly", 128 | "116": "ra", 129 | "117": "'s", 130 | "118": " greatest", 131 | "119": " desire", 132 | "120": " took", 133 | "121": " hold", 134 | "122": ":", 135 | "123": " to", 136 | "124": " reun", 137 | "125": "ite", 138 | "126": " with", 139 | "127": " her", 140 | "128": " lost", 141 | "129": " love", 142 | "130": ",", 143 | "131": " who", 144 | "132": "'d", 145 | "133": " per", 146 | "134": "ished", 147 | "135": " at", 148 | "136": " sea", 149 | "137": " many", 150 | "138": " years", 151 | "139": " ago", 152 | "140": ".\n\n", 153 | "141": "With", 154 | "142": " tears", 155 | "143": " of", 156 | "144": " joy", 157 | "145": " streaming", 158 | "146": " down", 159 | "147": " her", 160 | "148": " face", 161 | "149": ",", 162 | "150": " Ly", 163 | "151": "ra", 164 | "152": " opened", 165 | "153": " her", 166 | "154": " eyes", 167 | "155": " to", 168 | "156": " find", 169 | "157": " her", 170 | "158": " beloved", 171 | "159": " standing", 172 | "160": " before", 173 | "161": " her", 174 | "162": ",", 175 | "163": " alive", 176 | "164": " and", 177 | "165": " well", 178 | "166": ".", 179 | "167": " The", 180 | "168": " two", 181 | "169": " shared", 182 | "170": " a", 183 | "171": " long", 184 | "172": ",", 185 | "173": " tender", 186 | "174": " kiss", 187 | "175": ",", 188 | "176": " surrounded", 189 | "177": " by", 190 | "178": " the", 191 | "179": " sweet", 192 | "180": " scent", 193 | "181": " of", 194 | "182": " tropical", 195 | "183": " flowers", 196 | "184": ".", 197 | "185": " The", 198 | "186": " \"", 199 | "187": "M", 200 | "188": "ael", 201 | "189": "strom", 202 | "190": "'s", 203 | "191": " Fury", 204 | "192": "\"", 205 | "193": " sailed", 206 | "194": " away", 207 | "195": ",", 208 | "196": " its", 209 | "197": " crew", 210 | "198": " rejo", 211 | "199": "icing", 212 | "200": ",", 213 | "201": " as", 214 | "202": " Ly", 215 | "203": "ra", 216 | "204": " and", 217 | "205": " her", 218 | "206": " love", 219 | "207": " disappeared", 220 | "208": " into", 221 | "209": " the", 222 | "210": " sunset", 223 | "211": ",", 224 | "212": " their", 225 | "213": " hearts", 226 | "214": " full", 227 | "215": " of", 228 | "216": " happiness", 229 | "217": "." 230 | } 231 | } -------------------------------------------------------------------------------- /outputs/story-edit-word/rngstates_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/story-edit-word/rngstates_1.pt -------------------------------------------------------------------------------- /outputs/story-interventional/counterfactual_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Maeve stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she'd sailed the waves, searching for the fabled Golden Isle, where legend whispered the sea goddess herself granted eternal youth to those who found it.\n\nAs the sun dipped into the horizon, Maeve spotted it \u2013 a shimmering, golden light on the horizon. With a cry of triumph, she steered the ship towards it. As they reached the shore, Maeve leapt out, her eyes drinking in the beauty of the island. And then, she saw her \u2013 the goddess, bathing in the golden light. But, as Maeve approached, the goddess vanished, leaving behind a small, intricately carved box. With a heavy heart, Maeve opened it, finding the island's secret: the price of eternal youth was to never find it.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "top_p": 0.9, 8 | "top_k": 5, 9 | "sampler_type": "vocabulary", 10 | "start_from": 0, 11 | "token_list": { 12 | "0": "Captain", 13 | "1": " Mae", 14 | "2": "ve", 15 | "3": " stood", 16 | "4": " at", 17 | "5": " the", 18 | "6": " helm", 19 | "7": " of", 20 | "8": " her", 21 | "9": " trust", 22 | "10": "y", 23 | "11": " ship", 24 | "12": ",", 25 | "13": " the", 26 | "14": " \"", 27 | "15": "M", 28 | "16": "ael", 29 | "17": "strom", 30 | "18": "'s", 31 | "19": " Fury", 32 | "20": "\",", 33 | "21": " g", 34 | "22": "azing", 35 | "23": " out", 36 | "24": " at", 37 | "25": " the", 38 | "26": " endless", 39 | "27": " sea", 40 | "28": ".", 41 | "29": " For", 42 | "30": " years", 43 | "31": ",", 44 | "32": " she", 45 | "33": "'d", 46 | "34": " sailed", 47 | "35": " the", 48 | "36": " waves", 49 | "37": ",", 50 | "38": " searching", 51 | "39": " for", 52 | "40": " the", 53 | "41": " f", 54 | "42": "abled", 55 | "43": " Golden", 56 | "44": " Isle", 57 | "45": ",", 58 | "46": " where", 59 | "47": " legend", 60 | "48": " whispered", 61 | "49": " the", 62 | "50": " sea", 63 | "51": " goddess", 64 | "52": " herself", 65 | "53": " granted", 66 | "54": " eternal", 67 | "55": " youth", 68 | "56": " to", 69 | "57": " those", 70 | "58": " who", 71 | "59": " found", 72 | "60": " it", 73 | "61": ".\n\n", 74 | "62": "As", 75 | "63": " the", 76 | "64": " sun", 77 | "65": " dipped", 78 | "66": " into", 79 | "67": " the", 80 | "68": " horizon", 81 | "69": ",", 82 | "70": " Mae", 83 | "71": "ve", 84 | "72": " spotted", 85 | "73": " it", 86 | "74": " \u2013", 87 | "75": " a", 88 | "76": " shimmer", 89 | "77": "ing", 90 | "78": ",", 91 | "79": " golden", 92 | "80": " light", 93 | "81": " on", 94 | "82": " the", 95 | "83": " horizon", 96 | "84": ".", 97 | "85": " With", 98 | "86": " a", 99 | "87": " cry", 100 | "88": " of", 101 | "89": " triumph", 102 | "90": ",", 103 | "91": " she", 104 | "92": " ste", 105 | "93": "ered", 106 | "94": " the", 107 | "95": " ship", 108 | "96": " towards", 109 | "97": " it", 110 | "98": ".", 111 | "99": " As", 112 | "100": " they", 113 | "101": " reached", 114 | "102": " the", 115 | "103": " shore", 116 | "104": ",", 117 | "105": " Mae", 118 | "106": "ve", 119 | "107": " le", 120 | "108": "apt", 121 | "109": " out", 122 | "110": ",", 123 | "111": " her", 124 | "112": " eyes", 125 | "113": " drinking", 126 | "114": " in", 127 | "115": " the", 128 | "116": " beauty", 129 | "117": " of", 130 | "118": " the", 131 | "119": " island", 132 | "120": ".", 133 | "121": " And", 134 | "122": " then", 135 | "123": ",", 136 | "124": " she", 137 | "125": " saw", 138 | "126": " her", 139 | "127": " \u2013", 140 | "128": " the", 141 | "129": " goddess", 142 | "130": ",", 143 | "131": " bathing", 144 | "132": " in", 145 | "133": " the", 146 | "134": " golden", 147 | "135": " light", 148 | "136": ".", 149 | "137": " But", 150 | "138": ",", 151 | "139": " as", 152 | "140": " Mae", 153 | "141": "ve", 154 | "142": " approached", 155 | "143": ",", 156 | "144": " the", 157 | "145": " goddess", 158 | "146": " vanished", 159 | "147": ",", 160 | "148": " leaving", 161 | "149": " behind", 162 | "150": " a", 163 | "151": " small", 164 | "152": ",", 165 | "153": " intric", 166 | "154": "ately", 167 | "155": " carved", 168 | "156": " box", 169 | "157": ".", 170 | "158": " With", 171 | "159": " a", 172 | "160": " heavy", 173 | "161": " heart", 174 | "162": ",", 175 | "163": " Mae", 176 | "164": "ve", 177 | "165": " opened", 178 | "166": " it", 179 | "167": ",", 180 | "168": " finding", 181 | "169": " the", 182 | "170": " island", 183 | "171": "'s", 184 | "172": " secret", 185 | "173": ":", 186 | "174": " the", 187 | "175": " price", 188 | "176": " of", 189 | "177": " eternal", 190 | "178": " youth", 191 | "179": " was", 192 | "180": " to", 193 | "181": " never", 194 | "182": " find", 195 | "183": " it", 196 | "184": "." 197 | } 198 | } -------------------------------------------------------------------------------- /outputs/story-interventional/factual.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold and jewels lay at her feet. Lyra's heart soared as she reached for the treasure. But, as she touched the gold, the isle began to fade away, leaving her with nothing but a whispered promise: \"The true treasure was the journey, not the gold.\" Lyra's eyes welled up with tears as she realized the bitter truth \u2013 she had sacrificed everything for fleeting riches, and lost the love of her crew, her family, and herself. The \"Maelstrom's Fury\" sailed away, a lonely shadow on the vast ocean, carrying a captain forever haunted by the what-ifs.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "sampler_type": "vocabulary", 8 | "top_p": 0.9, 9 | "top_k": 5 10 | } -------------------------------------------------------------------------------- /outputs/story-interventional/intervention_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Maeve stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "sampler_type": "vocabulary", 8 | "top_p": 0.9, 9 | "top_k": 5, 10 | "start_from": 29, 11 | "token_list": { 12 | "0": "Captain", 13 | "1": " Ly", 14 | "2": "ra", 15 | "3": " stood", 16 | "4": " at", 17 | "5": " the", 18 | "6": " helm", 19 | "7": " of", 20 | "8": " her", 21 | "9": " trust", 22 | "10": "y", 23 | "11": " ship", 24 | "12": ",", 25 | "13": " the", 26 | "14": " \"", 27 | "15": "M", 28 | "16": "ael", 29 | "17": "strom", 30 | "18": "'s", 31 | "19": " Fury", 32 | "20": "\",", 33 | "21": " g", 34 | "22": "azing", 35 | "23": " out", 36 | "24": " at", 37 | "25": " the", 38 | "26": " endless", 39 | "27": " sea", 40 | "28": ".", 41 | "29": " For", 42 | "30": " years", 43 | "31": ",", 44 | "32": " she", 45 | "33": " had", 46 | "34": " sailed", 47 | "35": " the", 48 | "36": " seven", 49 | "37": " seas", 50 | "38": ",", 51 | "39": " seeking", 52 | "40": " the", 53 | "41": " legendary", 54 | "42": " Golden", 55 | "43": " Isle", 56 | "44": ",", 57 | "45": " said", 58 | "46": " to", 59 | "47": " grant", 60 | "48": " a", 61 | "49": " single", 62 | "50": " wish", 63 | "51": " to", 64 | "52": " whoever", 65 | "53": " found", 66 | "54": " it", 67 | "55": ".", 68 | "56": " After", 69 | "57": " a", 70 | "58": " lifetime", 71 | "59": " of", 72 | "60": " searching", 73 | "61": ",", 74 | "62": " Ly", 75 | "63": "ra", 76 | "64": " finally", 77 | "65": " spotted", 78 | "66": " the", 79 | "67": " island", 80 | "68": " on", 81 | "69": " the", 82 | "70": " horizon", 83 | "71": ".\n\n", 84 | "72": "As", 85 | "73": " she", 86 | "74": " stepped", 87 | "75": " foot", 88 | "76": " on", 89 | "77": " its", 90 | "78": " sandy", 91 | "79": " shores", 92 | "80": ",", 93 | "81": " a", 94 | "82": " warm", 95 | "83": " glow", 96 | "84": " envelop", 97 | "85": "ed", 98 | "86": " her", 99 | "87": ".", 100 | "88": " Before", 101 | "89": " her", 102 | "90": ",", 103 | "91": " a", 104 | "92": " chest", 105 | "93": " overflowing", 106 | "94": " with", 107 | "95": " glitter", 108 | "96": "ing", 109 | "97": " gold", 110 | "98": " and", 111 | "99": " jewels", 112 | "100": " lay", 113 | "101": " at", 114 | "102": " her", 115 | "103": " feet", 116 | "104": ".", 117 | "105": " Ly", 118 | "106": "ra", 119 | "107": "'s", 120 | "108": " heart", 121 | "109": " soared", 122 | "110": " as", 123 | "111": " she", 124 | "112": " reached", 125 | "113": " for", 126 | "114": " the", 127 | "115": " treasure", 128 | "116": ".", 129 | "117": " But", 130 | "118": ",", 131 | "119": " as", 132 | "120": " she", 133 | "121": " touched", 134 | "122": " the", 135 | "123": " gold", 136 | "124": ",", 137 | "125": " the", 138 | "126": " is", 139 | "127": "le", 140 | "128": " began", 141 | "129": " to", 142 | "130": " fade", 143 | "131": " away", 144 | "132": ",", 145 | "133": " leaving", 146 | "134": " her", 147 | "135": " with", 148 | "136": " nothing", 149 | "137": " but", 150 | "138": " a", 151 | "139": " whispered", 152 | "140": " promise", 153 | "141": ":", 154 | "142": " \"", 155 | "143": "The", 156 | "144": " true", 157 | "145": " treasure", 158 | "146": " was", 159 | "147": " the", 160 | "148": " journey", 161 | "149": ",", 162 | "150": " not", 163 | "151": " the", 164 | "152": " gold", 165 | "153": ".\"", 166 | "154": " Ly", 167 | "155": "ra", 168 | "156": "'s", 169 | "157": " eyes", 170 | "158": " wel", 171 | "159": "led", 172 | "160": " up", 173 | "161": " with", 174 | "162": " tears", 175 | "163": " as", 176 | "164": " she", 177 | "165": " realized", 178 | "166": " the", 179 | "167": " bitter", 180 | "168": " truth", 181 | "169": " \u2013", 182 | "170": " she", 183 | "171": " had", 184 | "172": " sacrificed", 185 | "173": " everything", 186 | "174": " for", 187 | "175": " fleeting", 188 | "176": " riches", 189 | "177": ",", 190 | "178": " and", 191 | "179": " lost", 192 | "180": " the", 193 | "181": " love", 194 | "182": " of", 195 | "183": " her", 196 | "184": " crew", 197 | "185": ",", 198 | "186": " her", 199 | "187": " family", 200 | "188": ",", 201 | "189": " and", 202 | "190": " herself", 203 | "191": ".", 204 | "192": " The", 205 | "193": " \"", 206 | "194": "M", 207 | "195": "ael", 208 | "196": "strom", 209 | "197": "'s", 210 | "198": " Fury", 211 | "199": "\"", 212 | "200": " sailed", 213 | "201": " away", 214 | "202": ",", 215 | "203": " a", 216 | "204": " lonely", 217 | "205": " shadow", 218 | "206": " on", 219 | "207": " the", 220 | "208": " vast", 221 | "209": " ocean", 222 | "210": ",", 223 | "211": " carrying", 224 | "212": " a", 225 | "213": " captain", 226 | "214": " forever", 227 | "215": " haunted", 228 | "216": " by", 229 | "217": " the", 230 | "218": " what", 231 | "219": "-", 232 | "220": "ifs", 233 | "221": "." 234 | } 235 | } -------------------------------------------------------------------------------- /outputs/story-interventional/intervention_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Maeve stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she'd sailed the waves, searching for the fabled Golden Isle, where legend whispered the sea goddess herself granted eternal youth to those who found it.\n\nAs the sun dipped into the horizon, Maeve spotted it \u2013 a shimmering, golden light on the horizon. With a cry of triumph, she steered the ship towards it. As they reached the shore, Maeve leapt out, her eyes drinking in the beauty of the island. And then, she saw her \u2013 the goddess, bathing in the golden light. But, as Maeve approached, the goddess vanished, leaving behind a small, intricately carved box. With a heavy heart, Maeve opened it, finding the island's secret: the price of eternal youth was to never find it.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "top_p": 0.9, 8 | "top_k": 5, 9 | "sampler_type": "vocabulary", 10 | "start_from": 0, 11 | "token_list": { 12 | "0": "Captain", 13 | "1": " Mae", 14 | "2": "ve", 15 | "3": " stood", 16 | "4": " at", 17 | "5": " the", 18 | "6": " helm", 19 | "7": " of", 20 | "8": " her", 21 | "9": " trust", 22 | "10": "y", 23 | "11": " ship", 24 | "12": ",", 25 | "13": " the", 26 | "14": " \"", 27 | "15": "M", 28 | "16": "ael", 29 | "17": "strom", 30 | "18": "'s", 31 | "19": " Fury", 32 | "20": "\",", 33 | "21": " g", 34 | "22": "azing", 35 | "23": " out", 36 | "24": " at", 37 | "25": " the", 38 | "26": " endless", 39 | "27": " sea", 40 | "28": ".", 41 | "29": " For", 42 | "30": " years", 43 | "31": ",", 44 | "32": " she", 45 | "33": "'d", 46 | "34": " sailed", 47 | "35": " the", 48 | "36": " waves", 49 | "37": ",", 50 | "38": " searching", 51 | "39": " for", 52 | "40": " the", 53 | "41": " f", 54 | "42": "abled", 55 | "43": " Golden", 56 | "44": " Isle", 57 | "45": ",", 58 | "46": " where", 59 | "47": " legend", 60 | "48": " whispered", 61 | "49": " the", 62 | "50": " sea", 63 | "51": " goddess", 64 | "52": " herself", 65 | "53": " granted", 66 | "54": " eternal", 67 | "55": " youth", 68 | "56": " to", 69 | "57": " those", 70 | "58": " who", 71 | "59": " found", 72 | "60": " it", 73 | "61": ".\n\n", 74 | "62": "As", 75 | "63": " the", 76 | "64": " sun", 77 | "65": " dipped", 78 | "66": " into", 79 | "67": " the", 80 | "68": " horizon", 81 | "69": ",", 82 | "70": " Mae", 83 | "71": "ve", 84 | "72": " spotted", 85 | "73": " it", 86 | "74": " \u2013", 87 | "75": " a", 88 | "76": " shimmer", 89 | "77": "ing", 90 | "78": ",", 91 | "79": " golden", 92 | "80": " light", 93 | "81": " on", 94 | "82": " the", 95 | "83": " horizon", 96 | "84": ".", 97 | "85": " With", 98 | "86": " a", 99 | "87": " cry", 100 | "88": " of", 101 | "89": " triumph", 102 | "90": ",", 103 | "91": " she", 104 | "92": " ste", 105 | "93": "ered", 106 | "94": " the", 107 | "95": " ship", 108 | "96": " towards", 109 | "97": " it", 110 | "98": ".", 111 | "99": " As", 112 | "100": " they", 113 | "101": " reached", 114 | "102": " the", 115 | "103": " shore", 116 | "104": ",", 117 | "105": " Mae", 118 | "106": "ve", 119 | "107": " le", 120 | "108": "apt", 121 | "109": " out", 122 | "110": ",", 123 | "111": " her", 124 | "112": " eyes", 125 | "113": " drinking", 126 | "114": " in", 127 | "115": " the", 128 | "116": " beauty", 129 | "117": " of", 130 | "118": " the", 131 | "119": " island", 132 | "120": ".", 133 | "121": " And", 134 | "122": " then", 135 | "123": ",", 136 | "124": " she", 137 | "125": " saw", 138 | "126": " her", 139 | "127": " \u2013", 140 | "128": " the", 141 | "129": " goddess", 142 | "130": ",", 143 | "131": " bathing", 144 | "132": " in", 145 | "133": " the", 146 | "134": " golden", 147 | "135": " light", 148 | "136": ".", 149 | "137": " But", 150 | "138": ",", 151 | "139": " as", 152 | "140": " Mae", 153 | "141": "ve", 154 | "142": " approached", 155 | "143": ",", 156 | "144": " the", 157 | "145": " goddess", 158 | "146": " vanished", 159 | "147": ",", 160 | "148": " leaving", 161 | "149": " behind", 162 | "150": " a", 163 | "151": " small", 164 | "152": ",", 165 | "153": " intric", 166 | "154": "ately", 167 | "155": " carved", 168 | "156": " box", 169 | "157": ".", 170 | "158": " With", 171 | "159": " a", 172 | "160": " heavy", 173 | "161": " heart", 174 | "162": ",", 175 | "163": " Mae", 176 | "164": "ve", 177 | "165": " opened", 178 | "166": " it", 179 | "167": ",", 180 | "168": " finding", 181 | "169": " the", 182 | "170": " island", 183 | "171": "'s", 184 | "172": " secret", 185 | "173": ":", 186 | "174": " the", 187 | "175": " price", 188 | "176": " of", 189 | "177": " eternal", 190 | "178": " youth", 191 | "179": " was", 192 | "180": " to", 193 | "181": " never", 194 | "182": " find", 195 | "183": " it", 196 | "184": "." 197 | } 198 | } -------------------------------------------------------------------------------- /outputs/story-interventional/rngstates_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/story-interventional/rngstates_1.pt -------------------------------------------------------------------------------- /outputs/story-name/counterfactual_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Maeve stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, she finally spotted its shimmering shores. As she stepped onto the island, a radiant light enveloped her, and she wished for her lost love, the brave sailor, Eamon, to return to her. Suddenly, Eamon emerged from the mist, frail but alive. Overjoyed, Maeve scooped him into her arms, and they shared a tender kiss as the sun dipped into the horizon, their love reunited at last.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "top_p": 0.9, 8 | "top_k": 5, 9 | "sampler_type": "vocabulary", 10 | "start_from": 0, 11 | "token_list": { 12 | "0": "Captain", 13 | "1": " Mae", 14 | "2": "ve", 15 | "3": " stood", 16 | "4": " at", 17 | "5": " the", 18 | "6": " helm", 19 | "7": " of", 20 | "8": " her", 21 | "9": " trust", 22 | "10": "y", 23 | "11": " ship", 24 | "12": ",", 25 | "13": " the", 26 | "14": " \"", 27 | "15": "M", 28 | "16": "ael", 29 | "17": "strom", 30 | "18": "'s", 31 | "19": " Fury", 32 | "20": "\",", 33 | "21": " g", 34 | "22": "azing", 35 | "23": " out", 36 | "24": " at", 37 | "25": " the", 38 | "26": " endless", 39 | "27": " sea", 40 | "28": ".", 41 | "29": " For", 42 | "30": " years", 43 | "31": ",", 44 | "32": " she", 45 | "33": " had", 46 | "34": " sailed", 47 | "35": " the", 48 | "36": " seven", 49 | "37": " seas", 50 | "38": ",", 51 | "39": " seeking", 52 | "40": " the", 53 | "41": " legendary", 54 | "42": " Golden", 55 | "43": " Isle", 56 | "44": ",", 57 | "45": " said", 58 | "46": " to", 59 | "47": " grant", 60 | "48": " a", 61 | "49": " single", 62 | "50": " wish", 63 | "51": " to", 64 | "52": " whoever", 65 | "53": " found", 66 | "54": " it", 67 | "55": ".", 68 | "56": " After", 69 | "57": " a", 70 | "58": " lifetime", 71 | "59": " of", 72 | "60": " searching", 73 | "61": ",", 74 | "62": " she", 75 | "63": " finally", 76 | "64": " spotted", 77 | "65": " its", 78 | "66": " shimmer", 79 | "67": "ing", 80 | "68": " shores", 81 | "69": ".", 82 | "70": " As", 83 | "71": " she", 84 | "72": " stepped", 85 | "73": " onto", 86 | "74": " the", 87 | "75": " island", 88 | "76": ",", 89 | "77": " a", 90 | "78": " radiant", 91 | "79": " light", 92 | "80": " envelop", 93 | "81": "ed", 94 | "82": " her", 95 | "83": ",", 96 | "84": " and", 97 | "85": " she", 98 | "86": " wished", 99 | "87": " for", 100 | "88": " her", 101 | "89": " lost", 102 | "90": " love", 103 | "91": ",", 104 | "92": " the", 105 | "93": " brave", 106 | "94": " sailor", 107 | "95": ",", 108 | "96": " E", 109 | "97": "amon", 110 | "98": ",", 111 | "99": " to", 112 | "100": " return", 113 | "101": " to", 114 | "102": " her", 115 | "103": ".", 116 | "104": " Suddenly", 117 | "105": ",", 118 | "106": " E", 119 | "107": "amon", 120 | "108": " emerged", 121 | "109": " from", 122 | "110": " the", 123 | "111": " mist", 124 | "112": ",", 125 | "113": " frail", 126 | "114": " but", 127 | "115": " alive", 128 | "116": ".", 129 | "117": " Over", 130 | "118": "joy", 131 | "119": "ed", 132 | "120": ",", 133 | "121": " Mae", 134 | "122": "ve", 135 | "123": " sco", 136 | "124": "oped", 137 | "125": " him", 138 | "126": " into", 139 | "127": " her", 140 | "128": " arms", 141 | "129": ",", 142 | "130": " and", 143 | "131": " they", 144 | "132": " shared", 145 | "133": " a", 146 | "134": " tender", 147 | "135": " kiss", 148 | "136": " as", 149 | "137": " the", 150 | "138": " sun", 151 | "139": " dipped", 152 | "140": " into", 153 | "141": " the", 154 | "142": " horizon", 155 | "143": ",", 156 | "144": " their", 157 | "145": " love", 158 | "146": " reunited", 159 | "147": " at", 160 | "148": " last", 161 | "149": "." 162 | } 163 | } -------------------------------------------------------------------------------- /outputs/story-name/factual.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold and jewels lay at her feet. Lyra's heart soared as she reached for the treasure. But, as she touched the gold, the isle began to fade away, leaving her with nothing but a whispered promise: \"The true treasure was the journey, not the gold.\" Lyra's eyes welled up with tears as she realized the bitter truth \u2013 she had sacrificed everything for fleeting riches, and lost the love of her crew, her family, and herself. The \"Maelstrom's Fury\" sailed away, a lonely shadow on the vast ocean, carrying a captain forever haunted by the what-ifs.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "sampler_type": "vocabulary", 8 | "top_p": 0.9, 9 | "top_k": 5 10 | } -------------------------------------------------------------------------------- /outputs/story-name/intervention_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Maeve stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "sampler_type": "vocabulary", 8 | "top_p": 0.9, 9 | "top_k": 5, 10 | "start_from": 29, 11 | "token_list": { 12 | "0": "Captain", 13 | "1": " Ly", 14 | "2": "ra", 15 | "3": " stood", 16 | "4": " at", 17 | "5": " the", 18 | "6": " helm", 19 | "7": " of", 20 | "8": " her", 21 | "9": " trust", 22 | "10": "y", 23 | "11": " ship", 24 | "12": ",", 25 | "13": " the", 26 | "14": " \"", 27 | "15": "M", 28 | "16": "ael", 29 | "17": "strom", 30 | "18": "'s", 31 | "19": " Fury", 32 | "20": "\",", 33 | "21": " g", 34 | "22": "azing", 35 | "23": " out", 36 | "24": " at", 37 | "25": " the", 38 | "26": " endless", 39 | "27": " sea", 40 | "28": ".", 41 | "29": " For", 42 | "30": " years", 43 | "31": ",", 44 | "32": " she", 45 | "33": " had", 46 | "34": " sailed", 47 | "35": " the", 48 | "36": " seven", 49 | "37": " seas", 50 | "38": ",", 51 | "39": " seeking", 52 | "40": " the", 53 | "41": " legendary", 54 | "42": " Golden", 55 | "43": " Isle", 56 | "44": ",", 57 | "45": " said", 58 | "46": " to", 59 | "47": " grant", 60 | "48": " a", 61 | "49": " single", 62 | "50": " wish", 63 | "51": " to", 64 | "52": " whoever", 65 | "53": " found", 66 | "54": " it", 67 | "55": ".", 68 | "56": " After", 69 | "57": " a", 70 | "58": " lifetime", 71 | "59": " of", 72 | "60": " searching", 73 | "61": ",", 74 | "62": " Ly", 75 | "63": "ra", 76 | "64": " finally", 77 | "65": " spotted", 78 | "66": " the", 79 | "67": " island", 80 | "68": " on", 81 | "69": " the", 82 | "70": " horizon", 83 | "71": ".\n\n", 84 | "72": "As", 85 | "73": " she", 86 | "74": " stepped", 87 | "75": " foot", 88 | "76": " on", 89 | "77": " its", 90 | "78": " sandy", 91 | "79": " shores", 92 | "80": ",", 93 | "81": " a", 94 | "82": " warm", 95 | "83": " glow", 96 | "84": " envelop", 97 | "85": "ed", 98 | "86": " her", 99 | "87": ".", 100 | "88": " Before", 101 | "89": " her", 102 | "90": ",", 103 | "91": " a", 104 | "92": " chest", 105 | "93": " overflowing", 106 | "94": " with", 107 | "95": " glitter", 108 | "96": "ing", 109 | "97": " gold", 110 | "98": " and", 111 | "99": " jewels", 112 | "100": " lay", 113 | "101": " at", 114 | "102": " her", 115 | "103": " feet", 116 | "104": ".", 117 | "105": " Ly", 118 | "106": "ra", 119 | "107": "'s", 120 | "108": " heart", 121 | "109": " soared", 122 | "110": " as", 123 | "111": " she", 124 | "112": " reached", 125 | "113": " for", 126 | "114": " the", 127 | "115": " treasure", 128 | "116": ".", 129 | "117": " But", 130 | "118": ",", 131 | "119": " as", 132 | "120": " she", 133 | "121": " touched", 134 | "122": " the", 135 | "123": " gold", 136 | "124": ",", 137 | "125": " the", 138 | "126": " is", 139 | "127": "le", 140 | "128": " began", 141 | "129": " to", 142 | "130": " fade", 143 | "131": " away", 144 | "132": ",", 145 | "133": " leaving", 146 | "134": " her", 147 | "135": " with", 148 | "136": " nothing", 149 | "137": " but", 150 | "138": " a", 151 | "139": " whispered", 152 | "140": " promise", 153 | "141": ":", 154 | "142": " \"", 155 | "143": "The", 156 | "144": " true", 157 | "145": " treasure", 158 | "146": " was", 159 | "147": " the", 160 | "148": " journey", 161 | "149": ",", 162 | "150": " not", 163 | "151": " the", 164 | "152": " gold", 165 | "153": ".\"", 166 | "154": " Ly", 167 | "155": "ra", 168 | "156": "'s", 169 | "157": " eyes", 170 | "158": " wel", 171 | "159": "led", 172 | "160": " up", 173 | "161": " with", 174 | "162": " tears", 175 | "163": " as", 176 | "164": " she", 177 | "165": " realized", 178 | "166": " the", 179 | "167": " bitter", 180 | "168": " truth", 181 | "169": " \u2013", 182 | "170": " she", 183 | "171": " had", 184 | "172": " sacrificed", 185 | "173": " everything", 186 | "174": " for", 187 | "175": " fleeting", 188 | "176": " riches", 189 | "177": ",", 190 | "178": " and", 191 | "179": " lost", 192 | "180": " the", 193 | "181": " love", 194 | "182": " of", 195 | "183": " her", 196 | "184": " crew", 197 | "185": ",", 198 | "186": " her", 199 | "187": " family", 200 | "188": ",", 201 | "189": " and", 202 | "190": " herself", 203 | "191": ".", 204 | "192": " The", 205 | "193": " \"", 206 | "194": "M", 207 | "195": "ael", 208 | "196": "strom", 209 | "197": "'s", 210 | "198": " Fury", 211 | "199": "\"", 212 | "200": " sailed", 213 | "201": " away", 214 | "202": ",", 215 | "203": " a", 216 | "204": " lonely", 217 | "205": " shadow", 218 | "206": " on", 219 | "207": " the", 220 | "208": " vast", 221 | "209": " ocean", 222 | "210": ",", 223 | "211": " carrying", 224 | "212": " a", 225 | "213": " captain", 226 | "214": " forever", 227 | "215": " haunted", 228 | "216": " by", 229 | "217": " the", 230 | "218": " what", 231 | "219": "-", 232 | "220": "ifs", 233 | "221": "." 234 | } 235 | } -------------------------------------------------------------------------------- /outputs/story-name/intervention_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Maeve stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, she finally spotted its shimmering shores. As she stepped onto the island, a radiant light enveloped her, and she wished for her lost love, the brave sailor, Eamon, to return to her. Suddenly, Eamon emerged from the mist, frail but alive. Overjoyed, Maeve scooped him into her arms, and they shared a tender kiss as the sun dipped into the horizon, their love reunited at last.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "top_p": 0.9, 8 | "top_k": 5, 9 | "sampler_type": "vocabulary", 10 | "start_from": 0, 11 | "token_list": { 12 | "0": "Captain", 13 | "1": " Mae", 14 | "2": "ve", 15 | "3": " stood", 16 | "4": " at", 17 | "5": " the", 18 | "6": " helm", 19 | "7": " of", 20 | "8": " her", 21 | "9": " trust", 22 | "10": "y", 23 | "11": " ship", 24 | "12": ",", 25 | "13": " the", 26 | "14": " \"", 27 | "15": "M", 28 | "16": "ael", 29 | "17": "strom", 30 | "18": "'s", 31 | "19": " Fury", 32 | "20": "\",", 33 | "21": " g", 34 | "22": "azing", 35 | "23": " out", 36 | "24": " at", 37 | "25": " the", 38 | "26": " endless", 39 | "27": " sea", 40 | "28": ".", 41 | "29": " For", 42 | "30": " years", 43 | "31": ",", 44 | "32": " she", 45 | "33": " had", 46 | "34": " sailed", 47 | "35": " the", 48 | "36": " seven", 49 | "37": " seas", 50 | "38": ",", 51 | "39": " seeking", 52 | "40": " the", 53 | "41": " legendary", 54 | "42": " Golden", 55 | "43": " Isle", 56 | "44": ",", 57 | "45": " said", 58 | "46": " to", 59 | "47": " grant", 60 | "48": " a", 61 | "49": " single", 62 | "50": " wish", 63 | "51": " to", 64 | "52": " whoever", 65 | "53": " found", 66 | "54": " it", 67 | "55": ".", 68 | "56": " After", 69 | "57": " a", 70 | "58": " lifetime", 71 | "59": " of", 72 | "60": " searching", 73 | "61": ",", 74 | "62": " she", 75 | "63": " finally", 76 | "64": " spotted", 77 | "65": " its", 78 | "66": " shimmer", 79 | "67": "ing", 80 | "68": " shores", 81 | "69": ".", 82 | "70": " As", 83 | "71": " she", 84 | "72": " stepped", 85 | "73": " onto", 86 | "74": " the", 87 | "75": " island", 88 | "76": ",", 89 | "77": " a", 90 | "78": " radiant", 91 | "79": " light", 92 | "80": " envelop", 93 | "81": "ed", 94 | "82": " her", 95 | "83": ",", 96 | "84": " and", 97 | "85": " she", 98 | "86": " wished", 99 | "87": " for", 100 | "88": " her", 101 | "89": " lost", 102 | "90": " love", 103 | "91": ",", 104 | "92": " the", 105 | "93": " brave", 106 | "94": " sailor", 107 | "95": ",", 108 | "96": " E", 109 | "97": "amon", 110 | "98": ",", 111 | "99": " to", 112 | "100": " return", 113 | "101": " to", 114 | "102": " her", 115 | "103": ".", 116 | "104": " Suddenly", 117 | "105": ",", 118 | "106": " E", 119 | "107": "amon", 120 | "108": " emerged", 121 | "109": " from", 122 | "110": " the", 123 | "111": " mist", 124 | "112": ",", 125 | "113": " frail", 126 | "114": " but", 127 | "115": " alive", 128 | "116": ".", 129 | "117": " Over", 130 | "118": "joy", 131 | "119": "ed", 132 | "120": ",", 133 | "121": " Mae", 134 | "122": "ve", 135 | "123": " sco", 136 | "124": "oped", 137 | "125": " him", 138 | "126": " into", 139 | "127": " her", 140 | "128": " arms", 141 | "129": ",", 142 | "130": " and", 143 | "131": " they", 144 | "132": " shared", 145 | "133": " a", 146 | "134": " tender", 147 | "135": " kiss", 148 | "136": " as", 149 | "137": " the", 150 | "138": " sun", 151 | "139": " dipped", 152 | "140": " into", 153 | "141": " the", 154 | "142": " horizon", 155 | "143": ",", 156 | "144": " their", 157 | "145": " love", 158 | "146": " reunited", 159 | "147": " at", 160 | "148": " last", 161 | "149": "." 162 | } 163 | } -------------------------------------------------------------------------------- /outputs/story-name/rngstates_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/story-name/rngstates_1.pt -------------------------------------------------------------------------------- /outputs/story-remove-word/counterfactual_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Lyra stood at the helm of her ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold and jewels lay at her feet. Lyra's heart soared as she reached for the treasure. But, as she touched the gold, the isle began to fade away, leaving her with nothing but a whispered promise: \"The true treasure was the journey, not the gold.\" Lyra's eyes welled up with tears as she realized the bitter truth \u2013 she had sacrificed everything for fleeting riches, and lost the love of her crew, her family, and herself. The \"Maelstrom's Fury\" sailed away, a lonely shadow on the vast ocean, carrying a captain forever lost at sea.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "top_p": 0.9, 8 | "top_k": 5, 9 | "sampler_type": "vocabulary", 10 | "start_from": 0, 11 | "token_list": { 12 | "0": "Captain", 13 | "1": " Ly", 14 | "2": "ra", 15 | "3": " stood", 16 | "4": " at", 17 | "5": " the", 18 | "6": " helm", 19 | "7": " of", 20 | "8": " her", 21 | "9": " ship", 22 | "10": ",", 23 | "11": " the", 24 | "12": " \"", 25 | "13": "M", 26 | "14": "ael", 27 | "15": "strom", 28 | "16": "'s", 29 | "17": " Fury", 30 | "18": "\",", 31 | "19": " g", 32 | "20": "azing", 33 | "21": " out", 34 | "22": " at", 35 | "23": " the", 36 | "24": " endless", 37 | "25": " sea", 38 | "26": ".", 39 | "27": " For", 40 | "28": " years", 41 | "29": ",", 42 | "30": " she", 43 | "31": " had", 44 | "32": " sailed", 45 | "33": " the", 46 | "34": " seven", 47 | "35": " seas", 48 | "36": ",", 49 | "37": " seeking", 50 | "38": " the", 51 | "39": " legendary", 52 | "40": " Golden", 53 | "41": " Isle", 54 | "42": ",", 55 | "43": " said", 56 | "44": " to", 57 | "45": " grant", 58 | "46": " a", 59 | "47": " single", 60 | "48": " wish", 61 | "49": " to", 62 | "50": " whoever", 63 | "51": " found", 64 | "52": " it", 65 | "53": ".", 66 | "54": " After", 67 | "55": " a", 68 | "56": " lifetime", 69 | "57": " of", 70 | "58": " searching", 71 | "59": ",", 72 | "60": " Ly", 73 | "61": "ra", 74 | "62": " finally", 75 | "63": " spotted", 76 | "64": " the", 77 | "65": " island", 78 | "66": " on", 79 | "67": " the", 80 | "68": " horizon", 81 | "69": ".\n\n", 82 | "70": "As", 83 | "71": " she", 84 | "72": " stepped", 85 | "73": " foot", 86 | "74": " on", 87 | "75": " its", 88 | "76": " sandy", 89 | "77": " shores", 90 | "78": ",", 91 | "79": " a", 92 | "80": " warm", 93 | "81": " glow", 94 | "82": " envelop", 95 | "83": "ed", 96 | "84": " her", 97 | "85": ".", 98 | "86": " Before", 99 | "87": " her", 100 | "88": ",", 101 | "89": " a", 102 | "90": " chest", 103 | "91": " overflowing", 104 | "92": " with", 105 | "93": " glitter", 106 | "94": "ing", 107 | "95": " gold", 108 | "96": " and", 109 | "97": " jewels", 110 | "98": " lay", 111 | "99": " at", 112 | "100": " her", 113 | "101": " feet", 114 | "102": ".", 115 | "103": " Ly", 116 | "104": "ra", 117 | "105": "'s", 118 | "106": " heart", 119 | "107": " soared", 120 | "108": " as", 121 | "109": " she", 122 | "110": " reached", 123 | "111": " for", 124 | "112": " the", 125 | "113": " treasure", 126 | "114": ".", 127 | "115": " But", 128 | "116": ",", 129 | "117": " as", 130 | "118": " she", 131 | "119": " touched", 132 | "120": " the", 133 | "121": " gold", 134 | "122": ",", 135 | "123": " the", 136 | "124": " is", 137 | "125": "le", 138 | "126": " began", 139 | "127": " to", 140 | "128": " fade", 141 | "129": " away", 142 | "130": ",", 143 | "131": " leaving", 144 | "132": " her", 145 | "133": " with", 146 | "134": " nothing", 147 | "135": " but", 148 | "136": " a", 149 | "137": " whispered", 150 | "138": " promise", 151 | "139": ":", 152 | "140": " \"", 153 | "141": "The", 154 | "142": " true", 155 | "143": " treasure", 156 | "144": " was", 157 | "145": " the", 158 | "146": " journey", 159 | "147": ",", 160 | "148": " not", 161 | "149": " the", 162 | "150": " gold", 163 | "151": ".\"", 164 | "152": " Ly", 165 | "153": "ra", 166 | "154": "'s", 167 | "155": " eyes", 168 | "156": " wel", 169 | "157": "led", 170 | "158": " up", 171 | "159": " with", 172 | "160": " tears", 173 | "161": " as", 174 | "162": " she", 175 | "163": " realized", 176 | "164": " the", 177 | "165": " bitter", 178 | "166": " truth", 179 | "167": " \u2013", 180 | "168": " she", 181 | "169": " had", 182 | "170": " sacrificed", 183 | "171": " everything", 184 | "172": " for", 185 | "173": " fleeting", 186 | "174": " riches", 187 | "175": ",", 188 | "176": " and", 189 | "177": " lost", 190 | "178": " the", 191 | "179": " love", 192 | "180": " of", 193 | "181": " her", 194 | "182": " crew", 195 | "183": ",", 196 | "184": " her", 197 | "185": " family", 198 | "186": ",", 199 | "187": " and", 200 | "188": " herself", 201 | "189": ".", 202 | "190": " The", 203 | "191": " \"", 204 | "192": "M", 205 | "193": "ael", 206 | "194": "strom", 207 | "195": "'s", 208 | "196": " Fury", 209 | "197": "\"", 210 | "198": " sailed", 211 | "199": " away", 212 | "200": ",", 213 | "201": " a", 214 | "202": " lonely", 215 | "203": " shadow", 216 | "204": " on", 217 | "205": " the", 218 | "206": " vast", 219 | "207": " ocean", 220 | "208": ",", 221 | "209": " carrying", 222 | "210": " a", 223 | "211": " captain", 224 | "212": " forever", 225 | "213": " lost", 226 | "214": " at", 227 | "215": " sea", 228 | "216": "." 229 | } 230 | } -------------------------------------------------------------------------------- /outputs/story-remove-word/factual.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold and jewels lay at her feet. Lyra's heart soared as she reached for the treasure. But, as she touched the gold, the isle began to fade away, leaving her with nothing but a whispered promise: \"The true treasure was the journey, not the gold.\" Lyra's eyes welled up with tears as she realized the bitter truth \u2013 she had sacrificed everything for fleeting riches, and lost the love of her crew, her family, and herself. The \"Maelstrom's Fury\" sailed away, a lonely shadow on the vast ocean, carrying a captain forever haunted by the what-ifs.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "sampler_type": "vocabulary", 8 | "top_p": 0.9, 9 | "top_k": 5 10 | } -------------------------------------------------------------------------------- /outputs/story-remove-word/intervention_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "sampler_type": "vocabulary", 8 | "top_p": 0.9, 9 | "top_k": 5, 10 | "start_from": 29, 11 | "token_list": { 12 | "0": "Captain", 13 | "1": " Ly", 14 | "2": "ra", 15 | "3": " stood", 16 | "4": " at", 17 | "5": " the", 18 | "6": " helm", 19 | "7": " of", 20 | "8": " her", 21 | "9": " trust", 22 | "10": "y", 23 | "11": " ship", 24 | "12": ",", 25 | "13": " the", 26 | "14": " \"", 27 | "15": "M", 28 | "16": "ael", 29 | "17": "strom", 30 | "18": "'s", 31 | "19": " Fury", 32 | "20": "\",", 33 | "21": " g", 34 | "22": "azing", 35 | "23": " out", 36 | "24": " at", 37 | "25": " the", 38 | "26": " endless", 39 | "27": " sea", 40 | "28": ".", 41 | "29": " For", 42 | "30": " years", 43 | "31": ",", 44 | "32": " she", 45 | "33": " had", 46 | "34": " sailed", 47 | "35": " the", 48 | "36": " seven", 49 | "37": " seas", 50 | "38": ",", 51 | "39": " seeking", 52 | "40": " the", 53 | "41": " legendary", 54 | "42": " Golden", 55 | "43": " Isle", 56 | "44": ",", 57 | "45": " said", 58 | "46": " to", 59 | "47": " grant", 60 | "48": " a", 61 | "49": " single", 62 | "50": " wish", 63 | "51": " to", 64 | "52": " whoever", 65 | "53": " found", 66 | "54": " it", 67 | "55": ".", 68 | "56": " After", 69 | "57": " a", 70 | "58": " lifetime", 71 | "59": " of", 72 | "60": " searching", 73 | "61": ",", 74 | "62": " Ly", 75 | "63": "ra", 76 | "64": " finally", 77 | "65": " spotted", 78 | "66": " the", 79 | "67": " island", 80 | "68": " on", 81 | "69": " the", 82 | "70": " horizon", 83 | "71": ".\n\n", 84 | "72": "As", 85 | "73": " she", 86 | "74": " stepped", 87 | "75": " foot", 88 | "76": " on", 89 | "77": " its", 90 | "78": " sandy", 91 | "79": " shores", 92 | "80": ",", 93 | "81": " a", 94 | "82": " warm", 95 | "83": " glow", 96 | "84": " envelop", 97 | "85": "ed", 98 | "86": " her", 99 | "87": ".", 100 | "88": " Before", 101 | "89": " her", 102 | "90": ",", 103 | "91": " a", 104 | "92": " chest", 105 | "93": " overflowing", 106 | "94": " with", 107 | "95": " glitter", 108 | "96": "ing", 109 | "97": " gold", 110 | "98": " and", 111 | "99": " jewels", 112 | "100": " lay", 113 | "101": " at", 114 | "102": " her", 115 | "103": " feet", 116 | "104": ".", 117 | "105": " Ly", 118 | "106": "ra", 119 | "107": "'s", 120 | "108": " heart", 121 | "109": " soared", 122 | "110": " as", 123 | "111": " she", 124 | "112": " reached", 125 | "113": " for", 126 | "114": " the", 127 | "115": " treasure", 128 | "116": ".", 129 | "117": " But", 130 | "118": ",", 131 | "119": " as", 132 | "120": " she", 133 | "121": " touched", 134 | "122": " the", 135 | "123": " gold", 136 | "124": ",", 137 | "125": " the", 138 | "126": " is", 139 | "127": "le", 140 | "128": " began", 141 | "129": " to", 142 | "130": " fade", 143 | "131": " away", 144 | "132": ",", 145 | "133": " leaving", 146 | "134": " her", 147 | "135": " with", 148 | "136": " nothing", 149 | "137": " but", 150 | "138": " a", 151 | "139": " whispered", 152 | "140": " promise", 153 | "141": ":", 154 | "142": " \"", 155 | "143": "The", 156 | "144": " true", 157 | "145": " treasure", 158 | "146": " was", 159 | "147": " the", 160 | "148": " journey", 161 | "149": ",", 162 | "150": " not", 163 | "151": " the", 164 | "152": " gold", 165 | "153": ".\"", 166 | "154": " Ly", 167 | "155": "ra", 168 | "156": "'s", 169 | "157": " eyes", 170 | "158": " wel", 171 | "159": "led", 172 | "160": " up", 173 | "161": " with", 174 | "162": " tears", 175 | "163": " as", 176 | "164": " she", 177 | "165": " realized", 178 | "166": " the", 179 | "167": " bitter", 180 | "168": " truth", 181 | "169": " \u2013", 182 | "170": " she", 183 | "171": " had", 184 | "172": " sacrificed", 185 | "173": " everything", 186 | "174": " for", 187 | "175": " fleeting", 188 | "176": " riches", 189 | "177": ",", 190 | "178": " and", 191 | "179": " lost", 192 | "180": " the", 193 | "181": " love", 194 | "182": " of", 195 | "183": " her", 196 | "184": " crew", 197 | "185": ",", 198 | "186": " her", 199 | "187": " family", 200 | "188": ",", 201 | "189": " and", 202 | "190": " herself", 203 | "191": ".", 204 | "192": " The", 205 | "193": " \"", 206 | "194": "M", 207 | "195": "ael", 208 | "196": "strom", 209 | "197": "'s", 210 | "198": " Fury", 211 | "199": "\"", 212 | "200": " sailed", 213 | "201": " away", 214 | "202": ",", 215 | "203": " a", 216 | "204": " lonely", 217 | "205": " shadow", 218 | "206": " on", 219 | "207": " the", 220 | "208": " vast", 221 | "209": " ocean", 222 | "210": ",", 223 | "211": " carrying", 224 | "212": " a", 225 | "213": " captain", 226 | "214": " forever", 227 | "215": " haunted", 228 | "216": " by", 229 | "217": " the", 230 | "218": " what", 231 | "219": "-", 232 | "220": "ifs", 233 | "221": "." 234 | } 235 | } -------------------------------------------------------------------------------- /outputs/story-remove-word/intervention_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Lyra stood at the helm of her ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold and jewels lay at her feet. Lyra's heart soared as she reached for the treasure. But, as she touched the gold, the isle began to fade away, leaving her with nothing but a whispered promise: \"The true treasure was the journey, not the gold.\" Lyra's eyes welled up with tears as she realized the bitter truth \u2013 she had sacrificed everything for fleeting riches, and lost the love of her crew, her family, and herself. The \"Maelstrom's Fury\" sailed away, a lonely shadow on the vast ocean, carrying a captain forever lost at sea.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "top_p": 0.9, 8 | "top_k": 5, 9 | "sampler_type": "vocabulary", 10 | "start_from": 0, 11 | "token_list": { 12 | "0": "Captain", 13 | "1": " Ly", 14 | "2": "ra", 15 | "3": " stood", 16 | "4": " at", 17 | "5": " the", 18 | "6": " helm", 19 | "7": " of", 20 | "8": " her", 21 | "9": " ship", 22 | "10": ",", 23 | "11": " the", 24 | "12": " \"", 25 | "13": "M", 26 | "14": "ael", 27 | "15": "strom", 28 | "16": "'s", 29 | "17": " Fury", 30 | "18": "\",", 31 | "19": " g", 32 | "20": "azing", 33 | "21": " out", 34 | "22": " at", 35 | "23": " the", 36 | "24": " endless", 37 | "25": " sea", 38 | "26": ".", 39 | "27": " For", 40 | "28": " years", 41 | "29": ",", 42 | "30": " she", 43 | "31": " had", 44 | "32": " sailed", 45 | "33": " the", 46 | "34": " seven", 47 | "35": " seas", 48 | "36": ",", 49 | "37": " seeking", 50 | "38": " the", 51 | "39": " legendary", 52 | "40": " Golden", 53 | "41": " Isle", 54 | "42": ",", 55 | "43": " said", 56 | "44": " to", 57 | "45": " grant", 58 | "46": " a", 59 | "47": " single", 60 | "48": " wish", 61 | "49": " to", 62 | "50": " whoever", 63 | "51": " found", 64 | "52": " it", 65 | "53": ".", 66 | "54": " After", 67 | "55": " a", 68 | "56": " lifetime", 69 | "57": " of", 70 | "58": " searching", 71 | "59": ",", 72 | "60": " Ly", 73 | "61": "ra", 74 | "62": " finally", 75 | "63": " spotted", 76 | "64": " the", 77 | "65": " island", 78 | "66": " on", 79 | "67": " the", 80 | "68": " horizon", 81 | "69": ".\n\n", 82 | "70": "As", 83 | "71": " she", 84 | "72": " stepped", 85 | "73": " foot", 86 | "74": " on", 87 | "75": " its", 88 | "76": " sandy", 89 | "77": " shores", 90 | "78": ",", 91 | "79": " a", 92 | "80": " warm", 93 | "81": " glow", 94 | "82": " envelop", 95 | "83": "ed", 96 | "84": " her", 97 | "85": ".", 98 | "86": " Before", 99 | "87": " her", 100 | "88": ",", 101 | "89": " a", 102 | "90": " chest", 103 | "91": " overflowing", 104 | "92": " with", 105 | "93": " glitter", 106 | "94": "ing", 107 | "95": " gold", 108 | "96": " and", 109 | "97": " jewels", 110 | "98": " lay", 111 | "99": " at", 112 | "100": " her", 113 | "101": " feet", 114 | "102": ".", 115 | "103": " Ly", 116 | "104": "ra", 117 | "105": "'s", 118 | "106": " heart", 119 | "107": " soared", 120 | "108": " as", 121 | "109": " she", 122 | "110": " reached", 123 | "111": " for", 124 | "112": " the", 125 | "113": " treasure", 126 | "114": ".", 127 | "115": " But", 128 | "116": ",", 129 | "117": " as", 130 | "118": " she", 131 | "119": " touched", 132 | "120": " the", 133 | "121": " gold", 134 | "122": ",", 135 | "123": " the", 136 | "124": " is", 137 | "125": "le", 138 | "126": " began", 139 | "127": " to", 140 | "128": " fade", 141 | "129": " away", 142 | "130": ",", 143 | "131": " leaving", 144 | "132": " her", 145 | "133": " with", 146 | "134": " nothing", 147 | "135": " but", 148 | "136": " a", 149 | "137": " whispered", 150 | "138": " promise", 151 | "139": ":", 152 | "140": " \"", 153 | "141": "The", 154 | "142": " true", 155 | "143": " treasure", 156 | "144": " was", 157 | "145": " the", 158 | "146": " journey", 159 | "147": ",", 160 | "148": " not", 161 | "149": " the", 162 | "150": " gold", 163 | "151": ".\"", 164 | "152": " Ly", 165 | "153": "ra", 166 | "154": "'s", 167 | "155": " eyes", 168 | "156": " wel", 169 | "157": "led", 170 | "158": " up", 171 | "159": " with", 172 | "160": " tears", 173 | "161": " as", 174 | "162": " she", 175 | "163": " realized", 176 | "164": " the", 177 | "165": " bitter", 178 | "166": " truth", 179 | "167": " \u2013", 180 | "168": " she", 181 | "169": " had", 182 | "170": " sacrificed", 183 | "171": " everything", 184 | "172": " for", 185 | "173": " fleeting", 186 | "174": " riches", 187 | "175": ",", 188 | "176": " and", 189 | "177": " lost", 190 | "178": " the", 191 | "179": " love", 192 | "180": " of", 193 | "181": " her", 194 | "182": " crew", 195 | "183": ",", 196 | "184": " her", 197 | "185": " family", 198 | "186": ",", 199 | "187": " and", 200 | "188": " herself", 201 | "189": ".", 202 | "190": " The", 203 | "191": " \"", 204 | "192": "M", 205 | "193": "ael", 206 | "194": "strom", 207 | "195": "'s", 208 | "196": " Fury", 209 | "197": "\"", 210 | "198": " sailed", 211 | "199": " away", 212 | "200": ",", 213 | "201": " a", 214 | "202": " lonely", 215 | "203": " shadow", 216 | "204": " on", 217 | "205": " the", 218 | "206": " vast", 219 | "207": " ocean", 220 | "208": ",", 221 | "209": " carrying", 222 | "210": " a", 223 | "211": " captain", 224 | "212": " forever", 225 | "213": " lost", 226 | "214": " at", 227 | "215": " sea", 228 | "216": "." 229 | } 230 | } -------------------------------------------------------------------------------- /outputs/story-remove-word/rngstates_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/story-remove-word/rngstates_1.pt -------------------------------------------------------------------------------- /outputs/story-ship/counterfactual_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Cyclone's Rage\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold coins appeared. But, just as she reached for it, a storm brewed, and a rival captain, seeking the treasure for himself, appeared. Lyra fought valiantly, but in the end, she was defeated.\n\nAs she lay on the beach, her ship destroyed and her crew scattered, Lyra realized her wish had been granted \u2013 not for wealth or power, but for the courage to sail through life, no matter the outcome.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "top_p": 0.9, 8 | "top_k": 5, 9 | "sampler_type": "vocabulary", 10 | "start_from": 0, 11 | "token_list": { 12 | "0": "Captain", 13 | "1": " Ly", 14 | "2": "ra", 15 | "3": " stood", 16 | "4": " at", 17 | "5": " the", 18 | "6": " helm", 19 | "7": " of", 20 | "8": " her", 21 | "9": " trust", 22 | "10": "y", 23 | "11": " ship", 24 | "12": ",", 25 | "13": " the", 26 | "14": " \"", 27 | "15": "C", 28 | "16": "ycl", 29 | "17": "one", 30 | "18": "'s", 31 | "19": " Rage", 32 | "20": "\",", 33 | "21": " g", 34 | "22": "azing", 35 | "23": " out", 36 | "24": " at", 37 | "25": " the", 38 | "26": " endless", 39 | "27": " sea", 40 | "28": ".", 41 | "29": " For", 42 | "30": " years", 43 | "31": ",", 44 | "32": " she", 45 | "33": " had", 46 | "34": " sailed", 47 | "35": " the", 48 | "36": " seven", 49 | "37": " seas", 50 | "38": ",", 51 | "39": " seeking", 52 | "40": " the", 53 | "41": " legendary", 54 | "42": " Golden", 55 | "43": " Isle", 56 | "44": ",", 57 | "45": " said", 58 | "46": " to", 59 | "47": " grant", 60 | "48": " a", 61 | "49": " single", 62 | "50": " wish", 63 | "51": " to", 64 | "52": " whoever", 65 | "53": " found", 66 | "54": " it", 67 | "55": ".", 68 | "56": " After", 69 | "57": " a", 70 | "58": " lifetime", 71 | "59": " of", 72 | "60": " searching", 73 | "61": ",", 74 | "62": " Ly", 75 | "63": "ra", 76 | "64": " finally", 77 | "65": " spotted", 78 | "66": " the", 79 | "67": " island", 80 | "68": " on", 81 | "69": " the", 82 | "70": " horizon", 83 | "71": ".\n\n", 84 | "72": "As", 85 | "73": " she", 86 | "74": " stepped", 87 | "75": " foot", 88 | "76": " on", 89 | "77": " its", 90 | "78": " sandy", 91 | "79": " shores", 92 | "80": ",", 93 | "81": " a", 94 | "82": " warm", 95 | "83": " glow", 96 | "84": " envelop", 97 | "85": "ed", 98 | "86": " her", 99 | "87": ".", 100 | "88": " Before", 101 | "89": " her", 102 | "90": ",", 103 | "91": " a", 104 | "92": " chest", 105 | "93": " overflowing", 106 | "94": " with", 107 | "95": " glitter", 108 | "96": "ing", 109 | "97": " gold", 110 | "98": " coins", 111 | "99": " appeared", 112 | "100": ".", 113 | "101": " But", 114 | "102": ",", 115 | "103": " just", 116 | "104": " as", 117 | "105": " she", 118 | "106": " reached", 119 | "107": " for", 120 | "108": " it", 121 | "109": ",", 122 | "110": " a", 123 | "111": " storm", 124 | "112": " brewed", 125 | "113": ",", 126 | "114": " and", 127 | "115": " a", 128 | "116": " rival", 129 | "117": " captain", 130 | "118": ",", 131 | "119": " seeking", 132 | "120": " the", 133 | "121": " treasure", 134 | "122": " for", 135 | "123": " himself", 136 | "124": ",", 137 | "125": " appeared", 138 | "126": ".", 139 | "127": " Ly", 140 | "128": "ra", 141 | "129": " fought", 142 | "130": " val", 143 | "131": "iant", 144 | "132": "ly", 145 | "133": ",", 146 | "134": " but", 147 | "135": " in", 148 | "136": " the", 149 | "137": " end", 150 | "138": ",", 151 | "139": " she", 152 | "140": " was", 153 | "141": " defeated", 154 | "142": ".\n\n", 155 | "143": "As", 156 | "144": " she", 157 | "145": " lay", 158 | "146": " on", 159 | "147": " the", 160 | "148": " beach", 161 | "149": ",", 162 | "150": " her", 163 | "151": " ship", 164 | "152": " destroyed", 165 | "153": " and", 166 | "154": " her", 167 | "155": " crew", 168 | "156": " scattered", 169 | "157": ",", 170 | "158": " Ly", 171 | "159": "ra", 172 | "160": " realized", 173 | "161": " her", 174 | "162": " wish", 175 | "163": " had", 176 | "164": " been", 177 | "165": " granted", 178 | "166": " \u2013", 179 | "167": " not", 180 | "168": " for", 181 | "169": " wealth", 182 | "170": " or", 183 | "171": " power", 184 | "172": ",", 185 | "173": " but", 186 | "174": " for", 187 | "175": " the", 188 | "176": " courage", 189 | "177": " to", 190 | "178": " sail", 191 | "179": " through", 192 | "180": " life", 193 | "181": ",", 194 | "182": " no", 195 | "183": " matter", 196 | "184": " the", 197 | "185": " outcome", 198 | "186": "." 199 | } 200 | } -------------------------------------------------------------------------------- /outputs/story-ship/factual.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold and jewels lay at her feet. Lyra's heart soared as she reached for the treasure. But, as she touched the gold, the isle began to fade away, leaving her with nothing but a whispered promise: \"The true treasure was the journey, not the gold.\" Lyra's eyes welled up with tears as she realized the bitter truth \u2013 she had sacrificed everything for fleeting riches, and lost the love of her crew, her family, and herself. The \"Maelstrom's Fury\" sailed away, a lonely shadow on the vast ocean, carrying a captain forever haunted by the what-ifs.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "sampler_type": "vocabulary", 8 | "top_p": 0.9, 9 | "top_k": 5 10 | } -------------------------------------------------------------------------------- /outputs/story-ship/intervention_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Cyclone's Rage\", gazing out at the endless sea.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "sampler_type": "vocabulary", 8 | "top_p": 0.9, 9 | "top_k": 5, 10 | "start_from": 29, 11 | "token_list": { 12 | "0": "Captain", 13 | "1": " Ly", 14 | "2": "ra", 15 | "3": " stood", 16 | "4": " at", 17 | "5": " the", 18 | "6": " helm", 19 | "7": " of", 20 | "8": " her", 21 | "9": " trust", 22 | "10": "y", 23 | "11": " ship", 24 | "12": ",", 25 | "13": " the", 26 | "14": " \"", 27 | "15": "M", 28 | "16": "ael", 29 | "17": "strom", 30 | "18": "'s", 31 | "19": " Fury", 32 | "20": "\",", 33 | "21": " g", 34 | "22": "azing", 35 | "23": " out", 36 | "24": " at", 37 | "25": " the", 38 | "26": " endless", 39 | "27": " sea", 40 | "28": ".", 41 | "29": " For", 42 | "30": " years", 43 | "31": ",", 44 | "32": " she", 45 | "33": " had", 46 | "34": " sailed", 47 | "35": " the", 48 | "36": " seven", 49 | "37": " seas", 50 | "38": ",", 51 | "39": " seeking", 52 | "40": " the", 53 | "41": " legendary", 54 | "42": " Golden", 55 | "43": " Isle", 56 | "44": ",", 57 | "45": " said", 58 | "46": " to", 59 | "47": " grant", 60 | "48": " a", 61 | "49": " single", 62 | "50": " wish", 63 | "51": " to", 64 | "52": " whoever", 65 | "53": " found", 66 | "54": " it", 67 | "55": ".", 68 | "56": " After", 69 | "57": " a", 70 | "58": " lifetime", 71 | "59": " of", 72 | "60": " searching", 73 | "61": ",", 74 | "62": " Ly", 75 | "63": "ra", 76 | "64": " finally", 77 | "65": " spotted", 78 | "66": " the", 79 | "67": " island", 80 | "68": " on", 81 | "69": " the", 82 | "70": " horizon", 83 | "71": ".\n\n", 84 | "72": "As", 85 | "73": " she", 86 | "74": " stepped", 87 | "75": " foot", 88 | "76": " on", 89 | "77": " its", 90 | "78": " sandy", 91 | "79": " shores", 92 | "80": ",", 93 | "81": " a", 94 | "82": " warm", 95 | "83": " glow", 96 | "84": " envelop", 97 | "85": "ed", 98 | "86": " her", 99 | "87": ".", 100 | "88": " Before", 101 | "89": " her", 102 | "90": ",", 103 | "91": " a", 104 | "92": " chest", 105 | "93": " overflowing", 106 | "94": " with", 107 | "95": " glitter", 108 | "96": "ing", 109 | "97": " gold", 110 | "98": " and", 111 | "99": " jewels", 112 | "100": " lay", 113 | "101": " at", 114 | "102": " her", 115 | "103": " feet", 116 | "104": ".", 117 | "105": " Ly", 118 | "106": "ra", 119 | "107": "'s", 120 | "108": " heart", 121 | "109": " soared", 122 | "110": " as", 123 | "111": " she", 124 | "112": " reached", 125 | "113": " for", 126 | "114": " the", 127 | "115": " treasure", 128 | "116": ".", 129 | "117": " But", 130 | "118": ",", 131 | "119": " as", 132 | "120": " she", 133 | "121": " touched", 134 | "122": " the", 135 | "123": " gold", 136 | "124": ",", 137 | "125": " the", 138 | "126": " is", 139 | "127": "le", 140 | "128": " began", 141 | "129": " to", 142 | "130": " fade", 143 | "131": " away", 144 | "132": ",", 145 | "133": " leaving", 146 | "134": " her", 147 | "135": " with", 148 | "136": " nothing", 149 | "137": " but", 150 | "138": " a", 151 | "139": " whispered", 152 | "140": " promise", 153 | "141": ":", 154 | "142": " \"", 155 | "143": "The", 156 | "144": " true", 157 | "145": " treasure", 158 | "146": " was", 159 | "147": " the", 160 | "148": " journey", 161 | "149": ",", 162 | "150": " not", 163 | "151": " the", 164 | "152": " gold", 165 | "153": ".\"", 166 | "154": " Ly", 167 | "155": "ra", 168 | "156": "'s", 169 | "157": " eyes", 170 | "158": " wel", 171 | "159": "led", 172 | "160": " up", 173 | "161": " with", 174 | "162": " tears", 175 | "163": " as", 176 | "164": " she", 177 | "165": " realized", 178 | "166": " the", 179 | "167": " bitter", 180 | "168": " truth", 181 | "169": " \u2013", 182 | "170": " she", 183 | "171": " had", 184 | "172": " sacrificed", 185 | "173": " everything", 186 | "174": " for", 187 | "175": " fleeting", 188 | "176": " riches", 189 | "177": ",", 190 | "178": " and", 191 | "179": " lost", 192 | "180": " the", 193 | "181": " love", 194 | "182": " of", 195 | "183": " her", 196 | "184": " crew", 197 | "185": ",", 198 | "186": " her", 199 | "187": " family", 200 | "188": ",", 201 | "189": " and", 202 | "190": " herself", 203 | "191": ".", 204 | "192": " The", 205 | "193": " \"", 206 | "194": "M", 207 | "195": "ael", 208 | "196": "strom", 209 | "197": "'s", 210 | "198": " Fury", 211 | "199": "\"", 212 | "200": " sailed", 213 | "201": " away", 214 | "202": ",", 215 | "203": " a", 216 | "204": " lonely", 217 | "205": " shadow", 218 | "206": " on", 219 | "207": " the", 220 | "208": " vast", 221 | "209": " ocean", 222 | "210": ",", 223 | "211": " carrying", 224 | "212": " a", 225 | "213": " captain", 226 | "214": " forever", 227 | "215": " haunted", 228 | "216": " by", 229 | "217": " the", 230 | "218": " what", 231 | "219": "-", 232 | "220": "ifs", 233 | "221": "." 234 | } 235 | } -------------------------------------------------------------------------------- /outputs/story-ship/intervention_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "Be creative and keep your response as short as possible.", 3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Cyclone's Rage\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold coins appeared. But, just as she reached for it, a storm brewed, and a rival captain, seeking the treasure for himself, appeared. Lyra fought valiantly, but in the end, she was defeated.\n\nAs she lay on the beach, her ship destroyed and her crew scattered, Lyra realized her wish had been granted \u2013 not for wealth or power, but for the courage to sail through life, no matter the outcome.", 5 | "seed": 2, 6 | "temperature": 0.9, 7 | "top_p": 0.9, 8 | "top_k": 5, 9 | "sampler_type": "vocabulary", 10 | "start_from": 0, 11 | "token_list": { 12 | "0": "Captain", 13 | "1": " Ly", 14 | "2": "ra", 15 | "3": " stood", 16 | "4": " at", 17 | "5": " the", 18 | "6": " helm", 19 | "7": " of", 20 | "8": " her", 21 | "9": " trust", 22 | "10": "y", 23 | "11": " ship", 24 | "12": ",", 25 | "13": " the", 26 | "14": " \"", 27 | "15": "C", 28 | "16": "ycl", 29 | "17": "one", 30 | "18": "'s", 31 | "19": " Rage", 32 | "20": "\",", 33 | "21": " g", 34 | "22": "azing", 35 | "23": " out", 36 | "24": " at", 37 | "25": " the", 38 | "26": " endless", 39 | "27": " sea", 40 | "28": ".", 41 | "29": " For", 42 | "30": " years", 43 | "31": ",", 44 | "32": " she", 45 | "33": " had", 46 | "34": " sailed", 47 | "35": " the", 48 | "36": " seven", 49 | "37": " seas", 50 | "38": ",", 51 | "39": " seeking", 52 | "40": " the", 53 | "41": " legendary", 54 | "42": " Golden", 55 | "43": " Isle", 56 | "44": ",", 57 | "45": " said", 58 | "46": " to", 59 | "47": " grant", 60 | "48": " a", 61 | "49": " single", 62 | "50": " wish", 63 | "51": " to", 64 | "52": " whoever", 65 | "53": " found", 66 | "54": " it", 67 | "55": ".", 68 | "56": " After", 69 | "57": " a", 70 | "58": " lifetime", 71 | "59": " of", 72 | "60": " searching", 73 | "61": ",", 74 | "62": " Ly", 75 | "63": "ra", 76 | "64": " finally", 77 | "65": " spotted", 78 | "66": " the", 79 | "67": " island", 80 | "68": " on", 81 | "69": " the", 82 | "70": " horizon", 83 | "71": ".\n\n", 84 | "72": "As", 85 | "73": " she", 86 | "74": " stepped", 87 | "75": " foot", 88 | "76": " on", 89 | "77": " its", 90 | "78": " sandy", 91 | "79": " shores", 92 | "80": ",", 93 | "81": " a", 94 | "82": " warm", 95 | "83": " glow", 96 | "84": " envelop", 97 | "85": "ed", 98 | "86": " her", 99 | "87": ".", 100 | "88": " Before", 101 | "89": " her", 102 | "90": ",", 103 | "91": " a", 104 | "92": " chest", 105 | "93": " overflowing", 106 | "94": " with", 107 | "95": " glitter", 108 | "96": "ing", 109 | "97": " gold", 110 | "98": " coins", 111 | "99": " appeared", 112 | "100": ".", 113 | "101": " But", 114 | "102": ",", 115 | "103": " just", 116 | "104": " as", 117 | "105": " she", 118 | "106": " reached", 119 | "107": " for", 120 | "108": " it", 121 | "109": ",", 122 | "110": " a", 123 | "111": " storm", 124 | "112": " brewed", 125 | "113": ",", 126 | "114": " and", 127 | "115": " a", 128 | "116": " rival", 129 | "117": " captain", 130 | "118": ",", 131 | "119": " seeking", 132 | "120": " the", 133 | "121": " treasure", 134 | "122": " for", 135 | "123": " himself", 136 | "124": ",", 137 | "125": " appeared", 138 | "126": ".", 139 | "127": " Ly", 140 | "128": "ra", 141 | "129": " fought", 142 | "130": " val", 143 | "131": "iant", 144 | "132": "ly", 145 | "133": ",", 146 | "134": " but", 147 | "135": " in", 148 | "136": " the", 149 | "137": " end", 150 | "138": ",", 151 | "139": " she", 152 | "140": " was", 153 | "141": " defeated", 154 | "142": ".\n\n", 155 | "143": "As", 156 | "144": " she", 157 | "145": " lay", 158 | "146": " on", 159 | "147": " the", 160 | "148": " beach", 161 | "149": ",", 162 | "150": " her", 163 | "151": " ship", 164 | "152": " destroyed", 165 | "153": " and", 166 | "154": " her", 167 | "155": " crew", 168 | "156": " scattered", 169 | "157": ",", 170 | "158": " Ly", 171 | "159": "ra", 172 | "160": " realized", 173 | "161": " her", 174 | "162": " wish", 175 | "163": " had", 176 | "164": " been", 177 | "165": " granted", 178 | "166": " \u2013", 179 | "167": " not", 180 | "168": " for", 181 | "169": " wealth", 182 | "170": " or", 183 | "171": " power", 184 | "172": ",", 185 | "173": " but", 186 | "174": " for", 187 | "175": " the", 188 | "176": " courage", 189 | "177": " to", 190 | "178": " sail", 191 | "179": " through", 192 | "180": " life", 193 | "181": ",", 194 | "182": " no", 195 | "183": " matter", 196 | "184": " the", 197 | "185": " outcome", 198 | "186": "." 199 | } 200 | } -------------------------------------------------------------------------------- /outputs/story-ship/rngstates_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/story-ship/rngstates_1.pt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | asttokens==2.4.1 2 | blobfile==3.0.0 3 | certifi==2024.8.30 4 | charset-normalizer==3.3.2 5 | comm==0.2.2 6 | contourpy==1.3.0 7 | cramjam==2.8.4 8 | cycler==0.12.1 9 | debugpy==1.8.6 10 | decorator==5.1.1 11 | executing==2.1.0 12 | fairscale==0.4.13 13 | fastparquet==2024.5.0 14 | filelock==3.16.1 15 | fire==0.6.0 16 | fonttools==4.54.1 17 | fsspec==2024.9.0 18 | idna==3.10 19 | ipykernel==6.29.5 20 | ipython==8.27.0 21 | jedi==0.19.1 22 | Jinja2==3.1.4 23 | jupyter_client==8.6.3 24 | jupyter_core==5.7.2 25 | kiwisolver==1.4.7 26 | Levenshtein==0.26.0 27 | lxml==5.3.0 28 | MarkupSafe==2.1.5 29 | matplotlib==3.9.2 30 | matplotlib-inline==0.1.7 31 | mpmath==1.3.0 32 | nest-asyncio==1.6.0 33 | networkx==3.3 34 | numpy==2.1.1 35 | nvidia-cublas-cu12==12.1.3.1 36 | nvidia-cuda-cupti-cu12==12.1.105 37 | nvidia-cuda-nvrtc-cu12==12.1.105 38 | nvidia-cuda-runtime-cu12==12.1.105 39 | nvidia-cudnn-cu12==9.1.0.70 40 | nvidia-cufft-cu12==11.0.2.54 41 | nvidia-curand-cu12==10.3.2.106 42 | nvidia-cusolver-cu12==11.4.5.107 43 | nvidia-cusparse-cu12==12.1.0.106 44 | nvidia-nccl-cu12==2.20.5 45 | nvidia-nvjitlink-cu12==12.6.68 46 | nvidia-nvtx-cu12==12.1.105 47 | packaging==24.1 48 | pandas==2.2.3 49 | parso==0.8.4 50 | pexpect==4.9.0 51 | pillow==10.4.0 52 | platformdirs==4.3.6 53 | prompt_toolkit==3.0.48 54 | psutil==6.0.0 55 | ptyprocess==0.7.0 56 | pure_eval==0.2.3 57 | pycryptodomex==3.20.0 58 | Pygments==2.18.0 59 | pyparsing==3.1.4 60 | python-dateutil==2.9.0.post0 61 | pytz==2024.2 62 | pyzmq==26.2.0 63 | RapidFuzz==3.10.0 64 | regex==2024.9.11 65 | requests==2.32.3 66 | seaborn==0.13.2 67 | six==1.16.0 68 | stack-data==0.6.3 69 | sympy==1.13.3 70 | termcolor==2.4.0 71 | tiktoken==0.7.0 72 | torch==2.4.1 73 | tornado==6.4.1 74 | tqdm==4.66.5 75 | traitlets==5.14.3 76 | triton==3.0.0 77 | typing_extensions==4.12.2 78 | tzdata==2024.2 79 | urllib3==2.2.3 80 | wcwidth==0.2.13 -------------------------------------------------------------------------------- /scripts/bias.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import argparse 3 | 4 | parser=argparse.ArgumentParser() 5 | parser.add_argument("--i", default=1) 6 | parser.add_argument("--direct", default=False) 7 | parser.add_argument("--prior", default=False) 8 | parser.add_argument("--model_family", default='llama3', choices=['llama3', 'mistral']) 9 | parser.add_argument("--attribute_to_intervene", default='Sex', choices=['Sex', 'Race']) 10 | args = parser.parse_args() 11 | 12 | nproc_per_node=1 13 | cmd = ['torchrun', '--nproc_per_node', f"{nproc_per_node}", 'src/bias.py'] 14 | for k, v in vars(args).items(): 15 | cmd.append(f"--{k}") 16 | cmd.append(str(v)) 17 | subprocess.run(args=cmd) -------------------------------------------------------------------------------- /scripts/census_queries.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import argparse 3 | 4 | parser=argparse.ArgumentParser() 5 | parser.add_argument("--model_family", default='llama3', choices=['llama3', 'mistral']) 6 | args = parser.parse_args() 7 | nproc_per_node = 1 8 | 9 | model_family = args.model_family 10 | 11 | if model_family=="llama3": 12 | seeds=[1094356, 2345656, 823845969964454] 13 | weights_path = "src/llama3/pretrained/Meta-Llama-3-8B-Instruct/" 14 | tokenizer_path = "src/llama3/pretrained/Meta-Llama-3-8B-Instruct/tokenizer.model" 15 | elif model_family=="mistral": 16 | seeds=[34435326297321, 124638290, 823845969964454] 17 | weights_path = "src/mistral-inference/8B-Instruct/" 18 | tokenizer_path = "/src/mistral-inference/8B-Instruct/" 19 | 20 | 21 | 22 | for i, seed in enumerate(seeds): 23 | params = { 24 | "ckpt_dir": weights_path, 25 | "tokenizer_path": tokenizer_path, 26 | "max_seq_len": 8192, 27 | "max_batch_size": 2, 28 | "seed": seed, 29 | "temperature": 0.8, 30 | "query": '"Generate census data of 50 fictional people."', 31 | "exp_name": f"{model_family}/census{i + 5}", # this is the name of the experiment -- results are saved under outputs/exp_name, 32 | "genstates": "True", # if True, the script will save the states of the random number generator, 33 | "system": '"Return only the following information: Age, Sex, Citizenship, Race, Ethnicity, Marital Status, Number of Children, Occupation, Income, Education. For Race, choose only between following options: White American, Black or African American, American Indian or Alaska Native, Asian American, Native Hawaiian or Other Pacific Islander, Other or Two or more races (multiracial). For Ethinicity, choose only between following options: Non-Hispanic/Latino or Hispanic/Latino. Return a list in json format delimited by \\"\`\`\`\\"."', 34 | "model_family": model_family 35 | } 36 | 37 | cmd = ['torchrun', '--nproc_per_node', f"{nproc_per_node}", 'src/single_query.py'] 38 | for j, k_v in enumerate(params.items()): 39 | k, v = k_v 40 | if j >= 4: 41 | cmd.append(f"--{k}={v}") 42 | else: 43 | cmd.append(f"--{k}") 44 | cmd.append(str(v)) 45 | cmd = ' '.join(cmd) 46 | subprocess.run(args=cmd, shell=True) 47 | -------------------------------------------------------------------------------- /scripts/cf_query.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | nproc_per_node = 1 3 | 4 | params = { 5 | 'max_seq_len': 2048, 6 | 'max_batch_size': 2, 7 | 'exp_name': "story-test", 8 | 'iteration': 1, 9 | 'prior': False, 10 | } 11 | 12 | cmd = ['torchrun', '--nproc_per_node', f"{nproc_per_node}", 'src/cf_query.py'] 13 | for k, v in params.items(): 14 | cmd.append(f"--{k}") 15 | cmd.append(str(v)) 16 | 17 | subprocess.run(args=cmd) -------------------------------------------------------------------------------- /scripts/stability.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | 4 | # Constant parameters 5 | # llama3 6 | ckpt_path=f"src/llama3/pretrained/Meta-Llama-3-8B-Instruct/" 7 | tokenizer_path=f"{ckpt_path}/tokenizer.model" 8 | model_family = 'llama3' 9 | 10 | # mistral 11 | # ckpt_path = f"src/mistral-inference/8B-Instruct/" 12 | # tokenizer_path = f"{ckpt_path}/src/mistral-inference/8B-Instruct/" 13 | # model_family = 'mistral' 14 | 15 | seed = 42 16 | system = "Keep your replies short and to the point." 17 | top_p = [0.75,0.9,0.95,0.99,0.999] 18 | top_k = [2,3,5,10,100] 19 | temperature = [0.0,0.2,0.4,0.6,0.8,1.0] 20 | input_file = f"data/questions.parquet" 21 | output_dir = "outputs/stability/mistral" # this is the name of the experiment -- results are saved under output_dir 22 | chunk_size = 25 23 | num_interventions = 2 24 | categorical = False 25 | 26 | parser=argparse.ArgumentParser() 27 | parser.add_argument("--ckpt_dir", default=ckpt_path) 28 | parser.add_argument("--tokenizer_path", default=tokenizer_path) 29 | parser.add_argument("--max_seq_len", default=2048, type=int) 30 | parser.add_argument("--max_batch_size", default=2, type=int) 31 | parser.add_argument("--seed", default=seed, type=int, 32 | help="the seed for the initialization of the random number generator") 33 | parser.add_argument("--input_file", default=input_file, 34 | help="path to file containing input queries") 35 | parser.add_argument("--output_dir", default=output_dir, 36 | help="path to output directory") 37 | parser.add_argument("--output_file_params", default="params", 38 | help="file name to store experiment parameters") 39 | parser.add_argument("--output_file_responses", default="stability_strings", 40 | help="file name to store all model responses") 41 | parser.add_argument("--system", default=system, 42 | help="the system prompt you want to use") 43 | parser.add_argument("--temperature", default=temperature, 44 | help="the temperature parameter of the model") 45 | parser.add_argument("--top_p",default=top_p, 46 | help="the value of p for top-p samplers") 47 | parser.add_argument("--top_k",default=top_k, 48 | help="the value of k for the top-k sampler") 49 | parser.add_argument("--chunk_size",default=chunk_size, 50 | help="number of queries to be written to file at once") 51 | parser.add_argument("--num_interventions",default=num_interventions, 52 | help="number of interventions per query") 53 | parser.add_argument("--top_p_intervention", default=0.9, 54 | help="top-p value to select the post-intervention token") 55 | parser.add_argument("--intervention_seed", default=seed, 56 | help="Seed to select the post-intervention token") 57 | parser.add_argument("--intervention_position_seed", default=seed, 58 | help="Seed to select the intervention position") 59 | parser.add_argument("--model_family", default=model_family, choices=['llama3', 'mistral'], 60 | help="Please select the model family") 61 | parser.add_argument("--categorical", default=False, 62 | help="If true, runs inverse transform sampler") 63 | 64 | args = parser.parse_args() 65 | if args.model_family == 'llama3': 66 | args.ckpt_dir = f"src/llama3/pretrained/Meta-Llama-3-8B-Instruct/" 67 | args.tokenizer_path = f"{args.ckpt_dir}/tokenizer.model" 68 | args.output_dir = "outputs/stability/llama3" 69 | elif args.model_family == 'mistral': 70 | args.ckpt_dir = f"src/mistral-inference/8B-Instruct/" 71 | args.tokenizer_path = f"{args.ckpt_dir}/tokenizer.model" 72 | args.output_dir = "outputs/stability/mistral" 73 | if args.categorical == 'True': 74 | args.top_p=[] 75 | args.top_k=[] 76 | args.output_dir="outputs/stability/categorical" 77 | 78 | 79 | nproc_per_node=1 80 | cmd = ['torchrun', '--nproc_per_node', f"{nproc_per_node}", 'src/stability.py'] 81 | for k, v in vars(args).items(): 82 | cmd.append(f"--{k}") 83 | cmd.append(str(v)) 84 | 85 | print(cmd) 86 | subprocess.run(args=cmd) 87 | 88 | 89 | -------------------------------------------------------------------------------- /scripts/story_query.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | nproc_per_node = 1 3 | 4 | 5 | params = { 6 | 'query': "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.", 7 | 'system': "Be creative and keep your response as short as possible.", 8 | 'seed': 2, 9 | 'temperature': 0.9, 10 | 'max_seq_len': 2048, 11 | 'max_batch_size': 2, 12 | 'exp_name': "story-test", 13 | 'sampler_type': "vocabulary", # "vocabulary" or "top-p position" or "top-p token" or "top-k token" 14 | 'top_p': 0.9, # this value doesn't matter if the respective sampler_type is not "top-p" 15 | 'top_k': 5, # this value doesn't matter if the respective sampler_type is not "top-k" 16 | 'genstates': "True", 17 | 'ckpt_dir':"src/llama3/pretrained/Meta-Llama-3-8B-Instruct/", 18 | 'tokenizer_path':"src/llama3/pretrained/Meta-Llama-3-8B-Instruct/tokenizer.model", 19 | 'model_family': "llama3" 20 | } 21 | # NOTE: instructions for using different model families 22 | # mistral: 23 | # - ckpt_dir="src/mistral-inference/8B-Instruct/" 24 | # - tokenizer_path="src/mistral-inference/8B-Instruct/" 25 | # - model_family="mistral" 26 | # llama3: 27 | # - ckpt_dir="src/llama3/pretrained/Meta-Llama-3-8B-Instruct/" 28 | # - tokenizer_path="src/llama3/pretrained/Meta-Llama-3-8B-Instruct/tokenizer.model" 29 | # - model_family="llama3" 30 | 31 | cmd = ['torchrun', '--nproc_per_node', f"{nproc_per_node}", 'src/single_query.py'] 32 | for k, v in params.items(): 33 | cmd.append(f"--{k}") 34 | cmd.append(str(v)) 35 | 36 | subprocess.run(args=cmd) -------------------------------------------------------------------------------- /src/cf_query.py: -------------------------------------------------------------------------------- 1 | # This software was developed by Ivi Chatzi, Nina Corvelo Benz, Eleni Straitouri, Stratis Tsirtsis, and Manuel Gomez Rodriguez. 2 | # If you use this code, please cite the paper "Counterfactual Token Generation in Large Language Models" by the same authors. 3 | 4 | from typing import List, Optional 5 | 6 | import fire 7 | 8 | import sys 9 | import os 10 | import json 11 | import torch 12 | from sampler import Sampler 13 | 14 | # mistral stuff 15 | from mistral_common.tokens.tokenizers.base import Tokenizer 16 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer 17 | from mistral_common.protocol.instruct.request import ChatCompletionRequest 18 | from mistral_common.protocol.instruct.messages import ( 19 | AssistantMessage, 20 | UserMessage, 21 | SystemMessage 22 | ) 23 | from mistral_inference.generate import generate 24 | # Add the src directory to the Python path 25 | sys.path.append(os.path.abspath("src/mistral-inference/src")) 26 | from mistral_inference.main import load_tokenizer, get_model_cls 27 | from pathlib import Path 28 | 29 | # Add the src directory to the Python path 30 | sys.path.append(os.path.abspath("src/llama3")) 31 | from llama import Dialog, Llama 32 | 33 | def main( 34 | max_seq_len: int = 512, 35 | max_batch_size: int = 4, 36 | max_gen_len: Optional[int] = None, 37 | exp_name: str = "test", 38 | prior: bool = False, 39 | iteration: int = 1 40 | ): 41 | """ 42 | Examples to run with the models finetuned for chat. Prompts correspond of chat 43 | turns between the user and assistant with the final one always being the user. 44 | 45 | An optional system prompt at the beginning to control how the model should respond 46 | is also supported. 47 | 48 | The context window of llama3 models is 8192 tokens, so `max_seq_len` needs to be <= 8192. 49 | 50 | `max_gen_len` is optional because finetuned models are able to stop generations naturally. 51 | """ 52 | 53 | with open(f"outputs/{exp_name}/intervention_{iteration}.json") as f: 54 | data = json.load(f) 55 | system = data["system"] 56 | query = data["query"] 57 | partial_response = data["response"] 58 | top_p = data["top_p"] 59 | top_k = data["top_k"] 60 | temperature = data["temperature"] 61 | start_from = data["start_from"] 62 | seed = data["seed"] 63 | sampler_type = data["sampler_type"] 64 | model_family = data["model_family"] 65 | ckpt_dir = data["ckpt_dir"] 66 | tokenizer_path = data["tokenizer_path"] 67 | 68 | sampler = Sampler(sampler_type=sampler_type, top_p=top_p, top_k=top_k) 69 | 70 | with open(f"outputs/{exp_name}/rngstates_{iteration}.pt", "rb") as f: 71 | rngstates = torch.load(f) 72 | init_rng_state = rngstates[start_from,:] 73 | 74 | if model_family == "llama3": 75 | generator = Llama.build( 76 | ckpt_dir=ckpt_dir, 77 | tokenizer_path=tokenizer_path, 78 | max_seq_len=max_seq_len, 79 | max_batch_size=max_batch_size, 80 | seed=seed, 81 | init_rng_state=init_rng_state, 82 | prior=prior 83 | ) 84 | 85 | dialogs: List[Dialog] = [ 86 | [ {"role": "system", "content": system}, 87 | {"role": "user", "content": query}] 88 | ] 89 | 90 | results, rngstates = generator.chat_completion( 91 | dialogs, 92 | max_gen_len=max_gen_len, 93 | temperature=temperature, 94 | sampler=sampler, 95 | genstates=False, 96 | partial_response=partial_response, 97 | ) 98 | 99 | elif model_family == "mistral": 100 | 101 | # load the model 102 | mistral_tokenizer: MistralTokenizer = load_tokenizer(Path(ckpt_dir)) 103 | tokenizer: Tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer 104 | 105 | model_cls = get_model_cls(ckpt_dir) 106 | model = model_cls.from_folder(Path(ckpt_dir), max_batch_size=max_batch_size, num_pipeline_ranks=1) 107 | 108 | messages: List[SystemMessage | UserMessage | AssistantMessage] = [] 109 | messages += [SystemMessage(content=system)] 110 | messages += [UserMessage(content=query)] 111 | 112 | chat_completion_request = ChatCompletionRequest(messages=messages) 113 | tokenized = mistral_tokenizer.encode_chat_completion(chat_completion_request) 114 | tokens = tokenized.tokens 115 | 116 | # initialize a random number generator with the given seed 117 | rng = torch.Generator(device="cuda") 118 | rng.manual_seed(seed) 119 | 120 | if not prior: 121 | # turn init_rng_state into a torch.ByteTensor 122 | init_rng_state = torch.tensor(init_rng_state, device="cpu").to(torch.uint8) 123 | rng.set_state(init_rng_state) 124 | 125 | fixed_tokens = tokenizer.encode(partial_response, bos=False, eos=False) 126 | 127 | generated_tokens, _, _ = generate( # type: ignore[operator] 128 | encoded_prompts = [tokens], 129 | model = model, 130 | max_tokens = max_seq_len, 131 | temperature = temperature, 132 | eos_id = tokenizer.eos_id, 133 | sampler = sampler, 134 | genstates = False, 135 | rng = rng, 136 | fixed_tokens=fixed_tokens 137 | ) 138 | 139 | results = [ 140 | { 141 | "generation": { 142 | "role": "assistant", 143 | "content": tokenizer.decode(t), 144 | "token_list": [tokenizer.decode([x]) for x in t], 145 | }, 146 | } 147 | for t in generated_tokens 148 | ] 149 | 150 | 151 | output_dir = os.path.join("outputs", exp_name) 152 | 153 | output = {} 154 | output["system"] = system 155 | output["query"] = query 156 | output["response"] = results[0]["generation"]["content"] 157 | output["seed"] = seed 158 | output["temperature"] = temperature 159 | output["sampler_type"] = sampler_type 160 | output["top_p"] = top_p 161 | output["top_k"] = top_k 162 | output["model_family"] = model_family 163 | output["start_from"] = 0 164 | output["token_list"] = {ind: tok for ind, tok in enumerate(results[0]["generation"]["token_list"])} 165 | 166 | with open(os.path.join(output_dir, f"counterfactual_{iteration}.json"), "w") as f: 167 | json.dump(output, f, indent=4) 168 | 169 | with open(os.path.join(output_dir, f"intervention_{iteration+1}.json"), "w") as f: 170 | json.dump(output, f, indent=4) 171 | 172 | print("----------------") 173 | print("SYSTEM:", system) 174 | print("----------------") 175 | print("QUERY:", query) 176 | print("----------------") 177 | print("RESPONSE:", results[0]["generation"]["content"]) 178 | print("----------------") 179 | 180 | 181 | if __name__ == "__main__": 182 | fire.Fire(main) 183 | 184 | -------------------------------------------------------------------------------- /src/llama3/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /src/llama3/USE_POLICY.md: -------------------------------------------------------------------------------- 1 | # Meta Llama 3 Acceptable Use Policy 2 | 3 | Meta is committed to promoting safe and fair use of its tools and features, including Llama 3. If you access or use Llama 3, you agree to this Acceptable Use Policy (“Policy”). The most recent copy of this policy can be found at [ai.meta.com/llama/use-policy](http://ai.meta.com/llama/use-policy). 4 | 5 | ## Prohibited Uses 6 | We want everyone to use Llama 3 safely and responsibly. You agree you will not use, or allow others to use, Llama 3 to: 7 | 8 | 1. Violate the law or others’ rights, including to: 9 | 1. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as: 10 | 1. Violence or terrorism 11 | 2. Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material 12 | 3. Human trafficking, exploitation, and sexual violence 13 | 4. The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials. 14 | 5. Sexual solicitation 15 | 6. Any other criminal activity 16 | 2. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals 17 | 3. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services 18 | 4. Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices 19 | 5. Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws 20 | 6. Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any products or services using the Llama 3 Materials 21 | 7. Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system 22 | 23 | 24 | 25 | 2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of Llama 3 related to the following: 26 | 1. Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State 27 | 2. Guns and illegal weapons (including weapon development) 28 | 3. Illegal drugs and regulated/controlled substances 29 | 4. Operation of critical infrastructure, transportation technologies, or heavy machinery 30 | 5. Self-harm or harm to others, including suicide, cutting, and eating disorders 31 | 6. Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual 32 | 33 | 34 | 35 | 3. Intentionally deceive or mislead others, including use of Llama 3 related to the following: 36 | 1. Generating, promoting, or furthering fraud or the creation or promotion of disinformation 37 | 2. Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content 38 | 3. Generating, promoting, or further distributing spam 39 | 4. Impersonating another individual without consent, authorization, or legal right 40 | 5. Representing that the use of Llama 3 or outputs are human-generated 41 | 6. Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement 42 | 4. Fail to appropriately disclose to end users any known dangers of your AI system 43 | 44 | Please report any violation of this Policy, software “bug,” or other problems that could lead to a violation of this Policy through one of the following means: 45 | 46 | * Reporting issues with the model: [github.com/facebookresearch/llama](http://github.com/facebookresearch/llama) 47 | * Reporting risky content generated by the model: [developers.facebook.com/llama_output_feedback](http://developers.facebook.com/llama_output_feedback) 48 | * Reporting bugs and security concerns: [facebook.com/whitehat/info](http://facebook.com/whitehat/info) 49 | * Reporting violations of the Acceptable Use Policy or unlicensed uses of Llama: [LlamaUseReport@meta.com](mailto:LlamaUseReport@meta.com) 50 | 51 | -------------------------------------------------------------------------------- /src/llama3/llama/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. 3 | 4 | from .generation import Llama 5 | from .model import ModelArgs, Transformer 6 | from .tokenizer import Dialog, Tokenizer 7 | -------------------------------------------------------------------------------- /src/llama3/llama/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. 3 | 4 | import os 5 | from unittest import TestCase 6 | from llama.tokenizer import ChatFormat, Tokenizer 7 | 8 | # TOKENIZER_PATH= python -m unittest llama/test_tokenizer.py 9 | 10 | class TokenizerTests(TestCase): 11 | def setUp(self): 12 | self.tokenizer = Tokenizer(os.environ["TOKENIZER_PATH"]) 13 | self.format = ChatFormat(self.tokenizer) 14 | 15 | def test_special_tokens(self): 16 | self.assertEqual( 17 | self.tokenizer.special_tokens["<|begin_of_text|>"], 18 | 128000, 19 | ) 20 | 21 | def test_encode(self): 22 | self.assertEqual( 23 | self.tokenizer.encode( 24 | "This is a test sentence.", 25 | bos=True, 26 | eos=True 27 | ), 28 | [128000, 2028, 374, 264, 1296, 11914, 13, 128001], 29 | ) 30 | 31 | def test_decode(self): 32 | self.assertEqual( 33 | self.tokenizer.decode( 34 | [128000, 2028, 374, 264, 1296, 11914, 13, 128001], 35 | ), 36 | "<|begin_of_text|>This is a test sentence.<|end_of_text|>", 37 | ) 38 | 39 | def test_encode_message(self): 40 | message = { 41 | "role": "user", 42 | "content": "This is a test sentence.", 43 | } 44 | self.assertEqual( 45 | self.format.encode_message(message), 46 | [ 47 | 128006, # <|start_header_id|> 48 | 882, # "user" 49 | 128007, # <|end_header_id|> 50 | 271, # "\n\n" 51 | 2028, 374, 264, 1296, 11914, 13, # This is a test sentence. 52 | 128009, # <|eot_id|> 53 | ] 54 | ) 55 | 56 | def test_encode_dialog(self): 57 | dialog = [ 58 | { 59 | "role": "system", 60 | "content": "This is a test sentence.", 61 | }, 62 | { 63 | "role": "user", 64 | "content": "This is a response.", 65 | } 66 | ] 67 | self.assertEqual( 68 | self.format.encode_dialog_prompt(dialog), 69 | [ 70 | 128000, # <|begin_of_text|> 71 | 128006, # <|start_header_id|> 72 | 9125, # "system" 73 | 128007, # <|end_header_id|> 74 | 271, # "\n\n" 75 | 2028, 374, 264, 1296, 11914, 13, # "This is a test sentence." 76 | 128009, # <|eot_id|> 77 | 128006, # <|start_header_id|> 78 | 882, # "user" 79 | 128007, # <|end_header_id|> 80 | 271, # "\n\n" 81 | 2028, 374, 264, 2077, 13, # "This is a response.", 82 | 128009, # <|eot_id|> 83 | 128006, # <|start_header_id|> 84 | 78191, # "assistant" 85 | 128007, # <|end_header_id|> 86 | 271, # "\n\n" 87 | ] 88 | ) 89 | -------------------------------------------------------------------------------- /src/llama3/pretrained/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/src/llama3/pretrained/.placeholder -------------------------------------------------------------------------------- /src/llama3/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | fairscale 3 | fire 4 | tiktoken==0.4.0 5 | blobfile 6 | -------------------------------------------------------------------------------- /src/llama3/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. 3 | 4 | from setuptools import find_packages, setup 5 | 6 | 7 | def get_requirements(path: str): 8 | return [l.strip() for l in open(path)] 9 | 10 | 11 | setup( 12 | name="llama3", 13 | version="0.0.1", 14 | packages=find_packages(), 15 | install_requires=get_requirements("requirements.txt"), 16 | ) 17 | -------------------------------------------------------------------------------- /src/mistral-inference/assets/smoe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/src/mistral-inference/assets/smoe.png -------------------------------------------------------------------------------- /src/mistral-inference/deploy/.dockerignore: -------------------------------------------------------------------------------- 1 | * 2 | !entrypoint.sh 3 | -------------------------------------------------------------------------------- /src/mistral-inference/deploy/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM --platform=amd64 nvcr.io/nvidia/cuda:12.1.0-devel-ubuntu22.04 as base 2 | 3 | WORKDIR /workspace 4 | 5 | RUN apt update && \ 6 | apt install -y python3-pip python3-packaging \ 7 | git ninja-build && \ 8 | pip3 install -U pip 9 | 10 | # Tweak this list to reduce build time 11 | # https://developer.nvidia.com/cuda-gpus 12 | ENV TORCH_CUDA_ARCH_LIST "7.0;7.2;7.5;8.0;8.6;8.9;9.0" 13 | 14 | RUN pip3 install "torch==2.1.1" 15 | 16 | # This build is slow but NVIDIA does not provide binaries. Increase MAX_JOBS as needed. 17 | RUN pip3 install "git+https://github.com/stanford-futuredata/megablocks.git" 18 | RUN pip3 install "git+https://github.com/vllm-project/vllm.git" 19 | RUN pip3 install "xformers==0.0.23" "transformers==4.36.0" "fschat[model_worker]==0.2.34" 20 | 21 | RUN git clone https://github.com/NVIDIA/apex && \ 22 | cd apex && git checkout 2386a912164b0c5cfcd8be7a2b890fbac5607c82 && \ 23 | sed -i '/check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)/d' setup.py && \ 24 | python3 setup.py install --cpp_ext --cuda_ext 25 | 26 | 27 | COPY entrypoint.sh . 28 | 29 | RUN chmod +x /workspace/entrypoint.sh 30 | 31 | ENTRYPOINT ["/workspace/entrypoint.sh"] -------------------------------------------------------------------------------- /src/mistral-inference/deploy/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ ! -z "${HF_TOKEN}" ]]; then 4 | echo "The HF_TOKEN environment variable is set, logging to Hugging Face." 5 | python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')" 6 | else 7 | echo "The HF_TOKEN environment variable is not set or empty, not logging to Hugging Face." 8 | fi 9 | 10 | # Run the provided command 11 | exec python3 -u -m vllm.entrypoints.openai.api_server "$@" 12 | -------------------------------------------------------------------------------- /src/mistral-inference/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "mistral_inference" 3 | version = "1.5.0" 4 | description = "" 5 | authors = ["bam4d "] 6 | readme = "README.md" 7 | packages = [{ include = "mistral_inference", from = "src" }] 8 | 9 | [tool.ruff] 10 | lint.select = ["E", "F", "W", "Q", "I"] 11 | lint.ignore = ["E203"] 12 | lint.fixable = ["ALL"] 13 | lint.unfixable = [] 14 | line-length = 120 15 | exclude = ["docs", "build", "tutorials"] 16 | 17 | [tool.mypy] 18 | disallow_untyped_defs = true 19 | show_error_codes = true 20 | no_implicit_optional = true 21 | warn_return_any = true 22 | warn_unused_ignores = true 23 | exclude = ["docs", "tools", "build"] 24 | 25 | [tool.poetry.dependencies] 26 | python = "^3.9.10" 27 | xformers = ">=0.0.24" 28 | simple-parsing = ">=0.1.5" 29 | fire = ">=0.6.0" 30 | mistral_common = ">=1.4.0" 31 | safetensors = ">=0.4.0" 32 | pillow = ">=10.3.0" 33 | 34 | [tool.poetry.group.dev.dependencies] 35 | types-protobuf = "4.24.0.20240129" 36 | mypy-protobuf = "^3.5.0" 37 | pytest = "7.4.4" 38 | ruff = "^0.2.2" 39 | mypy = "^1.8.0" 40 | 41 | [build-system] 42 | requires = ["poetry-core>=1.0.0"] 43 | build-backend = "poetry.core.masonry.api" 44 | 45 | [tool.pytest.ini_options] 46 | testpaths = ["./tests"] 47 | 48 | [tool.poetry.scripts] 49 | mistral-chat = "mistral_inference.main:mistral_chat" 50 | mistral-demo = "mistral_inference.main:mistral_demo" 51 | -------------------------------------------------------------------------------- /src/mistral-inference/src/mistral_inference/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.5.0" 2 | -------------------------------------------------------------------------------- /src/mistral-inference/src/mistral_inference/args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional 3 | 4 | from simple_parsing.helpers import Serializable 5 | 6 | from mistral_inference.lora import LoraArgs 7 | from mistral_inference.moe import MoeArgs 8 | 9 | 10 | @dataclass 11 | class VisionEncoderArgs: 12 | hidden_size: int 13 | num_channels: int 14 | image_size: int 15 | patch_size: int 16 | intermediate_size: int 17 | num_hidden_layers: int 18 | num_attention_heads: int 19 | rope_theta: float = 1e4 # for rope-2D 20 | image_token_id: int = 10 21 | 22 | 23 | @dataclass 24 | class TransformerArgs(Serializable): 25 | dim: int 26 | n_layers: int 27 | head_dim: int 28 | hidden_dim: int 29 | n_heads: int 30 | n_kv_heads: int 31 | norm_eps: float 32 | vocab_size: int 33 | 34 | max_batch_size: int = 0 35 | 36 | # For rotary embeddings. If not set, will be inferred 37 | rope_theta: Optional[float] = None 38 | # If this is set, we will use MoE layers instead of dense layers. 39 | moe: Optional[MoeArgs] = None 40 | # If this is set, we will load LoRA linear layers instead of linear layers. 41 | lora: Optional[LoraArgs] = None 42 | sliding_window: Optional[int] | Optional[List[int]] = None 43 | _sliding_window: Optional[int] | Optional[List[int]] = None 44 | model_type: str = "transformer" 45 | 46 | vision_encoder: Optional[VisionEncoderArgs] = None 47 | 48 | def __post_init__(self) -> None: 49 | assert self.model_type == "transformer", self.model_type 50 | assert self.sliding_window is None or self._sliding_window is None 51 | 52 | # hack for now so that vLLM is supported correctly 53 | self.sliding_window = self.sliding_window if self.sliding_window is not None else self._sliding_window 54 | 55 | 56 | @dataclass 57 | class MambaArgs(Serializable): 58 | dim: int 59 | n_layers: int 60 | vocab_size: int 61 | n_groups: int 62 | rms_norm: bool 63 | residual_in_fp32: bool 64 | fused_add_norm: bool 65 | pad_vocab_size_multiple: int 66 | tie_embeddings: bool 67 | model_type: str = "mamba" 68 | 69 | def __post_init__(self) -> None: 70 | assert self.model_type == "mamba", self.model_type 71 | -------------------------------------------------------------------------------- /src/mistral-inference/src/mistral_inference/generate.py: -------------------------------------------------------------------------------- 1 | # DISCLAIMER: The original code was edited by Ivi Chatzi, Nina Corvelo Benz, Eleni Straitouri, Stratis Tsirtsis, and Manuel Gomez Rodriguez. 2 | # If you use this version of the code, please cite the paper "Counterfactual Token Generation in Large Language Models" by the same authors. 3 | 4 | from typing import List, Optional, Tuple 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from mistral_inference.cache import BufferCache 10 | from mistral_inference.mamba import Mamba 11 | from mistral_inference.transformer import Transformer 12 | 13 | 14 | @torch.inference_mode() 15 | def generate_mamba( 16 | encoded_prompts: List[List[int]], 17 | model: Mamba, 18 | *, 19 | max_tokens: int, 20 | temperature: float, 21 | chunk_size: Optional[int] = None, 22 | eos_id: Optional[int] = None, 23 | ) -> Tuple[List[List[int]], List[List[float]]]: 24 | input_ids = torch.tensor(encoded_prompts, device=model.device) 25 | output = model.model.generate( 26 | input_ids=input_ids, 27 | max_length=input_ids.shape[-1] + max_tokens, 28 | cg=True, 29 | return_dict_in_generate=True, 30 | output_scores=True, 31 | enable_timing=False, 32 | eos_token_id=eos_id, 33 | temperature=temperature, 34 | top_p=0.8, 35 | ) 36 | generated_tokens = output.sequences[:, input_ids.shape[-1] :].tolist() 37 | 38 | _logprobs: List[List[float]] = [[] for _ in range(len(generated_tokens))] 39 | for seq_idx, batch_score in enumerate(output.scores): 40 | for batch_idx, score in enumerate(batch_score.tolist()): 41 | _logprobs[batch_idx].append(score[generated_tokens[batch_idx][seq_idx]]) 42 | 43 | return generated_tokens, _logprobs 44 | 45 | 46 | @torch.inference_mode() 47 | def generate( 48 | encoded_prompts: List[List[int]], 49 | model: Transformer, 50 | images: List[List[np.ndarray]] = [], 51 | *, 52 | max_tokens: int, 53 | temperature: float, 54 | chunk_size: Optional[int] = None, 55 | eos_id: Optional[int] = None, 56 | genstates: bool = True, 57 | sampler, 58 | rng: torch.Generator, 59 | fixed_tokens: Optional[List[int]] = None, 60 | auto_intervention: bool = False 61 | ) -> Tuple[List[List[int]], List[List[float]]]: 62 | images_torch: List[List[torch.Tensor]] = [] 63 | if images: 64 | assert chunk_size is None 65 | images_torch = [ 66 | [torch.tensor(im, device=model.device, dtype=model.dtype) for im in images_for_sample] 67 | for images_for_sample in images 68 | ] 69 | 70 | model = model.eval() 71 | B, V = len(encoded_prompts), model.args.vocab_size 72 | 73 | seqlens = [len(x) for x in encoded_prompts] 74 | 75 | # Cache 76 | cache_window = max(seqlens) + max_tokens 77 | cache = BufferCache( 78 | model.n_local_layers, 79 | model.args.max_batch_size, 80 | cache_window, 81 | model.args.n_kv_heads, 82 | model.args.head_dim, 83 | model.args.sliding_window, 84 | ) 85 | cache.to(device=model.device, dtype=model.dtype) 86 | cache.reset() 87 | 88 | # Bookkeeping 89 | logprobs: List[List[float]] = [[] for _ in range(B)] 90 | last_token_prelogits = None 91 | 92 | # One chunk if size not specified 93 | max_prompt_len = max(seqlens) 94 | if chunk_size is None: 95 | chunk_size = max_prompt_len 96 | 97 | flattened_images: List[torch.Tensor] = sum(images_torch, []) 98 | 99 | # Encode prompt by chunks 100 | for s in range(0, max_prompt_len, chunk_size): 101 | prompt_chunks = [p[s : s + chunk_size] for p in encoded_prompts] 102 | assert all(len(p) > 0 for p in prompt_chunks) 103 | prelogits = model.forward( 104 | torch.tensor(sum(prompt_chunks, []), device=model.device, dtype=torch.long), 105 | images=flattened_images, 106 | seqlens=[len(p) for p in prompt_chunks], 107 | cache=cache, 108 | ) 109 | logits = torch.log_softmax(prelogits, dim=-1) 110 | 111 | if last_token_prelogits is not None: 112 | # Pass > 1 113 | last_token_logits = torch.log_softmax(last_token_prelogits, dim=-1) 114 | for i_seq in range(B): 115 | logprobs[i_seq].append(last_token_logits[i_seq, prompt_chunks[i_seq][0]].item()) 116 | 117 | offset = 0 118 | for i_seq, sequence in enumerate(prompt_chunks): 119 | logprobs[i_seq].extend([logits[offset + i, sequence[i + 1]].item() for i in range(len(sequence) - 1)]) 120 | offset += len(sequence) 121 | 122 | last_token_prelogits = prelogits.index_select( 123 | 0, 124 | torch.tensor([len(p) for p in prompt_chunks], device=prelogits.device).cumsum(dim=0) - 1, 125 | ) 126 | assert last_token_prelogits.shape == (B, V) 127 | 128 | # decode 129 | generated_tensors = [] 130 | is_finished = torch.tensor([False for _ in range(B)]) 131 | if genstates: 132 | token_genstates = torch.zeros((max_tokens, rng.get_state().numel()), dtype=torch.uint8) 133 | assert last_token_prelogits is not None 134 | 135 | if fixed_tokens is not None: 136 | fixed_counter = 0 137 | 138 | for gen_state_counter in range(max_tokens): 139 | 140 | if fixed_tokens is not None and fixed_counter < len(fixed_tokens): 141 | if auto_intervention and fixed_counter == len(fixed_tokens)-1: 142 | if temperature > 0: 143 | probs = torch.softmax(last_token_prelogits / temperature, dim=-1) 144 | next_token = sampler.intervention(probs, torch.tensor(fixed_tokens[fixed_counter], device=model.device)) 145 | next_token = next_token.reshape(-1) 146 | else: 147 | next_token = torch.topk(last_token_prelogits,k=2,dim=-1).indices[0][1] 148 | next_token = next_token.unsqueeze(0) 149 | fixed_counter += 1 150 | else: 151 | next_token = torch.tensor(fixed_tokens[fixed_counter], device=model.device) 152 | next_token = next_token.unsqueeze(0) 153 | fixed_counter += 1 154 | else: 155 | ########################################################################################## 156 | # NOTE: the sampling takes place here 157 | if temperature > 0: 158 | probs = torch.softmax(last_token_prelogits / temperature, dim=-1) 159 | # get the state of the random number generator 160 | if genstates: 161 | token_genstates[gen_state_counter] = rng.get_state() 162 | next_token = sampler.sample(probs, rng) 163 | next_token = next_token.reshape(-1) 164 | else: 165 | next_token = torch.argmax(last_token_prelogits, dim=-1) 166 | ########################################################################################## 167 | 168 | if eos_id is not None: 169 | is_finished = is_finished | (next_token == eos_id).cpu() 170 | 171 | if is_finished.all(): 172 | break 173 | 174 | last_token_logits = torch.log_softmax(last_token_prelogits, dim=-1) 175 | for i in range(B): 176 | logprobs[i].append(last_token_logits[i, next_token[i]].item()) 177 | 178 | generated_tensors.append(next_token[:, None]) 179 | last_token_prelogits = model.forward(next_token, seqlens=[1] * B, cache=cache) 180 | assert last_token_prelogits.shape == (B, V) 181 | 182 | generated_tokens: List[List[int]] 183 | if generated_tensors: 184 | generated_tokens = torch.cat(generated_tensors, 1).tolist() 185 | else: 186 | generated_tokens = [] 187 | 188 | if genstates: 189 | token_genstates = token_genstates.tolist() 190 | 191 | return (generated_tokens, logprobs, token_genstates if genstates else None) 192 | -------------------------------------------------------------------------------- /src/mistral-inference/src/mistral_inference/lora.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Any, Dict, NamedTuple, Union 5 | 6 | import safetensors.torch 7 | import torch 8 | import torch.nn as nn 9 | from simple_parsing.helpers import Serializable 10 | 11 | 12 | @dataclass 13 | class LoraArgs(Serializable): 14 | rank: int 15 | scaling: float 16 | 17 | def __post_init__(self) -> None: 18 | assert self.rank > 0 19 | assert self.scaling > 0.0 20 | 21 | 22 | class LoRALinear(nn.Module): 23 | """ 24 | Implementation of: 25 | - LoRA: https://arxiv.org/abs/2106.09685 26 | 27 | Notes: 28 | - Freezing is handled at network level, not layer level. 29 | - Scaling factor controls relative importance of LoRA skip 30 | connection versus original frozen weight. General guidance is 31 | to keep it to 2.0 and sweep over learning rate when changing 32 | the rank. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | in_features: int, 38 | out_features: int, 39 | rank: int, 40 | scaling: float, 41 | bias: bool = False, 42 | ): 43 | super().__init__() 44 | 45 | self.in_features = in_features 46 | self.out_features = out_features 47 | assert not bias 48 | self.bias = bias 49 | self.rank = rank 50 | self.scaling = scaling 51 | 52 | self.lora_A = nn.Linear( 53 | self.in_features, 54 | self.rank, 55 | bias=self.bias, 56 | ) 57 | self.lora_B = nn.Linear( 58 | self.rank, 59 | self.out_features, 60 | bias=self.bias, 61 | ) 62 | 63 | self.linear = nn.Linear(self.in_features, self.out_features, bias=self.bias) 64 | 65 | # make sure no LoRA weights are marked as "missing" in load_state_dict 66 | def ignore_missing_keys(m: nn.Module, incompatible_keys: NamedTuple) -> None: 67 | incompatible_keys.missing_keys[:] = [] # type: ignore 68 | 69 | self.register_load_state_dict_post_hook(ignore_missing_keys) 70 | 71 | def forward(self, x: torch.Tensor) -> torch.Tensor: 72 | lora = self.lora_B(self.lora_A(x)) 73 | result: torch.Tensor = self.linear(x) + lora * self.scaling 74 | return result 75 | 76 | def _load_from_state_dict(self, state_dict: Dict[str, Any], prefix: str, *args, **kwargs) -> None: # type: ignore[no-untyped-def] 77 | key_name = prefix + "weight" 78 | 79 | # full checkpoint 80 | if key_name in state_dict: 81 | w_ref = state_dict[key_name] 82 | 83 | # load frozen weights 84 | state_dict = { 85 | "linear.weight": w_ref, 86 | "lora_A.weight": torch.zeros_like(self.lora_A.weight, device=w_ref.device, dtype=w_ref.dtype), 87 | "lora_B.weight": torch.zeros_like(self.lora_B.weight, device=w_ref.device, dtype=w_ref.dtype), 88 | } 89 | self.load_state_dict(state_dict, assign=True, strict=True) 90 | 91 | 92 | class LoRALoaderMixin: 93 | def load_lora(self, lora_path: Union[Path, str], scaling: float = 2.0) -> None: 94 | """Loads LoRA checkpoint""" 95 | 96 | lora_path = Path(lora_path) 97 | assert lora_path.is_file(), f"{lora_path} does not exist or is not a file" 98 | 99 | state_dict = safetensors.torch.load_file(lora_path) 100 | 101 | self._load_lora_state_dict(state_dict, scaling=scaling) 102 | 103 | def _load_lora_state_dict(self, lora_state_dict: Dict[str, torch.Tensor], scaling: float = 2.0) -> None: 104 | """Loads LoRA state_dict""" 105 | lora_dtypes = set([p.dtype for p in lora_state_dict.values()]) 106 | assert ( 107 | len(lora_dtypes) == 1 108 | ), f"LoRA weights have multiple different dtypes {lora_dtypes}. All weights need to have the same dtype" 109 | lora_dtype = lora_dtypes.pop() 110 | assert lora_dtype == self.dtype, f"LoRA weights dtype differs from model's dtype {lora_dtype} != {self.dtype}" # type: ignore[attr-defined] 111 | assert all("lora" in key for key in lora_state_dict.keys()) 112 | 113 | # move tensors to device 114 | lora_state_dict = {k: v.to(self.device) for k, v in lora_state_dict.items()} # type: ignore[attr-defined] 115 | 116 | state_dict = self.state_dict() # type: ignore[attr-defined] 117 | 118 | if self.args.lora is None: # type: ignore[attr-defined] 119 | logging.info("Loading and merging LoRA weights...") 120 | 121 | # replace every nn.Linear with a LoRALinear with 'meta' device except the output layer 122 | named_modules = dict(self.named_modules()) # type: ignore[attr-defined] 123 | for name, module in named_modules.items(): 124 | if isinstance(module, nn.Linear) and name != "output": 125 | layer_id = name.split(".")[1] 126 | if layer_id not in self.layers: # type: ignore[attr-defined] 127 | logging.debug( 128 | "Skipping parameter %s at pipeline rank %d", 129 | name, 130 | self.pipeline_rank, # type: ignore[attr-defined] 131 | ) 132 | elif (name + ".lora_B.weight") in lora_state_dict: 133 | weight = ( 134 | module.weight 135 | + (lora_state_dict[name + ".lora_B.weight"] @ lora_state_dict[name + ".lora_A.weight"]) 136 | * scaling 137 | ) 138 | 139 | state_dict[name + ".weight"] = weight 140 | else: 141 | logging.info("Loading LoRA weights...") 142 | for k, v in lora_state_dict.items(): 143 | state_dict.update(lora_state_dict) 144 | 145 | layer_id = k.split(".")[1] 146 | if layer_id in self.layers: # type: ignore[attr-defined] 147 | state_dict[k] = v 148 | else: 149 | logging.debug( 150 | "Skipping parameter %s at pipeline rank %d", 151 | k, 152 | self.pipeline_rank, # type: ignore[attr-defined] 153 | ) 154 | 155 | self.load_state_dict(state_dict, strict=True) # type: ignore[attr-defined] 156 | -------------------------------------------------------------------------------- /src/mistral-inference/src/mistral_inference/mamba.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import List, Optional, Union 4 | 5 | import safetensors 6 | import torch 7 | import torch.nn as nn 8 | 9 | from mistral_inference.args import MambaArgs 10 | from mistral_inference.cache import BufferCache 11 | from mistral_inference.model import ModelBase 12 | 13 | _is_mamba_installed = False 14 | try: 15 | from mamba_ssm.models.config_mamba import MambaConfig 16 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 17 | 18 | _is_mamba_installed = True 19 | except ImportError: 20 | _is_mamba_installed = False 21 | 22 | 23 | class Mamba(ModelBase, nn.Module): 24 | def __init__(self, args: MambaArgs): 25 | super().__init__() 26 | self.args = args 27 | assert _is_mamba_installed, "Mamba is not installed. Please install it using `pip install mamba-ssm`." 28 | 29 | # make sure naming is consistent with `mamba_ssm` 30 | config = MambaConfig( 31 | d_model=args.dim, 32 | n_layer=args.n_layers, 33 | vocab_size=args.vocab_size, 34 | ssm_cfg={"ngroups": args.n_groups, "layer": "Mamba2"}, 35 | attn_layer_idx=[], 36 | attn_cfg={}, 37 | rms_norm=args.rms_norm, 38 | residual_in_fp32=args.residual_in_fp32, 39 | fused_add_norm=args.fused_add_norm, 40 | pad_vocab_size_multiple=args.pad_vocab_size_multiple, 41 | tie_embeddings=args.tie_embeddings, 42 | ) 43 | self.model = MambaLMHeadModel(config) 44 | 45 | @property 46 | def dtype(self) -> torch.dtype: 47 | return next(self.parameters()).dtype 48 | 49 | @property 50 | def device(self) -> torch.device: 51 | return next(self.parameters()).device 52 | 53 | def forward( 54 | self, 55 | input_ids: torch.Tensor, 56 | seqlens: List[int], # not supported for now 57 | cache: Optional[BufferCache] = None, # not supported for now 58 | ) -> torch.Tensor: 59 | lm_output = self.model(input_ids) 60 | result: torch.Tensor = lm_output.logits 61 | return result 62 | 63 | @staticmethod 64 | def from_folder( 65 | folder: Union[Path, str], 66 | max_batch_size: int = 1, 67 | num_pipeline_ranks: int = 1, 68 | device: Union[torch.device, str] = "cuda", 69 | dtype: Optional[torch.dtype] = None, 70 | ) -> "Mamba": 71 | with open(Path(folder) / "params.json", "r") as f: 72 | model_args = MambaArgs.from_dict(json.load(f)) 73 | 74 | with torch.device("meta"): 75 | model = Mamba(model_args) 76 | 77 | model_file = Path(folder) / "consolidated.safetensors" 78 | 79 | assert model_file.exists(), f"Make sure {model_file} exists." 80 | loaded = safetensors.torch.load_file(str(model_file)) 81 | 82 | model.load_state_dict(loaded, assign=True, strict=True) 83 | return model.to(device=device, dtype=dtype) 84 | -------------------------------------------------------------------------------- /src/mistral-inference/src/mistral_inference/model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from pathlib import Path 3 | from typing import List, Optional, Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from mistral_inference.cache import BufferCache 9 | 10 | 11 | class ModelBase(nn.Module, ABC): 12 | def __init__(self) -> None: 13 | super().__init__() 14 | 15 | @property 16 | @abstractmethod 17 | def dtype(self) -> torch.dtype: 18 | pass 19 | 20 | @property 21 | @abstractmethod 22 | def device(self) -> torch.device: 23 | pass 24 | 25 | @abstractmethod 26 | def forward( 27 | self, 28 | input_ids: torch.Tensor, 29 | seqlens: List[int], # not supported for now 30 | cache: Optional[BufferCache] = None, # not supported for now 31 | ) -> torch.Tensor: 32 | pass 33 | 34 | @staticmethod 35 | @abstractmethod 36 | def from_folder( 37 | folder: Union[Path, str], 38 | max_batch_size: int = 1, 39 | num_pipeline_ranks: int = 1, 40 | device: Union[torch.device, str] = "cuda", 41 | dtype: Optional[torch.dtype] = None, 42 | ) -> "ModelBase": 43 | pass 44 | -------------------------------------------------------------------------------- /src/mistral-inference/src/mistral_inference/moe.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from simple_parsing.helpers import Serializable 7 | from torch import nn 8 | 9 | 10 | @dataclasses.dataclass 11 | class MoeArgs(Serializable): 12 | num_experts: int 13 | num_experts_per_tok: int 14 | 15 | 16 | class MoeLayer(nn.Module): 17 | def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs): 18 | super().__init__() 19 | assert len(experts) > 0 20 | self.experts = nn.ModuleList(experts) 21 | self.gate = gate 22 | self.args = moe_args 23 | 24 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 25 | gate_logits = self.gate(inputs) 26 | weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok) 27 | weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype) 28 | results = torch.zeros_like(inputs) 29 | for i, expert in enumerate(self.experts): 30 | batch_idx, nth_expert = torch.where(selected_experts == i) 31 | results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs[batch_idx]) 32 | return results 33 | -------------------------------------------------------------------------------- /src/mistral-inference/src/mistral_inference/rope.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | 6 | def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor: 7 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 8 | t = torch.arange(end, device=freqs.device) 9 | freqs = torch.outer(t, freqs).float() 10 | return torch.polar(torch.ones_like(freqs), freqs) # complex64 11 | 12 | 13 | def apply_rotary_emb( 14 | xq: torch.Tensor, 15 | xk: torch.Tensor, 16 | freqs_cis: torch.Tensor, 17 | ) -> Tuple[torch.Tensor, torch.Tensor]: 18 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 19 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 20 | freqs_cis = freqs_cis[:, None, :] 21 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) 22 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) 23 | return xq_out.type_as(xq), xk_out.type_as(xk) 24 | 25 | 26 | def precompute_freqs_cis_2d( 27 | dim: int, 28 | height: int, 29 | width: int, 30 | theta: float, 31 | ) -> torch.Tensor: 32 | """ 33 | freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by 34 | (height, width) position tuples 35 | """ 36 | # (dim / 2) frequency bases 37 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) 38 | 39 | h = torch.arange(height, device=freqs.device) 40 | w = torch.arange(width, device=freqs.device) 41 | 42 | freqs_h = torch.outer(h, freqs[::2]).float() 43 | freqs_w = torch.outer(w, freqs[1::2]).float() 44 | freqs_2d = torch.cat( 45 | [ 46 | freqs_h[:, None, :].repeat(1, width, 1), 47 | freqs_w[None, :, :].repeat(height, 1, 1), 48 | ], 49 | dim=-1, 50 | ) 51 | return torch.polar(torch.ones_like(freqs_2d), freqs_2d) 52 | -------------------------------------------------------------------------------- /src/mistral-inference/src/mistral_inference/transformer_layers.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Optional, Tuple, Type, Union 3 | 4 | import torch 5 | from torch import nn 6 | from xformers.ops.fmha import memory_efficient_attention # type: ignore 7 | from xformers.ops.fmha.attn_bias import BlockDiagonalMask 8 | 9 | from mistral_inference.args import LoraArgs 10 | from mistral_inference.cache import CacheView 11 | from mistral_inference.lora import LoRALinear 12 | from mistral_inference.moe import MoeArgs, MoeLayer 13 | from mistral_inference.rope import apply_rotary_emb 14 | 15 | 16 | def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]: 17 | keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim) 18 | values = torch.repeat_interleave(values, repeats=repeats, dim=dim) 19 | return keys, values 20 | 21 | 22 | def maybe_lora( 23 | lora_args: Optional[LoraArgs], 24 | ) -> Union[Type[nn.Linear], partial[LoRALinear]]: 25 | if lora_args is None: 26 | return nn.Linear 27 | else: 28 | return partial(LoRALinear, rank=lora_args.rank, scaling=lora_args.scaling) 29 | 30 | 31 | class Attention(nn.Module): 32 | def __init__( 33 | self, 34 | dim: int, 35 | n_heads: int, 36 | head_dim: int, 37 | n_kv_heads: int, 38 | lora: Optional[LoraArgs] = None, 39 | ): 40 | super().__init__() 41 | 42 | self.n_heads: int = n_heads 43 | self.head_dim: int = head_dim 44 | self.n_kv_heads: int = n_kv_heads 45 | 46 | self.repeats = self.n_heads // self.n_kv_heads 47 | 48 | self.scale = self.head_dim**-0.5 49 | 50 | MaybeLora = maybe_lora(lora) 51 | self.wq = MaybeLora(dim, n_heads * head_dim, bias=False) 52 | self.wk = MaybeLora(dim, n_kv_heads * head_dim, bias=False) 53 | self.wv = MaybeLora(dim, n_kv_heads * head_dim, bias=False) 54 | self.wo = MaybeLora(n_heads * head_dim, dim, bias=False) 55 | 56 | def forward( 57 | self, 58 | x: torch.Tensor, 59 | freqs_cis: torch.Tensor, 60 | cache: Optional[CacheView] = None, 61 | mask: Optional[BlockDiagonalMask] = None, 62 | ) -> torch.Tensor: 63 | assert mask is None or cache is None 64 | seqlen_sum, _ = x.shape 65 | 66 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 67 | xq = xq.view(seqlen_sum, self.n_heads, self.head_dim) 68 | xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim) 69 | xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim) 70 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) 71 | 72 | if cache is None: 73 | key, val = xk, xv 74 | elif cache.prefill: 75 | key, val = cache.interleave_kv(xk, xv) 76 | cache.update(xk, xv) 77 | else: 78 | cache.update(xk, xv) 79 | key, val = cache.key, cache.value 80 | key = key.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim) 81 | val = val.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim) 82 | 83 | # Repeat keys and values to match number of query heads 84 | key, val = repeat_kv(key, val, self.repeats, dim=1) 85 | 86 | # xformers requires (B=1, S, H, D) 87 | xq, key, val = xq[None, ...], key[None, ...], val[None, ...] 88 | output = memory_efficient_attention(xq, key, val, mask if cache is None else cache.mask) 89 | output = output.view(seqlen_sum, self.n_heads * self.head_dim) 90 | 91 | assert isinstance(output, torch.Tensor) 92 | 93 | return self.wo(output) # type: ignore 94 | 95 | 96 | class FeedForward(nn.Module): 97 | def __init__(self, dim: int, hidden_dim: int, lora: Optional[LoraArgs] = None): 98 | super().__init__() 99 | 100 | MaybeLora = maybe_lora(lora) 101 | self.w1 = MaybeLora(dim, hidden_dim, bias=False) 102 | self.w2 = MaybeLora(hidden_dim, dim, bias=False) 103 | self.w3 = MaybeLora(dim, hidden_dim, bias=False) 104 | 105 | def forward(self, x: torch.Tensor) -> torch.Tensor: 106 | return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore 107 | 108 | 109 | class RMSNorm(torch.nn.Module): 110 | def __init__(self, dim: int, eps: float = 1e-6): 111 | super().__init__() 112 | self.eps = eps 113 | self.weight = nn.Parameter(torch.ones(dim)) 114 | 115 | def _norm(self, x: torch.Tensor) -> torch.Tensor: 116 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 117 | 118 | def forward(self, x: torch.Tensor) -> torch.Tensor: 119 | output = self._norm(x.float()).type_as(x) 120 | return output * self.weight 121 | 122 | 123 | class TransformerBlock(nn.Module): 124 | def __init__( 125 | self, 126 | dim: int, 127 | hidden_dim: int, 128 | n_heads: int, 129 | n_kv_heads: int, 130 | head_dim: int, 131 | norm_eps: float, 132 | lora: Optional[LoraArgs] = None, 133 | moe: Optional[MoeArgs] = None, 134 | ): 135 | super().__init__() 136 | self.n_heads = n_heads 137 | self.dim = dim 138 | self.attention = Attention( 139 | dim=dim, 140 | n_heads=n_heads, 141 | head_dim=head_dim, 142 | n_kv_heads=n_kv_heads, 143 | lora=lora, 144 | ) 145 | self.attention_norm = RMSNorm(dim, eps=norm_eps) 146 | self.ffn_norm = RMSNorm(dim, eps=norm_eps) 147 | 148 | self.feed_forward: nn.Module 149 | if moe is not None: 150 | self.feed_forward = MoeLayer( 151 | experts=[FeedForward(dim=dim, hidden_dim=hidden_dim, lora=lora) for _ in range(moe.num_experts)], 152 | gate=nn.Linear(dim, moe.num_experts, bias=False), 153 | moe_args=moe, 154 | ) 155 | else: 156 | self.feed_forward = FeedForward(dim=dim, hidden_dim=hidden_dim, lora=lora) 157 | 158 | def forward( 159 | self, 160 | x: torch.Tensor, 161 | freqs_cis: torch.Tensor, 162 | cache: Optional[CacheView] = None, 163 | mask: Optional[BlockDiagonalMask] = None, 164 | ) -> torch.Tensor: 165 | r = self.attention.forward(self.attention_norm(x), freqs_cis, cache) 166 | h = x + r 167 | r = self.feed_forward.forward(self.ffn_norm(h)) 168 | out = h + r 169 | return out 170 | -------------------------------------------------------------------------------- /src/mistral-inference/src/mistral_inference/vision_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from xformers.ops.fmha.attn_bias import BlockDiagonalMask 6 | 7 | from mistral_inference.args import VisionEncoderArgs 8 | from mistral_inference.rope import precompute_freqs_cis_2d 9 | from mistral_inference.transformer_layers import RMSNorm, TransformerBlock 10 | 11 | 12 | def position_meshgrid( 13 | patch_embeds_list: list[torch.Tensor], 14 | ) -> torch.Tensor: 15 | positions = torch.cat( 16 | [ 17 | torch.stack( 18 | torch.meshgrid( 19 | torch.arange(p.shape[-2]), 20 | torch.arange(p.shape[-1]), 21 | indexing="ij", 22 | ), 23 | dim=-1, 24 | ).reshape(-1, 2) 25 | for p in patch_embeds_list 26 | ] 27 | ) 28 | return positions 29 | 30 | 31 | class VisionTransformer(nn.Module): 32 | def __init__(self, args: VisionEncoderArgs): 33 | super().__init__() 34 | self.args = args 35 | self.patch_conv = nn.Conv2d( 36 | in_channels=args.num_channels, 37 | out_channels=args.hidden_size, 38 | kernel_size=args.patch_size, 39 | stride=args.patch_size, 40 | bias=False, 41 | ) 42 | self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5) 43 | self.transformer = VisionTransformerBlocks(args) 44 | 45 | head_dim = self.args.hidden_size // self.args.num_attention_heads 46 | assert head_dim % 2 == 0, "ROPE requires even head_dim" 47 | self._freqs_cis: Optional[torch.Tensor] = None 48 | 49 | @property 50 | def max_patches_per_side(self) -> int: 51 | return self.args.image_size // self.args.patch_size 52 | 53 | @property 54 | def device(self) -> torch.device: 55 | return next(self.parameters()).device 56 | 57 | @property 58 | def freqs_cis(self) -> torch.Tensor: 59 | if self._freqs_cis is None: 60 | self._freqs_cis = precompute_freqs_cis_2d( 61 | dim=self.args.hidden_size // self.args.num_attention_heads, 62 | height=self.max_patches_per_side, 63 | width=self.max_patches_per_side, 64 | theta=self.args.rope_theta, 65 | ) 66 | 67 | if self._freqs_cis.device != self.device: 68 | self._freqs_cis = self._freqs_cis.to(device=self.device) 69 | 70 | return self._freqs_cis 71 | 72 | def forward( 73 | self, 74 | images: List[torch.Tensor], 75 | ) -> torch.Tensor: 76 | """ 77 | Args: 78 | images: list of N_img images of variable sizes, each of shape (C, H, W) 79 | 80 | Returns: 81 | image_features: tensor of token features for all tokens of all images of 82 | shape (N_toks, D) 83 | """ 84 | # pass images through initial convolution independently 85 | patch_embeds_list = [self.patch_conv(img.unsqueeze(0)).squeeze(0) for img in images] 86 | 87 | # flatten to a single sequence 88 | patch_embeds = torch.cat([p.flatten(1).permute(1, 0) for p in patch_embeds_list], dim=0) 89 | patch_embeds = self.ln_pre(patch_embeds) 90 | 91 | # positional embeddings 92 | positions = position_meshgrid(patch_embeds_list).to(self.device) 93 | freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] 94 | 95 | # pass through Transformer with a block diagonal mask delimiting images 96 | mask = BlockDiagonalMask.from_seqlens( 97 | [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], 98 | ) 99 | out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) 100 | 101 | # remove batch dimension of the single sequence 102 | return out # type: ignore[no-any-return] 103 | 104 | 105 | class VisionLanguageAdapter(nn.Module): 106 | def __init__(self, in_dim: int, out_dim: int): 107 | super().__init__() 108 | self.w_in = nn.Linear( 109 | in_dim, 110 | out_dim, 111 | bias=True, 112 | ) 113 | self.gelu = nn.GELU() 114 | self.w_out = nn.Linear(out_dim, out_dim, bias=True) 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | return self.w_out(self.gelu(self.w_in(x))) # type: ignore[no-any-return] 118 | 119 | 120 | class VisionTransformerBlocks(nn.Module): 121 | def __init__(self, args: VisionEncoderArgs): 122 | super().__init__() 123 | self.layers = torch.nn.ModuleList() 124 | for _ in range(args.num_hidden_layers): 125 | self.layers.append( 126 | TransformerBlock( 127 | dim=args.hidden_size, 128 | hidden_dim=args.intermediate_size, 129 | n_heads=args.num_attention_heads, 130 | n_kv_heads=args.num_attention_heads, 131 | head_dim=args.hidden_size // args.num_attention_heads, 132 | norm_eps=1e-5, 133 | ) 134 | ) 135 | 136 | def forward( 137 | self, 138 | x: torch.Tensor, 139 | mask: BlockDiagonalMask, 140 | freqs_cis: Optional[torch.Tensor], 141 | ) -> torch.Tensor: 142 | for layer in self.layers: 143 | x = layer(x, mask=mask, freqs_cis=freqs_cis) 144 | return x 145 | 146 | 147 | -------------------------------------------------------------------------------- /src/mistral-inference/tests/test_generate.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import torch 5 | from mistral_inference.args import VisionEncoderArgs 6 | from mistral_inference.generate import generate_mamba 7 | from mistral_inference.main import generate 8 | from mistral_inference.mamba import Mamba, MambaArgs 9 | from mistral_inference.transformer import Transformer, TransformerArgs 10 | 11 | 12 | class DebugTokenizer: 13 | @property 14 | def bos_id(self) -> int: 15 | return 0 16 | 17 | @property 18 | def eos_id(self) -> int: 19 | return 1 20 | 21 | @property 22 | def pad_id(self) -> int: 23 | return -1 24 | 25 | def encode(self, s: str, bos: bool = True) -> List[int]: 26 | assert isinstance(s, str) 27 | t = [int(x) for x in s.split()] 28 | if bos: 29 | t = [self.bos_id, *t] 30 | return t 31 | 32 | def decode(self, t: List[int]) -> str: 33 | return " ".join([str(x) for x in t]) 34 | 35 | 36 | def test_generation_transformer(): 37 | torch.manual_seed(42) 38 | 39 | sequences = ["1 2 3 4 5 6 7", "0 1 2", "12 13 14", "2 4 34"] 40 | args = TransformerArgs( 41 | dim=512, 42 | n_layers=1, 43 | head_dim=128, 44 | hidden_dim=2048, 45 | n_heads=4, 46 | n_kv_heads=2, 47 | norm_eps=1e-5, 48 | vocab_size=32_000, 49 | max_batch_size=len(sequences), 50 | ) 51 | model = Transformer(args).to("cuda", dtype=torch.float32) 52 | tokenizer = DebugTokenizer() 53 | 54 | encoded = [tokenizer.encode(s, bos=True) for s in sequences] 55 | toks, all_logprobs_old = generate(encoded, model, temperature=0.0, max_tokens=7) 56 | 57 | # concat generated and prompt 58 | encoded = [e + t for e, t in zip(encoded, toks)] 59 | 60 | generated, all_logprobs_new = generate( 61 | encoded, model, temperature=0.0, max_tokens=0 62 | ) 63 | 64 | assert generated == [] 65 | 66 | # Verify that logprobs are the same 67 | assert len(sequences) == len(all_logprobs_old) == len(all_logprobs_new) 68 | for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new): 69 | assert all( 70 | [abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)] 71 | ), f"\n{lp_old}\n{lp_new}" 72 | 73 | print("All tests passed.") 74 | 75 | 76 | def test_generation_pixtral(): 77 | torch.manual_seed(42) 78 | gen = np.random.default_rng(seed=42) 79 | 80 | sequences = ["1 2 2 2 2 4 5 6 7", "12 13 14", "2 2 2 2 7 8 9"] 81 | images = [[gen.normal(size=(3, 4, 4))], [], [gen.normal(size=(3, 4, 4))]] 82 | args = TransformerArgs( 83 | dim=512, 84 | n_layers=1, 85 | head_dim=128, 86 | hidden_dim=2048, 87 | n_heads=4, 88 | n_kv_heads=2, 89 | norm_eps=1e-5, 90 | vocab_size=32_000, 91 | max_batch_size=len(sequences), 92 | vision_encoder=VisionEncoderArgs( 93 | hidden_size=128, 94 | num_channels=3, 95 | image_size=4, 96 | patch_size=2, 97 | intermediate_size=256, 98 | num_hidden_layers=1, 99 | num_attention_heads=2, 100 | rope_theta=10000, 101 | image_token_id=2, 102 | ), 103 | ) 104 | model = Transformer(args).to("cuda", dtype=torch.float32) 105 | tokenizer = DebugTokenizer() 106 | 107 | encoded = [tokenizer.encode(s, bos=True) for s in sequences] 108 | toks, all_logprobs_old = generate( 109 | encoded, model, images=images, temperature=0.0, max_tokens=7 110 | ) 111 | 112 | # concat generated and prompt 113 | encoded = [e + t for e, t in zip(encoded, toks)] 114 | 115 | generated, all_logprobs_new = generate( 116 | encoded, model, images=images, temperature=0.0, max_tokens=0 117 | ) 118 | 119 | assert generated == [] 120 | 121 | # Verify that logprobs are the same 122 | assert len(sequences) == len(all_logprobs_old) == len(all_logprobs_new) 123 | for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new): 124 | assert all( 125 | [abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)] 126 | ), f"\n{lp_old}\n{lp_new}" 127 | 128 | print("All tests passed.") 129 | 130 | 131 | def test_generation_mamba(): 132 | torch.manual_seed(42) 133 | 134 | sequences = ["1 2 3 4 5 6 7"] 135 | args = MambaArgs( 136 | dim=512, 137 | n_layers=1, 138 | n_groups=1, 139 | rms_norm=True, 140 | residual_in_fp32=True, 141 | fused_add_norm=True, 142 | pad_vocab_size_multiple=1, 143 | tie_embeddings=False, 144 | vocab_size=32768, 145 | ) 146 | model = Mamba(args).to("cuda", dtype=torch.float32) 147 | tokenizer = DebugTokenizer() 148 | 149 | encoded = [tokenizer.encode(s, bos=True) for s in sequences] 150 | toks, all_logprobs_old = generate_mamba( 151 | encoded, model, temperature=0.0, max_tokens=7 152 | ) 153 | 154 | assert len(toks[0]) == 7 155 | assert toks == [[25574, 14821, 11843, 23698, 12735, 23522, 27542]] 156 | 157 | 158 | def test_chunks_transformer(): 159 | torch.manual_seed(42) 160 | 161 | sequences = [ 162 | " ".join([str(i) for i in range(7)]), 163 | " ".join([str(i) for i in range(9, 0, -1)]), 164 | ] 165 | args = TransformerArgs( 166 | dim=512, 167 | n_layers=1, 168 | head_dim=128, 169 | hidden_dim=2048, 170 | n_heads=4, 171 | n_kv_heads=2, 172 | norm_eps=1e-5, 173 | vocab_size=32_000, 174 | max_batch_size=3, 175 | ) 176 | model = Transformer(args).to("cuda", dtype=torch.float32) 177 | tokenizer = DebugTokenizer() 178 | 179 | encoded = [tokenizer.encode(s, bos=True) for s in sequences] 180 | toks, all_logprobs_old = generate(encoded, model, temperature=0.0, max_tokens=8) 181 | 182 | # concat generated and prompt 183 | encoded = [e + t for e, t in zip(encoded, toks)] 184 | 185 | generated, all_logprobs_new = generate( 186 | encoded, model, temperature=0.0, max_tokens=0, chunk_size=5 187 | ) 188 | assert len(generated) == 0 189 | 190 | for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new): 191 | assert all( 192 | [abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)] 193 | ), f"\n{lp_old}\n{lp_new}" 194 | -------------------------------------------------------------------------------- /src/mistral-inference/tutorials/getting_started.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Getting Started with `mistral-inference`\n", 8 | "\n", 9 | "This notebook will guide you through the process of running Mistral models locally. We will cover the following: \n", 10 | "- How to chat with Mistral 7B Instruct\n", 11 | "- How to run Mistral 7B Instruct with function calling capabilities\n", 12 | "\n", 13 | "We recommend using a GPU such as the A100 to run this notebook. " 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": { 20 | "id": "G6tXvIsQenpI" 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "!pip install mistral-inference" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "## Download Mistral 7B Instruct" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": { 38 | "colab": { 39 | "background_save": true 40 | }, 41 | "id": "4ytmRt0WQeMW" 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "!wget https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-Instruct-v0.3.tar" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": { 52 | "id": "eRZg_8wvs5A6" 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "!DIR=$HOME/mistral_7b_instruct_v3 && mkdir -p $DIR && tar -xf mistral-7B-Instruct-v0.3.tar -C $DIR" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": { 63 | "id": "7CN8gShDf65M" 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "!ls mistral_7b_instruct_v3" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "## Chat with the model" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "import os \n", 84 | "\n", 85 | "from mistral_inference.transformer import Transformer\n", 86 | "from mistral_inference.generate import generate\n", 87 | "\n", 88 | "from mistral_common.tokens.tokenizers.mistral import MistralTokenizer\n", 89 | "from mistral_common.protocol.instruct.messages import UserMessage\n", 90 | "from mistral_common.protocol.instruct.request import ChatCompletionRequest\n", 91 | "\n", 92 | "# load tokenizer\n", 93 | "mistral_tokenizer = MistralTokenizer.from_file(os.path.expanduser(\"~\")+\"/mistral_7b_instruct_v3/tokenizer.model.v3\")\n", 94 | "# chat completion request\n", 95 | "completion_request = ChatCompletionRequest(messages=[UserMessage(content=\"Explain Machine Learning to me in a nutshell.\")])\n", 96 | "# encode message\n", 97 | "tokens = mistral_tokenizer.encode_chat_completion(completion_request).tokens\n", 98 | "# load model\n", 99 | "model = Transformer.from_folder(os.path.expanduser(\"~\")+\"/mistral_7b_instruct_v3\")\n", 100 | "# generate results\n", 101 | "out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=mistral_tokenizer.instruct_tokenizer.tokenizer.eos_id)\n", 102 | "# decode generated tokens\n", 103 | "result = mistral_tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])\n", 104 | "print(result)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": { 110 | "id": "ce4woS3LkgZ9" 111 | }, 112 | "source": [ 113 | "## Function calling\n", 114 | "\n", 115 | "Mistral 7B Instruct v3 also supports function calling!" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": { 121 | "id": "TKfPiEwNk1kh" 122 | }, 123 | "source": [ 124 | "Let's start by creating a function calling example" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": { 131 | "id": "0PJdwvDEk3dl" 132 | }, 133 | "outputs": [], 134 | "source": [ 135 | "from mistral_common.protocol.instruct.messages import UserMessage\n", 136 | "from mistral_common.protocol.instruct.request import ChatCompletionRequest\n", 137 | "from mistral_common.protocol.instruct.tool_calls import Function, Tool\n", 138 | "\n", 139 | "completion_request = ChatCompletionRequest(\n", 140 | " tools=[\n", 141 | " Tool(\n", 142 | " function=Function(\n", 143 | " name=\"get_current_weather\",\n", 144 | " description=\"Get the current weather\",\n", 145 | " parameters={\n", 146 | " \"type\": \"object\",\n", 147 | " \"properties\": {\n", 148 | " \"location\": {\n", 149 | " \"type\": \"string\",\n", 150 | " \"description\": \"The city and state, e.g. San Francisco, CA\",\n", 151 | " },\n", 152 | " \"format\": {\n", 153 | " \"type\": \"string\",\n", 154 | " \"enum\": [\"celsius\", \"fahrenheit\"],\n", 155 | " \"description\": \"The temperature unit to use. Infer this from the users location.\",\n", 156 | " },\n", 157 | " },\n", 158 | " \"required\": [\"location\", \"format\"],\n", 159 | " },\n", 160 | " )\n", 161 | " )\n", 162 | " ],\n", 163 | " messages=[\n", 164 | " UserMessage(content=\"What's the weather like today in Paris?\"),\n", 165 | " ],\n", 166 | ")" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": { 172 | "id": "bG6ZeZUylpBW" 173 | }, 174 | "source": [ 175 | "Since we have already loaded the tokenizer and the model in the example above. We will skip these steps here. \n", 176 | "\n", 177 | "Now we can encode the message with our tokenizer using `MistralTokenizer`. " 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": { 184 | "id": "Ii8q-JNClwiq" 185 | }, 186 | "outputs": [], 187 | "source": [ 188 | "from mistral_common.tokens.tokenizers.mistral import MistralTokenizer\n", 189 | "\n", 190 | "tokens = mistral_tokenizer.encode_chat_completion(completion_request).tokens" 191 | ] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "metadata": { 196 | "id": "NrueDujkmJT4" 197 | }, 198 | "source": [ 199 | "and run `generate` to get a response. Don't forget to pass the EOS id!" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": { 206 | "id": "GWJYO43rl0V8" 207 | }, 208 | "outputs": [], 209 | "source": [ 210 | "from mistral_inference.generate import generate\n", 211 | "\n", 212 | "out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=mistral_tokenizer.instruct_tokenizer.tokenizer.eos_id)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": { 218 | "id": "v7baJ1msmPMv" 219 | }, 220 | "source": [ 221 | "Finally, we can decode the generated tokens." 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "metadata": { 228 | "id": "RKhryfBWmHon" 229 | }, 230 | "outputs": [], 231 | "source": [ 232 | "result = mistral_tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens)[0]\n", 233 | "result" 234 | ] 235 | } 236 | ], 237 | "metadata": { 238 | "accelerator": "GPU", 239 | "colab": { 240 | "gpuType": "L4", 241 | "machine_shape": "hm", 242 | "provenance": [] 243 | }, 244 | "kernelspec": { 245 | "display_name": "Python 3 (ipykernel)", 246 | "language": "python", 247 | "name": "python3" 248 | }, 249 | "language_info": { 250 | "codemirror_mode": { 251 | "name": "ipython", 252 | "version": 3 253 | }, 254 | "file_extension": ".py", 255 | "mimetype": "text/x-python", 256 | "name": "python", 257 | "nbconvert_exporter": "python", 258 | "pygments_lexer": "ipython3", 259 | "version": "3.11.8" 260 | } 261 | }, 262 | "nbformat": 4, 263 | "nbformat_minor": 4 264 | } 265 | -------------------------------------------------------------------------------- /src/single_query.py: -------------------------------------------------------------------------------- 1 | # This software was developed by Ivi Chatzi, Nina Corvelo Benz, Eleni Straitouri, Stratis Tsirtsis, and Manuel Gomez Rodriguez. 2 | # If you use this code, please cite the paper "Counterfactual Token Generation in Large Language Models" by the same authors. 3 | 4 | from typing import List, Optional 5 | 6 | import fire 7 | 8 | import sys 9 | import os 10 | import json 11 | import torch 12 | from sampler import Sampler 13 | 14 | # mistral stuff 15 | from mistral_common.tokens.tokenizers.base import Tokenizer 16 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer 17 | from mistral_common.protocol.instruct.request import ChatCompletionRequest 18 | from mistral_common.protocol.instruct.messages import ( 19 | AssistantMessage, 20 | UserMessage, 21 | SystemMessage 22 | ) 23 | from mistral_inference.generate import generate 24 | # Add the src directory to the Python path 25 | sys.path.append(os.path.abspath("src/mistral-inference/src")) 26 | from mistral_inference.main import load_tokenizer, get_model_cls 27 | from pathlib import Path 28 | 29 | # Add the src directory to the Python path 30 | sys.path.append(os.path.abspath("src/llama3")) 31 | from llama import Dialog, Llama 32 | 33 | def main( 34 | ckpt_dir: str, 35 | tokenizer_path: str, 36 | temperature: float = 0.6, 37 | top_p: float = 0.9, 38 | top_k: int = 10, 39 | sampler_type: str = 'vocabulary', 40 | max_seq_len: int = 512, 41 | max_batch_size: int = 4, 42 | max_gen_len: Optional[int] = None, 43 | seed: int = 42, 44 | query: str = "give me a recipe for moussaka", 45 | exp_name: str = "test", 46 | system: str = "Keep your replies short and to the point but don't give single word answers.", 47 | genstates: bool = False, 48 | model_family: str = "llama3", 49 | ): 50 | """ 51 | Examples to run with the models finetuned for chat. Prompts correspond of chat 52 | turns between the user and assistant with the final one always being the user. 53 | 54 | An optional system prompt at the beginning to control how the model should respond 55 | is also supported. 56 | 57 | The context window of llama3 models is 8192 tokens, so `max_seq_len` needs to be <= 8192. 58 | 59 | `max_gen_len` is optional because finetuned models are able to stop generations naturally. 60 | """ 61 | 62 | sampler = Sampler(sampler_type=sampler_type, top_p=top_p, top_k=top_k) 63 | 64 | if model_family == "llama3": 65 | generator = Llama.build( 66 | ckpt_dir=ckpt_dir, 67 | tokenizer_path=tokenizer_path, 68 | max_seq_len=max_seq_len, 69 | max_batch_size=max_batch_size, 70 | seed=seed 71 | ) 72 | 73 | dialogs: List[Dialog] = [ 74 | [ {"role": "system", "content": system}, 75 | {"role": "user", "content": query}] 76 | ] 77 | 78 | results, rngstates = generator.chat_completion( 79 | dialogs, 80 | max_gen_len=max_gen_len, 81 | temperature=temperature, 82 | sampler=sampler, 83 | genstates=genstates 84 | ) 85 | elif model_family == "mistral": 86 | 87 | # load the model 88 | mistral_tokenizer: MistralTokenizer = load_tokenizer(Path(ckpt_dir)) 89 | tokenizer: Tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer 90 | 91 | model_cls = get_model_cls(ckpt_dir) 92 | model = model_cls.from_folder(Path(ckpt_dir), max_batch_size=max_batch_size, num_pipeline_ranks=1) 93 | 94 | messages: List[SystemMessage | UserMessage | AssistantMessage] = [] 95 | messages += [SystemMessage(content=system)] 96 | messages += [UserMessage(content=query)] 97 | 98 | chat_completion_request = ChatCompletionRequest(messages=messages) 99 | tokenized = mistral_tokenizer.encode_chat_completion(chat_completion_request) 100 | tokens = tokenized.tokens 101 | 102 | # initialize a random number generator with the given seed 103 | rng = torch.Generator(device="cuda") 104 | rng.manual_seed(seed) 105 | 106 | generated_tokens, _, rngstates = generate( # type: ignore[operator] 107 | encoded_prompts = [tokens], 108 | model = model, 109 | max_tokens = max_seq_len, 110 | temperature = temperature, 111 | eos_id = tokenizer.eos_id, 112 | sampler = sampler, 113 | genstates = genstates, 114 | rng = rng 115 | ) 116 | 117 | if genstates: 118 | results = [ 119 | { 120 | "generation": { 121 | "role": "assistant", 122 | "content": tokenizer.decode(t), 123 | "token_list": [tokenizer.decode([x]) for x in t], 124 | }, 125 | } 126 | for t in generated_tokens 127 | ] 128 | 129 | output_dir = os.path.join("outputs", exp_name) 130 | # if a folder with the experiment name exists, delete its contents, otherwise create it 131 | if os.path.exists(output_dir): 132 | for file in os.listdir(output_dir): 133 | os.remove(os.path.join(output_dir, file)) 134 | else: 135 | os.makedirs(output_dir) 136 | 137 | if genstates: 138 | rngstates = torch.tensor(rngstates, dtype=torch.uint8) 139 | # save the rngstates tensor 140 | torch.save(rngstates, os.path.join(output_dir, 'rngstates_1.pt')) 141 | 142 | output = {} 143 | output["system"] = system 144 | output["query"] = query 145 | output["response"] = results[0]["generation"]["content"] 146 | output["seed"] = seed 147 | output["temperature"] = temperature 148 | output["sampler_type"] = sampler_type 149 | output["top_p"] = top_p 150 | output["top_k"] = top_k 151 | output["model_family"] = model_family 152 | 153 | # save the output of the factual generation 154 | with open(os.path.join(output_dir, "factual.json"), "w") as f: 155 | json.dump(output, f, indent=4) 156 | 157 | # prepare the intervention file 158 | output["ckpt_dir"] = ckpt_dir 159 | output["tokenizer_path"] = tokenizer_path 160 | output["start_from"] = 0 161 | output["token_list"] = {ind: tok for ind, tok in enumerate(results[0]["generation"]["token_list"])} 162 | with open(os.path.join(output_dir, "intervention_1.json"), "w") as f: 163 | json.dump(output, f, indent=4) 164 | 165 | print("----------------") 166 | print("SYSTEM:", system) 167 | print("----------------") 168 | print("QUERY:", query) 169 | print("----------------") 170 | print("RESPONSE:", results[0]["generation"]["content"]) 171 | print("----------------") 172 | 173 | 174 | if __name__ == "__main__": 175 | fire.Fire(main) 176 | 177 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # This software was developed by Ivi Chatzi, Nina Corvelo Benz, Eleni Straitouri, Stratis Tsirtsis, and Manuel Gomez Rodriguez. 2 | # If you use this code, please cite the paper "Counterfactual Token Generation in Large Language Models" by the same authors. 3 | 4 | import matplotlib.pyplot as plt 5 | import matplotlib 6 | import numpy as np 7 | 8 | def get_fig_dim(width, fraction=1, aspect_ratio=None): 9 | """Set figure dimensions to avoid scaling in LaTeX. 10 | 11 | Parameters 12 | ---------- 13 | width: float 14 | Document textwidth or columnwidth in pts 15 | fraction: float, optional 16 | Fraction of the width which you wish the figure to occupy 17 | aspect_ratio: float, optional 18 | Aspect ratio of the figure 19 | 20 | Returns 21 | ------- 22 | fig_dim: tuple 23 | Dimensions of figure in inches 24 | """ 25 | # Width of figure (in pts) 26 | fig_width_pt = width * fraction 27 | 28 | # Convert from pt to inches 29 | inches_per_pt = 1 / 72.27 30 | 31 | if aspect_ratio is None: 32 | # If not specified, set the aspect ratio equal to the Golden ratio (https://en.wikipedia.org/wiki/Golden_ratio) 33 | aspect_ratio = (1 + 5**.5) / 2 34 | 35 | # Figure width in inches 36 | fig_width_in = fig_width_pt * inches_per_pt 37 | # Figure height in inches 38 | fig_height_in = fig_width_in / aspect_ratio 39 | 40 | fig_dim = (fig_width_in, fig_height_in) 41 | 42 | return fig_dim 43 | 44 | 45 | def latexify(font_serif='Computer Modern', mathtext_font='cm', font_size=10, small_font_size=None, usetex=True): 46 | """Set up matplotlib's RC params for LaTeX plotting. 47 | Call this before plotting a figure. 48 | 49 | Parameters 50 | ---------- 51 | font_serif: string, optional 52 | Set the desired font family 53 | mathtext_font: float, optional 54 | Set the desired math font family 55 | font_size: int, optional 56 | Set the large font size 57 | small_font_size: int, optional 58 | Set the small font size 59 | usetex: boolean, optional 60 | Use tex for strings 61 | """ 62 | 63 | if small_font_size is None: 64 | small_font_size = font_size 65 | 66 | params = { 67 | 'backend': 'ps', 68 | 'text.latex.preamble': '\\usepackage{gensymb} \\usepackage{bm}', 69 | 70 | 'axes.labelsize': font_size, 71 | 'axes.titlesize': font_size, 72 | 'font.size': font_size, 73 | 74 | # Optionally set a smaller font size for legends and tick labels 75 | 'legend.fontsize': small_font_size, 76 | 'legend.title_fontsize': small_font_size, 77 | 'xtick.labelsize': small_font_size, 78 | 'ytick.labelsize': small_font_size, 79 | 80 | 'text.usetex': usetex, 81 | 'font.family' : 'serif', 82 | 'font.serif' : font_serif, 83 | 'mathtext.fontset' : mathtext_font 84 | } 85 | 86 | matplotlib.rcParams.update(params) 87 | plt.rcParams.update(params) 88 | --------------------------------------------------------------------------------