├── .gitignore
├── README.md
├── banner.png
├── data
└── README.md
├── figures
├── .placeholder
├── edit_distance
│ ├── k_vs_ed.pdf
│ ├── legend.pdf
│ ├── p_vs_ed.pdf
│ ├── tau_vs_ed.pdf
│ └── tau_vs_ed_categorical.pdf
├── llama3
│ ├── bias
│ │ ├── effects_Female_income_hist_median.pdf
│ │ ├── llama3_effects_Female_income_hist_median.pdf
│ │ ├── llama3_effects_Male_income_hist_median.pdf
│ │ ├── median_sex_income_direct_cf.pdf
│ │ ├── median_sex_income_total_cf.pdf
│ │ └── race_education_total.pdf
│ └── occupation
│ │ └── occupation_factual_Asian_American_cf_Black_or_African_American.pdf
└── mistral
│ ├── bias
│ ├── median_sex_income_direct_cf.pdf
│ ├── median_sex_income_total_cf.pdf
│ ├── mistral_effects_Female_income_hist_median.pdf
│ ├── mistral_effects_Male_income_hist_median.pdf
│ └── race_education_total.pdf
│ └── occupation
│ └── occupation_factual_Asian_American_cf_Black_or_African_American.pdf
├── notebooks
├── bias.ipynb
└── edit_distance.ipynb
├── outputs
├── .placeholder
├── llama3
│ ├── bias
│ │ ├── direct_Sex_cf_changes.json
│ │ ├── total_Race_cf_changes.json
│ │ └── total_Sex_cf_changes.json
│ ├── census1
│ │ ├── factual.json
│ │ ├── intervention_1.json
│ │ └── rngstates_1.pt
│ ├── census2
│ │ ├── factual.json
│ │ ├── intervention_1.json
│ │ └── rngstates_1.pt
│ └── census3
│ │ ├── factual.json
│ │ ├── intervention_1.json
│ │ └── rngstates_1.pt
├── mistral
│ ├── bias
│ │ ├── direct_Sex_cf_changes.json
│ │ ├── total_Race_cf_changes.json
│ │ └── total_Sex_cf_changes.json
│ ├── census1
│ │ ├── factual.json
│ │ ├── intervention_1.json
│ │ └── rngstates_1.pt
│ ├── census2
│ │ ├── factual.json
│ │ ├── intervention_1.json
│ │ └── rngstates_1.pt
│ └── census3
│ │ ├── factual.json
│ │ ├── intervention_1.json
│ │ └── rngstates_1.pt
├── story-edit-word
│ ├── counterfactual_1.json
│ ├── factual.json
│ ├── intervention_1.json
│ ├── intervention_2.json
│ └── rngstates_1.pt
├── story-interventional
│ ├── counterfactual_1.json
│ ├── factual.json
│ ├── intervention_1.json
│ ├── intervention_2.json
│ └── rngstates_1.pt
├── story-name
│ ├── counterfactual_1.json
│ ├── factual.json
│ ├── intervention_1.json
│ ├── intervention_2.json
│ └── rngstates_1.pt
├── story-remove-word
│ ├── counterfactual_1.json
│ ├── factual.json
│ ├── intervention_1.json
│ ├── intervention_2.json
│ └── rngstates_1.pt
└── story-ship
│ ├── counterfactual_1.json
│ ├── factual.json
│ ├── intervention_1.json
│ ├── intervention_2.json
│ └── rngstates_1.pt
├── requirements.txt
├── scripts
├── bias.py
├── census_queries.py
├── cf_query.py
├── stability.py
└── story_query.py
└── src
├── bias.py
├── cf_query.py
├── llama3
├── CODE_OF_CONDUCT.md
├── LICENSE
├── MODEL_CARD.md
├── README.md
├── USE_POLICY
├── USE_POLICY.md
├── llama
│ ├── __init__.py
│ ├── generation.py
│ ├── model.py
│ ├── sampler.py
│ ├── test_tokenizer.py
│ └── tokenizer.py
├── pretrained
│ └── .placeholder
├── requirements.txt
└── setup.py
├── mistral-inference
├── LICENSE
├── README.md
├── assets
│ └── smoe.png
├── deploy
│ ├── .dockerignore
│ ├── Dockerfile
│ └── entrypoint.sh
├── moe_one_file_ref.py
├── one_file_ref.py
├── poetry.lock
├── pyproject.toml
├── src
│ └── mistral_inference
│ │ ├── __init__.py
│ │ ├── args.py
│ │ ├── cache.py
│ │ ├── generate.py
│ │ ├── lora.py
│ │ ├── main.py
│ │ ├── mamba.py
│ │ ├── model.py
│ │ ├── moe.py
│ │ ├── rope.py
│ │ ├── transformer.py
│ │ ├── transformer_layers.py
│ │ └── vision_encoder.py
├── tests
│ └── test_generate.py
└── tutorials
│ ├── classifier.ipynb
│ └── getting_started.ipynb
├── sampler.py
├── single_query.py
├── stability.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
140 | # Ignore the outputs
141 | outputs/*json
142 |
143 | # Ignore the figures
144 | figures/*pdf
145 | figures/*png
146 | figures/*tif
147 |
148 | # Ignore the backups
149 | backups*
150 |
151 | # Ignore vscode things
152 | **.DS_Store
153 |
154 | # Ignore all subdirectories of src/llama3/downloads
155 | src/llama3/pretrained/Meta*
156 |
157 | # ignore all json files in data/ and its subdirectories
158 | data/**/*.json
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Counterfactual Token Generation in Large Language Models
2 |
3 | This repository contains the code used in the paper ["Counterfactual Token Generation in Large Language Models"](https://arxiv.org/abs/2409.17027) by Ivi Chatzi, Nina Corvelo Benz, Eleni Straitouri, Stratis Tsirtsis, and Manuel Gomez-Rodriguez.
4 |
5 | __Contents__:
6 | - [Introduction](#introduction)
7 | - [Dependencies](#dependencies)
8 | - [Repository structure](#repository-structure)
9 | - [Contact & attribution](#contact--attribution)
10 |
11 | ## Introduction
12 |
13 |
14 |

15 |
16 |
17 | *"Sure, I am happy to generate a story for you: Captain Lyra stood at the helm of her trusty ship, the Maelstrom’s Fury, gazing out at the endless sea. \[...\] Lyra’s eyes welled up with tears as she realized the bitter truth - she had sacrificed everything for fleeting riches, and lost the love of her crew, her family, and herself."* Although this story, generated by a large language model, is captivating, one may wonder—how would the story have unfolded if the model had chosen "Captain Maeve" as the protagonist instead? We cannot know. State-of-the-art large language models are stateless—they maintain no internal memory or state. Given a prompt, they generate a sequence of tokens as an output using an autoregressive process. As a consequence, they cannot reason about counterfactual alternatives to tokens they have generated in the past. In this work, our goal is to enhance them with this functionality. To this end, we develop a causal model of token generation that builds upon the Gumbel-Max structural causal model. Our model allows any large language model to perform counterfactual token generation at almost no cost in comparison with vanilla token generation, it is embarrassingly simple to implement, and it does not require any fine-tuning nor prompt engineering. We implement our model on Llama 3 8B-Instruct and Ministral-8B-Instruct, and conduct both qualitative and quantitative analyses of counterfactually generated text. We conclude with a demonstrative application of counterfactual token generation for bias detection, unveiling interesting insights about the model of the world constructed by large language models.
18 |
19 |
20 | ## Dependencies
21 |
22 | All the experiments were performed using Python 3.11.2. In order to create a virtual environment and install the project dependencies you can run the following commands:
23 |
24 | ```bash
25 | python3 -m venv env
26 | source env/bin/activate
27 | pip install -r requirements.txt
28 | ```
29 |
30 | Our code builds upon the popular open-weight large language models Llama 3 8B-instruct and Ministral-8B-instruct. For instructions regarding getting access to the weights of the Llama 3 model refer to the [Llama 3 GitHub repository](https://github.com/meta-llama/llama3) and for the weights of the Ministral model to the Mistral [online documentation](https://docs.mistral.ai/getting-started/models/weights/). In addition to the project dependencies of our own code mentioned above, make sure to first install the project dependencies of the Ministral and Llama 3 models.
31 |
32 | ## Repository structure
33 |
34 | ```
35 | ├── data
36 | ├── figures
37 | ├── notebooks
38 | ├── outputs
39 | │ ├── llama3
40 | │ │ ├── census*
41 | │ │ └── bias
42 | │ ├── mistral
43 | │ │ ├── census*
44 | │ │ └── bias
45 | │ └── story*
46 | ├── scripts
47 | │ ├── bias.py
48 | │ ├── census_queries.py
49 | │ ├── cf_query.py
50 | │ ├── stability.py
51 | │ └── story_query.py
52 | └── src
53 | ├── llama3
54 | │ ├── llama
55 | │ │ ├── generation.py
56 | │ │ └── sampler.py
57 | │ └── pretrained
58 | ├── mistral-inference
59 | │ └── src/mistral-inference
60 | │ ├── generate.py
61 | │ └── 8B-Instruct
62 | ├── bias.py
63 | ├── cf_query.py
64 | ├── sampler.py
65 | ├── single_query.py
66 | ├── stability.py
67 | └── utils.py
68 | ```
69 |
70 | - `data` contains configuration files for our experiments.
71 | - `figures` contains all the figures presented in the paper.
72 | - `notebooks` contains python notebooks to generate all the figures included in the paper.
73 | - `outputs/`, where `` is either `llama3` or `mistral`, contains intermediate output files generated by the experiments' scripts. Specifically:
74 | - `bias` contains the counterfactual census data of Section 4.3.
75 | - `census*` directories contain the factual census data of Section 4.3.
76 | - `outputs/story*` directories contain the results of Section 4.1 and Appendix A.
77 | - `scripts` contains a set of scripts used to run all the experiments presented in the paper.
78 | - `src` contains all the code necessary to reproduce the results in the paper. Specifically:
79 | - `llama3` contains the code of the LLama 3 8B-instruct. Therein:
80 | - `llama/generation.py` uses the LLM to perform factual/counterfactual token generation.
81 | - `llama/pretrained/` is a placeholder directory where the weights of the (pre-trained) LLM should be placed.
82 | - `mistral-inference` contains the code of the Ministral 8B-instruct. Therein:
83 | - `src/mistral-inference/generate.py` uses the LLM to perform factual/counterfactual token generation.
84 | - `src/mistral-inference/8B-Instruct/` is a placeholder directory where the weights of the (pre-trained) LLM should be placed.
85 | - `sampler.py` samples from a token distribution using a Gumbel-Max SCM or its top-p and top-k variants.
86 | - `bias.py` performs counterfactual and international token generation for the experiments of Section 4.3 using the LLM generated census data.
87 | - `cf_query.py` performs counterfactual token generation for a single query.
88 | - `single_query.py` performs factual token generation for a single query. It creates and saves its results in a subdirectory of `outputs/`, where `` is either `llama3` or `mistral`. The results are then used by `cf_query.py`.
89 | - `stability.py` performs interventional and counterfactual token generation for the experiments of Section 4.2.
90 | - `utils.py` contains auxiliary functions for plotting.
91 |
92 |
93 | ## Contact & attribution
94 |
95 | In case you have questions about the code, you identify potential bugs or you would like us to include additional functionalities, feel free to open an issue or contact [Ivi Chatzi](mailto:ichatzi@mpi-sws.org) or [Stratis Tsirtsis](mailto:stsirtsis@mpi-sws.org).
96 |
97 | If you use parts of the code in this repository for your own research, please consider citing:
98 |
99 | @article{chatzi2024counterfactual,
100 | title={Counterfactual Token Generation in Large Language Models},
101 | author={Ivi Chatzi and Nina Corvelo Benz and Eleni Straitouri and Stratis Tsirtsis and Manuel Gomez-Rodriguez},
102 | year={2024},
103 | journal={arXiv preprint arXiv:2409.17027}
104 | }
105 |
106 |
--------------------------------------------------------------------------------
/banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/banner.png
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | ## Dataset information
2 |
3 | We used questions from the [LMSYS Chat 1M dataset](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) for our text similarity experiment in section 4.2.
--------------------------------------------------------------------------------
/figures/.placeholder:
--------------------------------------------------------------------------------
1 | # Placeholder for figures
--------------------------------------------------------------------------------
/figures/edit_distance/k_vs_ed.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/edit_distance/k_vs_ed.pdf
--------------------------------------------------------------------------------
/figures/edit_distance/legend.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/edit_distance/legend.pdf
--------------------------------------------------------------------------------
/figures/edit_distance/p_vs_ed.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/edit_distance/p_vs_ed.pdf
--------------------------------------------------------------------------------
/figures/edit_distance/tau_vs_ed.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/edit_distance/tau_vs_ed.pdf
--------------------------------------------------------------------------------
/figures/edit_distance/tau_vs_ed_categorical.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/edit_distance/tau_vs_ed_categorical.pdf
--------------------------------------------------------------------------------
/figures/llama3/bias/effects_Female_income_hist_median.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/llama3/bias/effects_Female_income_hist_median.pdf
--------------------------------------------------------------------------------
/figures/llama3/bias/llama3_effects_Female_income_hist_median.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/llama3/bias/llama3_effects_Female_income_hist_median.pdf
--------------------------------------------------------------------------------
/figures/llama3/bias/llama3_effects_Male_income_hist_median.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/llama3/bias/llama3_effects_Male_income_hist_median.pdf
--------------------------------------------------------------------------------
/figures/llama3/bias/median_sex_income_direct_cf.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/llama3/bias/median_sex_income_direct_cf.pdf
--------------------------------------------------------------------------------
/figures/llama3/bias/median_sex_income_total_cf.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/llama3/bias/median_sex_income_total_cf.pdf
--------------------------------------------------------------------------------
/figures/llama3/bias/race_education_total.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/llama3/bias/race_education_total.pdf
--------------------------------------------------------------------------------
/figures/llama3/occupation/occupation_factual_Asian_American_cf_Black_or_African_American.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/llama3/occupation/occupation_factual_Asian_American_cf_Black_or_African_American.pdf
--------------------------------------------------------------------------------
/figures/mistral/bias/median_sex_income_direct_cf.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/mistral/bias/median_sex_income_direct_cf.pdf
--------------------------------------------------------------------------------
/figures/mistral/bias/median_sex_income_total_cf.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/mistral/bias/median_sex_income_total_cf.pdf
--------------------------------------------------------------------------------
/figures/mistral/bias/mistral_effects_Female_income_hist_median.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/mistral/bias/mistral_effects_Female_income_hist_median.pdf
--------------------------------------------------------------------------------
/figures/mistral/bias/mistral_effects_Male_income_hist_median.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/mistral/bias/mistral_effects_Male_income_hist_median.pdf
--------------------------------------------------------------------------------
/figures/mistral/bias/race_education_total.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/mistral/bias/race_education_total.pdf
--------------------------------------------------------------------------------
/figures/mistral/occupation/occupation_factual_Asian_American_cf_Black_or_African_American.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/figures/mistral/occupation/occupation_factual_Asian_American_cf_Black_or_African_American.pdf
--------------------------------------------------------------------------------
/outputs/.placeholder:
--------------------------------------------------------------------------------
1 | # Placeholder for outputs
--------------------------------------------------------------------------------
/outputs/llama3/census1/rngstates_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/llama3/census1/rngstates_1.pt
--------------------------------------------------------------------------------
/outputs/llama3/census2/rngstates_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/llama3/census2/rngstates_1.pt
--------------------------------------------------------------------------------
/outputs/llama3/census3/rngstates_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/llama3/census3/rngstates_1.pt
--------------------------------------------------------------------------------
/outputs/mistral/census1/rngstates_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/mistral/census1/rngstates_1.pt
--------------------------------------------------------------------------------
/outputs/mistral/census2/rngstates_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/mistral/census2/rngstates_1.pt
--------------------------------------------------------------------------------
/outputs/mistral/census3/rngstates_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/mistral/census3/rngstates_1.pt
--------------------------------------------------------------------------------
/outputs/story-edit-word/counterfactual_1.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless blue. For years, she'd sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she anchored the ship and stepped onto the sandy beach, a warm light enveloped her. With trembling hands, she reached out and touched the glittering treasure. The Golden Isle's magic surged through her, and Lyra's greatest desire took hold: to reunite with her lost love, who'd perished at sea many years ago.\n\nWith tears of joy streaming down her face, Lyra opened her eyes to find her beloved standing before her, alive and well. The two shared a long, tender kiss, surrounded by the sweet scent of tropical flowers. The \"Maelstrom's Fury\" sailed away, its crew rejoicing, as Lyra and her love disappeared into the sunset, their hearts full of happiness.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "top_p": 0.9,
8 | "top_k": 5,
9 | "sampler_type": "vocabulary",
10 | "start_from": 0,
11 | "token_list": {
12 | "0": "Captain",
13 | "1": " Ly",
14 | "2": "ra",
15 | "3": " stood",
16 | "4": " at",
17 | "5": " the",
18 | "6": " helm",
19 | "7": " of",
20 | "8": " her",
21 | "9": " trust",
22 | "10": "y",
23 | "11": " ship",
24 | "12": ",",
25 | "13": " the",
26 | "14": " \"",
27 | "15": "M",
28 | "16": "ael",
29 | "17": "strom",
30 | "18": "'s",
31 | "19": " Fury",
32 | "20": "\",",
33 | "21": " g",
34 | "22": "azing",
35 | "23": " out",
36 | "24": " at",
37 | "25": " the",
38 | "26": " endless",
39 | "27": " blue",
40 | "28": ".",
41 | "29": " For",
42 | "30": " years",
43 | "31": ",",
44 | "32": " she",
45 | "33": "'d",
46 | "34": " sailed",
47 | "35": " the",
48 | "36": " seven",
49 | "37": " seas",
50 | "38": ",",
51 | "39": " seeking",
52 | "40": " the",
53 | "41": " legendary",
54 | "42": " Golden",
55 | "43": " Isle",
56 | "44": ",",
57 | "45": " said",
58 | "46": " to",
59 | "47": " grant",
60 | "48": " a",
61 | "49": " single",
62 | "50": " wish",
63 | "51": " to",
64 | "52": " whoever",
65 | "53": " found",
66 | "54": " it",
67 | "55": ".",
68 | "56": " After",
69 | "57": " a",
70 | "58": " lifetime",
71 | "59": " of",
72 | "60": " searching",
73 | "61": ",",
74 | "62": " Ly",
75 | "63": "ra",
76 | "64": " finally",
77 | "65": " spotted",
78 | "66": " the",
79 | "67": " island",
80 | "68": " on",
81 | "69": " the",
82 | "70": " horizon",
83 | "71": ".\n\n",
84 | "72": "As",
85 | "73": " she",
86 | "74": " anchored",
87 | "75": " the",
88 | "76": " ship",
89 | "77": " and",
90 | "78": " stepped",
91 | "79": " onto",
92 | "80": " the",
93 | "81": " sandy",
94 | "82": " beach",
95 | "83": ",",
96 | "84": " a",
97 | "85": " warm",
98 | "86": " light",
99 | "87": " envelop",
100 | "88": "ed",
101 | "89": " her",
102 | "90": ".",
103 | "91": " With",
104 | "92": " trembling",
105 | "93": " hands",
106 | "94": ",",
107 | "95": " she",
108 | "96": " reached",
109 | "97": " out",
110 | "98": " and",
111 | "99": " touched",
112 | "100": " the",
113 | "101": " glitter",
114 | "102": "ing",
115 | "103": " treasure",
116 | "104": ".",
117 | "105": " The",
118 | "106": " Golden",
119 | "107": " Isle",
120 | "108": "'s",
121 | "109": " magic",
122 | "110": " surged",
123 | "111": " through",
124 | "112": " her",
125 | "113": ",",
126 | "114": " and",
127 | "115": " Ly",
128 | "116": "ra",
129 | "117": "'s",
130 | "118": " greatest",
131 | "119": " desire",
132 | "120": " took",
133 | "121": " hold",
134 | "122": ":",
135 | "123": " to",
136 | "124": " reun",
137 | "125": "ite",
138 | "126": " with",
139 | "127": " her",
140 | "128": " lost",
141 | "129": " love",
142 | "130": ",",
143 | "131": " who",
144 | "132": "'d",
145 | "133": " per",
146 | "134": "ished",
147 | "135": " at",
148 | "136": " sea",
149 | "137": " many",
150 | "138": " years",
151 | "139": " ago",
152 | "140": ".\n\n",
153 | "141": "With",
154 | "142": " tears",
155 | "143": " of",
156 | "144": " joy",
157 | "145": " streaming",
158 | "146": " down",
159 | "147": " her",
160 | "148": " face",
161 | "149": ",",
162 | "150": " Ly",
163 | "151": "ra",
164 | "152": " opened",
165 | "153": " her",
166 | "154": " eyes",
167 | "155": " to",
168 | "156": " find",
169 | "157": " her",
170 | "158": " beloved",
171 | "159": " standing",
172 | "160": " before",
173 | "161": " her",
174 | "162": ",",
175 | "163": " alive",
176 | "164": " and",
177 | "165": " well",
178 | "166": ".",
179 | "167": " The",
180 | "168": " two",
181 | "169": " shared",
182 | "170": " a",
183 | "171": " long",
184 | "172": ",",
185 | "173": " tender",
186 | "174": " kiss",
187 | "175": ",",
188 | "176": " surrounded",
189 | "177": " by",
190 | "178": " the",
191 | "179": " sweet",
192 | "180": " scent",
193 | "181": " of",
194 | "182": " tropical",
195 | "183": " flowers",
196 | "184": ".",
197 | "185": " The",
198 | "186": " \"",
199 | "187": "M",
200 | "188": "ael",
201 | "189": "strom",
202 | "190": "'s",
203 | "191": " Fury",
204 | "192": "\"",
205 | "193": " sailed",
206 | "194": " away",
207 | "195": ",",
208 | "196": " its",
209 | "197": " crew",
210 | "198": " rejo",
211 | "199": "icing",
212 | "200": ",",
213 | "201": " as",
214 | "202": " Ly",
215 | "203": "ra",
216 | "204": " and",
217 | "205": " her",
218 | "206": " love",
219 | "207": " disappeared",
220 | "208": " into",
221 | "209": " the",
222 | "210": " sunset",
223 | "211": ",",
224 | "212": " their",
225 | "213": " hearts",
226 | "214": " full",
227 | "215": " of",
228 | "216": " happiness",
229 | "217": "."
230 | }
231 | }
--------------------------------------------------------------------------------
/outputs/story-edit-word/factual.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold and jewels lay at her feet. Lyra's heart soared as she reached for the treasure. But, as she touched the gold, the isle began to fade away, leaving her with nothing but a whispered promise: \"The true treasure was the journey, not the gold.\" Lyra's eyes welled up with tears as she realized the bitter truth \u2013 she had sacrificed everything for fleeting riches, and lost the love of her crew, her family, and herself. The \"Maelstrom's Fury\" sailed away, a lonely shadow on the vast ocean, carrying a captain forever haunted by the what-ifs.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "sampler_type": "vocabulary",
8 | "top_p": 0.9,
9 | "top_k": 5
10 | }
--------------------------------------------------------------------------------
/outputs/story-edit-word/intervention_1.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless blue.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "sampler_type": "vocabulary",
8 | "top_p": 0.9,
9 | "top_k": 5,
10 | "start_from": 29,
11 | "token_list": {
12 | "0": "Captain",
13 | "1": " Ly",
14 | "2": "ra",
15 | "3": " stood",
16 | "4": " at",
17 | "5": " the",
18 | "6": " helm",
19 | "7": " of",
20 | "8": " her",
21 | "9": " trust",
22 | "10": "y",
23 | "11": " ship",
24 | "12": ",",
25 | "13": " the",
26 | "14": " \"",
27 | "15": "M",
28 | "16": "ael",
29 | "17": "strom",
30 | "18": "'s",
31 | "19": " Fury",
32 | "20": "\",",
33 | "21": " g",
34 | "22": "azing",
35 | "23": " out",
36 | "24": " at",
37 | "25": " the",
38 | "26": " endless",
39 | "27": " sea",
40 | "28": ".",
41 | "29": " For",
42 | "30": " years",
43 | "31": ",",
44 | "32": " she",
45 | "33": " had",
46 | "34": " sailed",
47 | "35": " the",
48 | "36": " seven",
49 | "37": " seas",
50 | "38": ",",
51 | "39": " seeking",
52 | "40": " the",
53 | "41": " legendary",
54 | "42": " Golden",
55 | "43": " Isle",
56 | "44": ",",
57 | "45": " said",
58 | "46": " to",
59 | "47": " grant",
60 | "48": " a",
61 | "49": " single",
62 | "50": " wish",
63 | "51": " to",
64 | "52": " whoever",
65 | "53": " found",
66 | "54": " it",
67 | "55": ".",
68 | "56": " After",
69 | "57": " a",
70 | "58": " lifetime",
71 | "59": " of",
72 | "60": " searching",
73 | "61": ",",
74 | "62": " Ly",
75 | "63": "ra",
76 | "64": " finally",
77 | "65": " spotted",
78 | "66": " the",
79 | "67": " island",
80 | "68": " on",
81 | "69": " the",
82 | "70": " horizon",
83 | "71": ".\n\n",
84 | "72": "As",
85 | "73": " she",
86 | "74": " stepped",
87 | "75": " foot",
88 | "76": " on",
89 | "77": " its",
90 | "78": " sandy",
91 | "79": " shores",
92 | "80": ",",
93 | "81": " a",
94 | "82": " warm",
95 | "83": " glow",
96 | "84": " envelop",
97 | "85": "ed",
98 | "86": " her",
99 | "87": ".",
100 | "88": " Before",
101 | "89": " her",
102 | "90": ",",
103 | "91": " a",
104 | "92": " chest",
105 | "93": " overflowing",
106 | "94": " with",
107 | "95": " glitter",
108 | "96": "ing",
109 | "97": " gold",
110 | "98": " and",
111 | "99": " jewels",
112 | "100": " lay",
113 | "101": " at",
114 | "102": " her",
115 | "103": " feet",
116 | "104": ".",
117 | "105": " Ly",
118 | "106": "ra",
119 | "107": "'s",
120 | "108": " heart",
121 | "109": " soared",
122 | "110": " as",
123 | "111": " she",
124 | "112": " reached",
125 | "113": " for",
126 | "114": " the",
127 | "115": " treasure",
128 | "116": ".",
129 | "117": " But",
130 | "118": ",",
131 | "119": " as",
132 | "120": " she",
133 | "121": " touched",
134 | "122": " the",
135 | "123": " gold",
136 | "124": ",",
137 | "125": " the",
138 | "126": " is",
139 | "127": "le",
140 | "128": " began",
141 | "129": " to",
142 | "130": " fade",
143 | "131": " away",
144 | "132": ",",
145 | "133": " leaving",
146 | "134": " her",
147 | "135": " with",
148 | "136": " nothing",
149 | "137": " but",
150 | "138": " a",
151 | "139": " whispered",
152 | "140": " promise",
153 | "141": ":",
154 | "142": " \"",
155 | "143": "The",
156 | "144": " true",
157 | "145": " treasure",
158 | "146": " was",
159 | "147": " the",
160 | "148": " journey",
161 | "149": ",",
162 | "150": " not",
163 | "151": " the",
164 | "152": " gold",
165 | "153": ".\"",
166 | "154": " Ly",
167 | "155": "ra",
168 | "156": "'s",
169 | "157": " eyes",
170 | "158": " wel",
171 | "159": "led",
172 | "160": " up",
173 | "161": " with",
174 | "162": " tears",
175 | "163": " as",
176 | "164": " she",
177 | "165": " realized",
178 | "166": " the",
179 | "167": " bitter",
180 | "168": " truth",
181 | "169": " \u2013",
182 | "170": " she",
183 | "171": " had",
184 | "172": " sacrificed",
185 | "173": " everything",
186 | "174": " for",
187 | "175": " fleeting",
188 | "176": " riches",
189 | "177": ",",
190 | "178": " and",
191 | "179": " lost",
192 | "180": " the",
193 | "181": " love",
194 | "182": " of",
195 | "183": " her",
196 | "184": " crew",
197 | "185": ",",
198 | "186": " her",
199 | "187": " family",
200 | "188": ",",
201 | "189": " and",
202 | "190": " herself",
203 | "191": ".",
204 | "192": " The",
205 | "193": " \"",
206 | "194": "M",
207 | "195": "ael",
208 | "196": "strom",
209 | "197": "'s",
210 | "198": " Fury",
211 | "199": "\"",
212 | "200": " sailed",
213 | "201": " away",
214 | "202": ",",
215 | "203": " a",
216 | "204": " lonely",
217 | "205": " shadow",
218 | "206": " on",
219 | "207": " the",
220 | "208": " vast",
221 | "209": " ocean",
222 | "210": ",",
223 | "211": " carrying",
224 | "212": " a",
225 | "213": " captain",
226 | "214": " forever",
227 | "215": " haunted",
228 | "216": " by",
229 | "217": " the",
230 | "218": " what",
231 | "219": "-",
232 | "220": "ifs",
233 | "221": "."
234 | }
235 | }
--------------------------------------------------------------------------------
/outputs/story-edit-word/intervention_2.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless blue. For years, she'd sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she anchored the ship and stepped onto the sandy beach, a warm light enveloped her. With trembling hands, she reached out and touched the glittering treasure. The Golden Isle's magic surged through her, and Lyra's greatest desire took hold: to reunite with her lost love, who'd perished at sea many years ago.\n\nWith tears of joy streaming down her face, Lyra opened her eyes to find her beloved standing before her, alive and well. The two shared a long, tender kiss, surrounded by the sweet scent of tropical flowers. The \"Maelstrom's Fury\" sailed away, its crew rejoicing, as Lyra and her love disappeared into the sunset, their hearts full of happiness.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "top_p": 0.9,
8 | "top_k": 5,
9 | "sampler_type": "vocabulary",
10 | "start_from": 0,
11 | "token_list": {
12 | "0": "Captain",
13 | "1": " Ly",
14 | "2": "ra",
15 | "3": " stood",
16 | "4": " at",
17 | "5": " the",
18 | "6": " helm",
19 | "7": " of",
20 | "8": " her",
21 | "9": " trust",
22 | "10": "y",
23 | "11": " ship",
24 | "12": ",",
25 | "13": " the",
26 | "14": " \"",
27 | "15": "M",
28 | "16": "ael",
29 | "17": "strom",
30 | "18": "'s",
31 | "19": " Fury",
32 | "20": "\",",
33 | "21": " g",
34 | "22": "azing",
35 | "23": " out",
36 | "24": " at",
37 | "25": " the",
38 | "26": " endless",
39 | "27": " blue",
40 | "28": ".",
41 | "29": " For",
42 | "30": " years",
43 | "31": ",",
44 | "32": " she",
45 | "33": "'d",
46 | "34": " sailed",
47 | "35": " the",
48 | "36": " seven",
49 | "37": " seas",
50 | "38": ",",
51 | "39": " seeking",
52 | "40": " the",
53 | "41": " legendary",
54 | "42": " Golden",
55 | "43": " Isle",
56 | "44": ",",
57 | "45": " said",
58 | "46": " to",
59 | "47": " grant",
60 | "48": " a",
61 | "49": " single",
62 | "50": " wish",
63 | "51": " to",
64 | "52": " whoever",
65 | "53": " found",
66 | "54": " it",
67 | "55": ".",
68 | "56": " After",
69 | "57": " a",
70 | "58": " lifetime",
71 | "59": " of",
72 | "60": " searching",
73 | "61": ",",
74 | "62": " Ly",
75 | "63": "ra",
76 | "64": " finally",
77 | "65": " spotted",
78 | "66": " the",
79 | "67": " island",
80 | "68": " on",
81 | "69": " the",
82 | "70": " horizon",
83 | "71": ".\n\n",
84 | "72": "As",
85 | "73": " she",
86 | "74": " anchored",
87 | "75": " the",
88 | "76": " ship",
89 | "77": " and",
90 | "78": " stepped",
91 | "79": " onto",
92 | "80": " the",
93 | "81": " sandy",
94 | "82": " beach",
95 | "83": ",",
96 | "84": " a",
97 | "85": " warm",
98 | "86": " light",
99 | "87": " envelop",
100 | "88": "ed",
101 | "89": " her",
102 | "90": ".",
103 | "91": " With",
104 | "92": " trembling",
105 | "93": " hands",
106 | "94": ",",
107 | "95": " she",
108 | "96": " reached",
109 | "97": " out",
110 | "98": " and",
111 | "99": " touched",
112 | "100": " the",
113 | "101": " glitter",
114 | "102": "ing",
115 | "103": " treasure",
116 | "104": ".",
117 | "105": " The",
118 | "106": " Golden",
119 | "107": " Isle",
120 | "108": "'s",
121 | "109": " magic",
122 | "110": " surged",
123 | "111": " through",
124 | "112": " her",
125 | "113": ",",
126 | "114": " and",
127 | "115": " Ly",
128 | "116": "ra",
129 | "117": "'s",
130 | "118": " greatest",
131 | "119": " desire",
132 | "120": " took",
133 | "121": " hold",
134 | "122": ":",
135 | "123": " to",
136 | "124": " reun",
137 | "125": "ite",
138 | "126": " with",
139 | "127": " her",
140 | "128": " lost",
141 | "129": " love",
142 | "130": ",",
143 | "131": " who",
144 | "132": "'d",
145 | "133": " per",
146 | "134": "ished",
147 | "135": " at",
148 | "136": " sea",
149 | "137": " many",
150 | "138": " years",
151 | "139": " ago",
152 | "140": ".\n\n",
153 | "141": "With",
154 | "142": " tears",
155 | "143": " of",
156 | "144": " joy",
157 | "145": " streaming",
158 | "146": " down",
159 | "147": " her",
160 | "148": " face",
161 | "149": ",",
162 | "150": " Ly",
163 | "151": "ra",
164 | "152": " opened",
165 | "153": " her",
166 | "154": " eyes",
167 | "155": " to",
168 | "156": " find",
169 | "157": " her",
170 | "158": " beloved",
171 | "159": " standing",
172 | "160": " before",
173 | "161": " her",
174 | "162": ",",
175 | "163": " alive",
176 | "164": " and",
177 | "165": " well",
178 | "166": ".",
179 | "167": " The",
180 | "168": " two",
181 | "169": " shared",
182 | "170": " a",
183 | "171": " long",
184 | "172": ",",
185 | "173": " tender",
186 | "174": " kiss",
187 | "175": ",",
188 | "176": " surrounded",
189 | "177": " by",
190 | "178": " the",
191 | "179": " sweet",
192 | "180": " scent",
193 | "181": " of",
194 | "182": " tropical",
195 | "183": " flowers",
196 | "184": ".",
197 | "185": " The",
198 | "186": " \"",
199 | "187": "M",
200 | "188": "ael",
201 | "189": "strom",
202 | "190": "'s",
203 | "191": " Fury",
204 | "192": "\"",
205 | "193": " sailed",
206 | "194": " away",
207 | "195": ",",
208 | "196": " its",
209 | "197": " crew",
210 | "198": " rejo",
211 | "199": "icing",
212 | "200": ",",
213 | "201": " as",
214 | "202": " Ly",
215 | "203": "ra",
216 | "204": " and",
217 | "205": " her",
218 | "206": " love",
219 | "207": " disappeared",
220 | "208": " into",
221 | "209": " the",
222 | "210": " sunset",
223 | "211": ",",
224 | "212": " their",
225 | "213": " hearts",
226 | "214": " full",
227 | "215": " of",
228 | "216": " happiness",
229 | "217": "."
230 | }
231 | }
--------------------------------------------------------------------------------
/outputs/story-edit-word/rngstates_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/story-edit-word/rngstates_1.pt
--------------------------------------------------------------------------------
/outputs/story-interventional/counterfactual_1.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Maeve stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she'd sailed the waves, searching for the fabled Golden Isle, where legend whispered the sea goddess herself granted eternal youth to those who found it.\n\nAs the sun dipped into the horizon, Maeve spotted it \u2013 a shimmering, golden light on the horizon. With a cry of triumph, she steered the ship towards it. As they reached the shore, Maeve leapt out, her eyes drinking in the beauty of the island. And then, she saw her \u2013 the goddess, bathing in the golden light. But, as Maeve approached, the goddess vanished, leaving behind a small, intricately carved box. With a heavy heart, Maeve opened it, finding the island's secret: the price of eternal youth was to never find it.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "top_p": 0.9,
8 | "top_k": 5,
9 | "sampler_type": "vocabulary",
10 | "start_from": 0,
11 | "token_list": {
12 | "0": "Captain",
13 | "1": " Mae",
14 | "2": "ve",
15 | "3": " stood",
16 | "4": " at",
17 | "5": " the",
18 | "6": " helm",
19 | "7": " of",
20 | "8": " her",
21 | "9": " trust",
22 | "10": "y",
23 | "11": " ship",
24 | "12": ",",
25 | "13": " the",
26 | "14": " \"",
27 | "15": "M",
28 | "16": "ael",
29 | "17": "strom",
30 | "18": "'s",
31 | "19": " Fury",
32 | "20": "\",",
33 | "21": " g",
34 | "22": "azing",
35 | "23": " out",
36 | "24": " at",
37 | "25": " the",
38 | "26": " endless",
39 | "27": " sea",
40 | "28": ".",
41 | "29": " For",
42 | "30": " years",
43 | "31": ",",
44 | "32": " she",
45 | "33": "'d",
46 | "34": " sailed",
47 | "35": " the",
48 | "36": " waves",
49 | "37": ",",
50 | "38": " searching",
51 | "39": " for",
52 | "40": " the",
53 | "41": " f",
54 | "42": "abled",
55 | "43": " Golden",
56 | "44": " Isle",
57 | "45": ",",
58 | "46": " where",
59 | "47": " legend",
60 | "48": " whispered",
61 | "49": " the",
62 | "50": " sea",
63 | "51": " goddess",
64 | "52": " herself",
65 | "53": " granted",
66 | "54": " eternal",
67 | "55": " youth",
68 | "56": " to",
69 | "57": " those",
70 | "58": " who",
71 | "59": " found",
72 | "60": " it",
73 | "61": ".\n\n",
74 | "62": "As",
75 | "63": " the",
76 | "64": " sun",
77 | "65": " dipped",
78 | "66": " into",
79 | "67": " the",
80 | "68": " horizon",
81 | "69": ",",
82 | "70": " Mae",
83 | "71": "ve",
84 | "72": " spotted",
85 | "73": " it",
86 | "74": " \u2013",
87 | "75": " a",
88 | "76": " shimmer",
89 | "77": "ing",
90 | "78": ",",
91 | "79": " golden",
92 | "80": " light",
93 | "81": " on",
94 | "82": " the",
95 | "83": " horizon",
96 | "84": ".",
97 | "85": " With",
98 | "86": " a",
99 | "87": " cry",
100 | "88": " of",
101 | "89": " triumph",
102 | "90": ",",
103 | "91": " she",
104 | "92": " ste",
105 | "93": "ered",
106 | "94": " the",
107 | "95": " ship",
108 | "96": " towards",
109 | "97": " it",
110 | "98": ".",
111 | "99": " As",
112 | "100": " they",
113 | "101": " reached",
114 | "102": " the",
115 | "103": " shore",
116 | "104": ",",
117 | "105": " Mae",
118 | "106": "ve",
119 | "107": " le",
120 | "108": "apt",
121 | "109": " out",
122 | "110": ",",
123 | "111": " her",
124 | "112": " eyes",
125 | "113": " drinking",
126 | "114": " in",
127 | "115": " the",
128 | "116": " beauty",
129 | "117": " of",
130 | "118": " the",
131 | "119": " island",
132 | "120": ".",
133 | "121": " And",
134 | "122": " then",
135 | "123": ",",
136 | "124": " she",
137 | "125": " saw",
138 | "126": " her",
139 | "127": " \u2013",
140 | "128": " the",
141 | "129": " goddess",
142 | "130": ",",
143 | "131": " bathing",
144 | "132": " in",
145 | "133": " the",
146 | "134": " golden",
147 | "135": " light",
148 | "136": ".",
149 | "137": " But",
150 | "138": ",",
151 | "139": " as",
152 | "140": " Mae",
153 | "141": "ve",
154 | "142": " approached",
155 | "143": ",",
156 | "144": " the",
157 | "145": " goddess",
158 | "146": " vanished",
159 | "147": ",",
160 | "148": " leaving",
161 | "149": " behind",
162 | "150": " a",
163 | "151": " small",
164 | "152": ",",
165 | "153": " intric",
166 | "154": "ately",
167 | "155": " carved",
168 | "156": " box",
169 | "157": ".",
170 | "158": " With",
171 | "159": " a",
172 | "160": " heavy",
173 | "161": " heart",
174 | "162": ",",
175 | "163": " Mae",
176 | "164": "ve",
177 | "165": " opened",
178 | "166": " it",
179 | "167": ",",
180 | "168": " finding",
181 | "169": " the",
182 | "170": " island",
183 | "171": "'s",
184 | "172": " secret",
185 | "173": ":",
186 | "174": " the",
187 | "175": " price",
188 | "176": " of",
189 | "177": " eternal",
190 | "178": " youth",
191 | "179": " was",
192 | "180": " to",
193 | "181": " never",
194 | "182": " find",
195 | "183": " it",
196 | "184": "."
197 | }
198 | }
--------------------------------------------------------------------------------
/outputs/story-interventional/factual.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold and jewels lay at her feet. Lyra's heart soared as she reached for the treasure. But, as she touched the gold, the isle began to fade away, leaving her with nothing but a whispered promise: \"The true treasure was the journey, not the gold.\" Lyra's eyes welled up with tears as she realized the bitter truth \u2013 she had sacrificed everything for fleeting riches, and lost the love of her crew, her family, and herself. The \"Maelstrom's Fury\" sailed away, a lonely shadow on the vast ocean, carrying a captain forever haunted by the what-ifs.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "sampler_type": "vocabulary",
8 | "top_p": 0.9,
9 | "top_k": 5
10 | }
--------------------------------------------------------------------------------
/outputs/story-interventional/intervention_1.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Maeve stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "sampler_type": "vocabulary",
8 | "top_p": 0.9,
9 | "top_k": 5,
10 | "start_from": 29,
11 | "token_list": {
12 | "0": "Captain",
13 | "1": " Ly",
14 | "2": "ra",
15 | "3": " stood",
16 | "4": " at",
17 | "5": " the",
18 | "6": " helm",
19 | "7": " of",
20 | "8": " her",
21 | "9": " trust",
22 | "10": "y",
23 | "11": " ship",
24 | "12": ",",
25 | "13": " the",
26 | "14": " \"",
27 | "15": "M",
28 | "16": "ael",
29 | "17": "strom",
30 | "18": "'s",
31 | "19": " Fury",
32 | "20": "\",",
33 | "21": " g",
34 | "22": "azing",
35 | "23": " out",
36 | "24": " at",
37 | "25": " the",
38 | "26": " endless",
39 | "27": " sea",
40 | "28": ".",
41 | "29": " For",
42 | "30": " years",
43 | "31": ",",
44 | "32": " she",
45 | "33": " had",
46 | "34": " sailed",
47 | "35": " the",
48 | "36": " seven",
49 | "37": " seas",
50 | "38": ",",
51 | "39": " seeking",
52 | "40": " the",
53 | "41": " legendary",
54 | "42": " Golden",
55 | "43": " Isle",
56 | "44": ",",
57 | "45": " said",
58 | "46": " to",
59 | "47": " grant",
60 | "48": " a",
61 | "49": " single",
62 | "50": " wish",
63 | "51": " to",
64 | "52": " whoever",
65 | "53": " found",
66 | "54": " it",
67 | "55": ".",
68 | "56": " After",
69 | "57": " a",
70 | "58": " lifetime",
71 | "59": " of",
72 | "60": " searching",
73 | "61": ",",
74 | "62": " Ly",
75 | "63": "ra",
76 | "64": " finally",
77 | "65": " spotted",
78 | "66": " the",
79 | "67": " island",
80 | "68": " on",
81 | "69": " the",
82 | "70": " horizon",
83 | "71": ".\n\n",
84 | "72": "As",
85 | "73": " she",
86 | "74": " stepped",
87 | "75": " foot",
88 | "76": " on",
89 | "77": " its",
90 | "78": " sandy",
91 | "79": " shores",
92 | "80": ",",
93 | "81": " a",
94 | "82": " warm",
95 | "83": " glow",
96 | "84": " envelop",
97 | "85": "ed",
98 | "86": " her",
99 | "87": ".",
100 | "88": " Before",
101 | "89": " her",
102 | "90": ",",
103 | "91": " a",
104 | "92": " chest",
105 | "93": " overflowing",
106 | "94": " with",
107 | "95": " glitter",
108 | "96": "ing",
109 | "97": " gold",
110 | "98": " and",
111 | "99": " jewels",
112 | "100": " lay",
113 | "101": " at",
114 | "102": " her",
115 | "103": " feet",
116 | "104": ".",
117 | "105": " Ly",
118 | "106": "ra",
119 | "107": "'s",
120 | "108": " heart",
121 | "109": " soared",
122 | "110": " as",
123 | "111": " she",
124 | "112": " reached",
125 | "113": " for",
126 | "114": " the",
127 | "115": " treasure",
128 | "116": ".",
129 | "117": " But",
130 | "118": ",",
131 | "119": " as",
132 | "120": " she",
133 | "121": " touched",
134 | "122": " the",
135 | "123": " gold",
136 | "124": ",",
137 | "125": " the",
138 | "126": " is",
139 | "127": "le",
140 | "128": " began",
141 | "129": " to",
142 | "130": " fade",
143 | "131": " away",
144 | "132": ",",
145 | "133": " leaving",
146 | "134": " her",
147 | "135": " with",
148 | "136": " nothing",
149 | "137": " but",
150 | "138": " a",
151 | "139": " whispered",
152 | "140": " promise",
153 | "141": ":",
154 | "142": " \"",
155 | "143": "The",
156 | "144": " true",
157 | "145": " treasure",
158 | "146": " was",
159 | "147": " the",
160 | "148": " journey",
161 | "149": ",",
162 | "150": " not",
163 | "151": " the",
164 | "152": " gold",
165 | "153": ".\"",
166 | "154": " Ly",
167 | "155": "ra",
168 | "156": "'s",
169 | "157": " eyes",
170 | "158": " wel",
171 | "159": "led",
172 | "160": " up",
173 | "161": " with",
174 | "162": " tears",
175 | "163": " as",
176 | "164": " she",
177 | "165": " realized",
178 | "166": " the",
179 | "167": " bitter",
180 | "168": " truth",
181 | "169": " \u2013",
182 | "170": " she",
183 | "171": " had",
184 | "172": " sacrificed",
185 | "173": " everything",
186 | "174": " for",
187 | "175": " fleeting",
188 | "176": " riches",
189 | "177": ",",
190 | "178": " and",
191 | "179": " lost",
192 | "180": " the",
193 | "181": " love",
194 | "182": " of",
195 | "183": " her",
196 | "184": " crew",
197 | "185": ",",
198 | "186": " her",
199 | "187": " family",
200 | "188": ",",
201 | "189": " and",
202 | "190": " herself",
203 | "191": ".",
204 | "192": " The",
205 | "193": " \"",
206 | "194": "M",
207 | "195": "ael",
208 | "196": "strom",
209 | "197": "'s",
210 | "198": " Fury",
211 | "199": "\"",
212 | "200": " sailed",
213 | "201": " away",
214 | "202": ",",
215 | "203": " a",
216 | "204": " lonely",
217 | "205": " shadow",
218 | "206": " on",
219 | "207": " the",
220 | "208": " vast",
221 | "209": " ocean",
222 | "210": ",",
223 | "211": " carrying",
224 | "212": " a",
225 | "213": " captain",
226 | "214": " forever",
227 | "215": " haunted",
228 | "216": " by",
229 | "217": " the",
230 | "218": " what",
231 | "219": "-",
232 | "220": "ifs",
233 | "221": "."
234 | }
235 | }
--------------------------------------------------------------------------------
/outputs/story-interventional/intervention_2.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Maeve stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she'd sailed the waves, searching for the fabled Golden Isle, where legend whispered the sea goddess herself granted eternal youth to those who found it.\n\nAs the sun dipped into the horizon, Maeve spotted it \u2013 a shimmering, golden light on the horizon. With a cry of triumph, she steered the ship towards it. As they reached the shore, Maeve leapt out, her eyes drinking in the beauty of the island. And then, she saw her \u2013 the goddess, bathing in the golden light. But, as Maeve approached, the goddess vanished, leaving behind a small, intricately carved box. With a heavy heart, Maeve opened it, finding the island's secret: the price of eternal youth was to never find it.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "top_p": 0.9,
8 | "top_k": 5,
9 | "sampler_type": "vocabulary",
10 | "start_from": 0,
11 | "token_list": {
12 | "0": "Captain",
13 | "1": " Mae",
14 | "2": "ve",
15 | "3": " stood",
16 | "4": " at",
17 | "5": " the",
18 | "6": " helm",
19 | "7": " of",
20 | "8": " her",
21 | "9": " trust",
22 | "10": "y",
23 | "11": " ship",
24 | "12": ",",
25 | "13": " the",
26 | "14": " \"",
27 | "15": "M",
28 | "16": "ael",
29 | "17": "strom",
30 | "18": "'s",
31 | "19": " Fury",
32 | "20": "\",",
33 | "21": " g",
34 | "22": "azing",
35 | "23": " out",
36 | "24": " at",
37 | "25": " the",
38 | "26": " endless",
39 | "27": " sea",
40 | "28": ".",
41 | "29": " For",
42 | "30": " years",
43 | "31": ",",
44 | "32": " she",
45 | "33": "'d",
46 | "34": " sailed",
47 | "35": " the",
48 | "36": " waves",
49 | "37": ",",
50 | "38": " searching",
51 | "39": " for",
52 | "40": " the",
53 | "41": " f",
54 | "42": "abled",
55 | "43": " Golden",
56 | "44": " Isle",
57 | "45": ",",
58 | "46": " where",
59 | "47": " legend",
60 | "48": " whispered",
61 | "49": " the",
62 | "50": " sea",
63 | "51": " goddess",
64 | "52": " herself",
65 | "53": " granted",
66 | "54": " eternal",
67 | "55": " youth",
68 | "56": " to",
69 | "57": " those",
70 | "58": " who",
71 | "59": " found",
72 | "60": " it",
73 | "61": ".\n\n",
74 | "62": "As",
75 | "63": " the",
76 | "64": " sun",
77 | "65": " dipped",
78 | "66": " into",
79 | "67": " the",
80 | "68": " horizon",
81 | "69": ",",
82 | "70": " Mae",
83 | "71": "ve",
84 | "72": " spotted",
85 | "73": " it",
86 | "74": " \u2013",
87 | "75": " a",
88 | "76": " shimmer",
89 | "77": "ing",
90 | "78": ",",
91 | "79": " golden",
92 | "80": " light",
93 | "81": " on",
94 | "82": " the",
95 | "83": " horizon",
96 | "84": ".",
97 | "85": " With",
98 | "86": " a",
99 | "87": " cry",
100 | "88": " of",
101 | "89": " triumph",
102 | "90": ",",
103 | "91": " she",
104 | "92": " ste",
105 | "93": "ered",
106 | "94": " the",
107 | "95": " ship",
108 | "96": " towards",
109 | "97": " it",
110 | "98": ".",
111 | "99": " As",
112 | "100": " they",
113 | "101": " reached",
114 | "102": " the",
115 | "103": " shore",
116 | "104": ",",
117 | "105": " Mae",
118 | "106": "ve",
119 | "107": " le",
120 | "108": "apt",
121 | "109": " out",
122 | "110": ",",
123 | "111": " her",
124 | "112": " eyes",
125 | "113": " drinking",
126 | "114": " in",
127 | "115": " the",
128 | "116": " beauty",
129 | "117": " of",
130 | "118": " the",
131 | "119": " island",
132 | "120": ".",
133 | "121": " And",
134 | "122": " then",
135 | "123": ",",
136 | "124": " she",
137 | "125": " saw",
138 | "126": " her",
139 | "127": " \u2013",
140 | "128": " the",
141 | "129": " goddess",
142 | "130": ",",
143 | "131": " bathing",
144 | "132": " in",
145 | "133": " the",
146 | "134": " golden",
147 | "135": " light",
148 | "136": ".",
149 | "137": " But",
150 | "138": ",",
151 | "139": " as",
152 | "140": " Mae",
153 | "141": "ve",
154 | "142": " approached",
155 | "143": ",",
156 | "144": " the",
157 | "145": " goddess",
158 | "146": " vanished",
159 | "147": ",",
160 | "148": " leaving",
161 | "149": " behind",
162 | "150": " a",
163 | "151": " small",
164 | "152": ",",
165 | "153": " intric",
166 | "154": "ately",
167 | "155": " carved",
168 | "156": " box",
169 | "157": ".",
170 | "158": " With",
171 | "159": " a",
172 | "160": " heavy",
173 | "161": " heart",
174 | "162": ",",
175 | "163": " Mae",
176 | "164": "ve",
177 | "165": " opened",
178 | "166": " it",
179 | "167": ",",
180 | "168": " finding",
181 | "169": " the",
182 | "170": " island",
183 | "171": "'s",
184 | "172": " secret",
185 | "173": ":",
186 | "174": " the",
187 | "175": " price",
188 | "176": " of",
189 | "177": " eternal",
190 | "178": " youth",
191 | "179": " was",
192 | "180": " to",
193 | "181": " never",
194 | "182": " find",
195 | "183": " it",
196 | "184": "."
197 | }
198 | }
--------------------------------------------------------------------------------
/outputs/story-interventional/rngstates_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/story-interventional/rngstates_1.pt
--------------------------------------------------------------------------------
/outputs/story-name/counterfactual_1.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Maeve stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, she finally spotted its shimmering shores. As she stepped onto the island, a radiant light enveloped her, and she wished for her lost love, the brave sailor, Eamon, to return to her. Suddenly, Eamon emerged from the mist, frail but alive. Overjoyed, Maeve scooped him into her arms, and they shared a tender kiss as the sun dipped into the horizon, their love reunited at last.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "top_p": 0.9,
8 | "top_k": 5,
9 | "sampler_type": "vocabulary",
10 | "start_from": 0,
11 | "token_list": {
12 | "0": "Captain",
13 | "1": " Mae",
14 | "2": "ve",
15 | "3": " stood",
16 | "4": " at",
17 | "5": " the",
18 | "6": " helm",
19 | "7": " of",
20 | "8": " her",
21 | "9": " trust",
22 | "10": "y",
23 | "11": " ship",
24 | "12": ",",
25 | "13": " the",
26 | "14": " \"",
27 | "15": "M",
28 | "16": "ael",
29 | "17": "strom",
30 | "18": "'s",
31 | "19": " Fury",
32 | "20": "\",",
33 | "21": " g",
34 | "22": "azing",
35 | "23": " out",
36 | "24": " at",
37 | "25": " the",
38 | "26": " endless",
39 | "27": " sea",
40 | "28": ".",
41 | "29": " For",
42 | "30": " years",
43 | "31": ",",
44 | "32": " she",
45 | "33": " had",
46 | "34": " sailed",
47 | "35": " the",
48 | "36": " seven",
49 | "37": " seas",
50 | "38": ",",
51 | "39": " seeking",
52 | "40": " the",
53 | "41": " legendary",
54 | "42": " Golden",
55 | "43": " Isle",
56 | "44": ",",
57 | "45": " said",
58 | "46": " to",
59 | "47": " grant",
60 | "48": " a",
61 | "49": " single",
62 | "50": " wish",
63 | "51": " to",
64 | "52": " whoever",
65 | "53": " found",
66 | "54": " it",
67 | "55": ".",
68 | "56": " After",
69 | "57": " a",
70 | "58": " lifetime",
71 | "59": " of",
72 | "60": " searching",
73 | "61": ",",
74 | "62": " she",
75 | "63": " finally",
76 | "64": " spotted",
77 | "65": " its",
78 | "66": " shimmer",
79 | "67": "ing",
80 | "68": " shores",
81 | "69": ".",
82 | "70": " As",
83 | "71": " she",
84 | "72": " stepped",
85 | "73": " onto",
86 | "74": " the",
87 | "75": " island",
88 | "76": ",",
89 | "77": " a",
90 | "78": " radiant",
91 | "79": " light",
92 | "80": " envelop",
93 | "81": "ed",
94 | "82": " her",
95 | "83": ",",
96 | "84": " and",
97 | "85": " she",
98 | "86": " wished",
99 | "87": " for",
100 | "88": " her",
101 | "89": " lost",
102 | "90": " love",
103 | "91": ",",
104 | "92": " the",
105 | "93": " brave",
106 | "94": " sailor",
107 | "95": ",",
108 | "96": " E",
109 | "97": "amon",
110 | "98": ",",
111 | "99": " to",
112 | "100": " return",
113 | "101": " to",
114 | "102": " her",
115 | "103": ".",
116 | "104": " Suddenly",
117 | "105": ",",
118 | "106": " E",
119 | "107": "amon",
120 | "108": " emerged",
121 | "109": " from",
122 | "110": " the",
123 | "111": " mist",
124 | "112": ",",
125 | "113": " frail",
126 | "114": " but",
127 | "115": " alive",
128 | "116": ".",
129 | "117": " Over",
130 | "118": "joy",
131 | "119": "ed",
132 | "120": ",",
133 | "121": " Mae",
134 | "122": "ve",
135 | "123": " sco",
136 | "124": "oped",
137 | "125": " him",
138 | "126": " into",
139 | "127": " her",
140 | "128": " arms",
141 | "129": ",",
142 | "130": " and",
143 | "131": " they",
144 | "132": " shared",
145 | "133": " a",
146 | "134": " tender",
147 | "135": " kiss",
148 | "136": " as",
149 | "137": " the",
150 | "138": " sun",
151 | "139": " dipped",
152 | "140": " into",
153 | "141": " the",
154 | "142": " horizon",
155 | "143": ",",
156 | "144": " their",
157 | "145": " love",
158 | "146": " reunited",
159 | "147": " at",
160 | "148": " last",
161 | "149": "."
162 | }
163 | }
--------------------------------------------------------------------------------
/outputs/story-name/factual.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold and jewels lay at her feet. Lyra's heart soared as she reached for the treasure. But, as she touched the gold, the isle began to fade away, leaving her with nothing but a whispered promise: \"The true treasure was the journey, not the gold.\" Lyra's eyes welled up with tears as she realized the bitter truth \u2013 she had sacrificed everything for fleeting riches, and lost the love of her crew, her family, and herself. The \"Maelstrom's Fury\" sailed away, a lonely shadow on the vast ocean, carrying a captain forever haunted by the what-ifs.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "sampler_type": "vocabulary",
8 | "top_p": 0.9,
9 | "top_k": 5
10 | }
--------------------------------------------------------------------------------
/outputs/story-name/intervention_1.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Maeve stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "sampler_type": "vocabulary",
8 | "top_p": 0.9,
9 | "top_k": 5,
10 | "start_from": 29,
11 | "token_list": {
12 | "0": "Captain",
13 | "1": " Ly",
14 | "2": "ra",
15 | "3": " stood",
16 | "4": " at",
17 | "5": " the",
18 | "6": " helm",
19 | "7": " of",
20 | "8": " her",
21 | "9": " trust",
22 | "10": "y",
23 | "11": " ship",
24 | "12": ",",
25 | "13": " the",
26 | "14": " \"",
27 | "15": "M",
28 | "16": "ael",
29 | "17": "strom",
30 | "18": "'s",
31 | "19": " Fury",
32 | "20": "\",",
33 | "21": " g",
34 | "22": "azing",
35 | "23": " out",
36 | "24": " at",
37 | "25": " the",
38 | "26": " endless",
39 | "27": " sea",
40 | "28": ".",
41 | "29": " For",
42 | "30": " years",
43 | "31": ",",
44 | "32": " she",
45 | "33": " had",
46 | "34": " sailed",
47 | "35": " the",
48 | "36": " seven",
49 | "37": " seas",
50 | "38": ",",
51 | "39": " seeking",
52 | "40": " the",
53 | "41": " legendary",
54 | "42": " Golden",
55 | "43": " Isle",
56 | "44": ",",
57 | "45": " said",
58 | "46": " to",
59 | "47": " grant",
60 | "48": " a",
61 | "49": " single",
62 | "50": " wish",
63 | "51": " to",
64 | "52": " whoever",
65 | "53": " found",
66 | "54": " it",
67 | "55": ".",
68 | "56": " After",
69 | "57": " a",
70 | "58": " lifetime",
71 | "59": " of",
72 | "60": " searching",
73 | "61": ",",
74 | "62": " Ly",
75 | "63": "ra",
76 | "64": " finally",
77 | "65": " spotted",
78 | "66": " the",
79 | "67": " island",
80 | "68": " on",
81 | "69": " the",
82 | "70": " horizon",
83 | "71": ".\n\n",
84 | "72": "As",
85 | "73": " she",
86 | "74": " stepped",
87 | "75": " foot",
88 | "76": " on",
89 | "77": " its",
90 | "78": " sandy",
91 | "79": " shores",
92 | "80": ",",
93 | "81": " a",
94 | "82": " warm",
95 | "83": " glow",
96 | "84": " envelop",
97 | "85": "ed",
98 | "86": " her",
99 | "87": ".",
100 | "88": " Before",
101 | "89": " her",
102 | "90": ",",
103 | "91": " a",
104 | "92": " chest",
105 | "93": " overflowing",
106 | "94": " with",
107 | "95": " glitter",
108 | "96": "ing",
109 | "97": " gold",
110 | "98": " and",
111 | "99": " jewels",
112 | "100": " lay",
113 | "101": " at",
114 | "102": " her",
115 | "103": " feet",
116 | "104": ".",
117 | "105": " Ly",
118 | "106": "ra",
119 | "107": "'s",
120 | "108": " heart",
121 | "109": " soared",
122 | "110": " as",
123 | "111": " she",
124 | "112": " reached",
125 | "113": " for",
126 | "114": " the",
127 | "115": " treasure",
128 | "116": ".",
129 | "117": " But",
130 | "118": ",",
131 | "119": " as",
132 | "120": " she",
133 | "121": " touched",
134 | "122": " the",
135 | "123": " gold",
136 | "124": ",",
137 | "125": " the",
138 | "126": " is",
139 | "127": "le",
140 | "128": " began",
141 | "129": " to",
142 | "130": " fade",
143 | "131": " away",
144 | "132": ",",
145 | "133": " leaving",
146 | "134": " her",
147 | "135": " with",
148 | "136": " nothing",
149 | "137": " but",
150 | "138": " a",
151 | "139": " whispered",
152 | "140": " promise",
153 | "141": ":",
154 | "142": " \"",
155 | "143": "The",
156 | "144": " true",
157 | "145": " treasure",
158 | "146": " was",
159 | "147": " the",
160 | "148": " journey",
161 | "149": ",",
162 | "150": " not",
163 | "151": " the",
164 | "152": " gold",
165 | "153": ".\"",
166 | "154": " Ly",
167 | "155": "ra",
168 | "156": "'s",
169 | "157": " eyes",
170 | "158": " wel",
171 | "159": "led",
172 | "160": " up",
173 | "161": " with",
174 | "162": " tears",
175 | "163": " as",
176 | "164": " she",
177 | "165": " realized",
178 | "166": " the",
179 | "167": " bitter",
180 | "168": " truth",
181 | "169": " \u2013",
182 | "170": " she",
183 | "171": " had",
184 | "172": " sacrificed",
185 | "173": " everything",
186 | "174": " for",
187 | "175": " fleeting",
188 | "176": " riches",
189 | "177": ",",
190 | "178": " and",
191 | "179": " lost",
192 | "180": " the",
193 | "181": " love",
194 | "182": " of",
195 | "183": " her",
196 | "184": " crew",
197 | "185": ",",
198 | "186": " her",
199 | "187": " family",
200 | "188": ",",
201 | "189": " and",
202 | "190": " herself",
203 | "191": ".",
204 | "192": " The",
205 | "193": " \"",
206 | "194": "M",
207 | "195": "ael",
208 | "196": "strom",
209 | "197": "'s",
210 | "198": " Fury",
211 | "199": "\"",
212 | "200": " sailed",
213 | "201": " away",
214 | "202": ",",
215 | "203": " a",
216 | "204": " lonely",
217 | "205": " shadow",
218 | "206": " on",
219 | "207": " the",
220 | "208": " vast",
221 | "209": " ocean",
222 | "210": ",",
223 | "211": " carrying",
224 | "212": " a",
225 | "213": " captain",
226 | "214": " forever",
227 | "215": " haunted",
228 | "216": " by",
229 | "217": " the",
230 | "218": " what",
231 | "219": "-",
232 | "220": "ifs",
233 | "221": "."
234 | }
235 | }
--------------------------------------------------------------------------------
/outputs/story-name/intervention_2.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Maeve stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, she finally spotted its shimmering shores. As she stepped onto the island, a radiant light enveloped her, and she wished for her lost love, the brave sailor, Eamon, to return to her. Suddenly, Eamon emerged from the mist, frail but alive. Overjoyed, Maeve scooped him into her arms, and they shared a tender kiss as the sun dipped into the horizon, their love reunited at last.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "top_p": 0.9,
8 | "top_k": 5,
9 | "sampler_type": "vocabulary",
10 | "start_from": 0,
11 | "token_list": {
12 | "0": "Captain",
13 | "1": " Mae",
14 | "2": "ve",
15 | "3": " stood",
16 | "4": " at",
17 | "5": " the",
18 | "6": " helm",
19 | "7": " of",
20 | "8": " her",
21 | "9": " trust",
22 | "10": "y",
23 | "11": " ship",
24 | "12": ",",
25 | "13": " the",
26 | "14": " \"",
27 | "15": "M",
28 | "16": "ael",
29 | "17": "strom",
30 | "18": "'s",
31 | "19": " Fury",
32 | "20": "\",",
33 | "21": " g",
34 | "22": "azing",
35 | "23": " out",
36 | "24": " at",
37 | "25": " the",
38 | "26": " endless",
39 | "27": " sea",
40 | "28": ".",
41 | "29": " For",
42 | "30": " years",
43 | "31": ",",
44 | "32": " she",
45 | "33": " had",
46 | "34": " sailed",
47 | "35": " the",
48 | "36": " seven",
49 | "37": " seas",
50 | "38": ",",
51 | "39": " seeking",
52 | "40": " the",
53 | "41": " legendary",
54 | "42": " Golden",
55 | "43": " Isle",
56 | "44": ",",
57 | "45": " said",
58 | "46": " to",
59 | "47": " grant",
60 | "48": " a",
61 | "49": " single",
62 | "50": " wish",
63 | "51": " to",
64 | "52": " whoever",
65 | "53": " found",
66 | "54": " it",
67 | "55": ".",
68 | "56": " After",
69 | "57": " a",
70 | "58": " lifetime",
71 | "59": " of",
72 | "60": " searching",
73 | "61": ",",
74 | "62": " she",
75 | "63": " finally",
76 | "64": " spotted",
77 | "65": " its",
78 | "66": " shimmer",
79 | "67": "ing",
80 | "68": " shores",
81 | "69": ".",
82 | "70": " As",
83 | "71": " she",
84 | "72": " stepped",
85 | "73": " onto",
86 | "74": " the",
87 | "75": " island",
88 | "76": ",",
89 | "77": " a",
90 | "78": " radiant",
91 | "79": " light",
92 | "80": " envelop",
93 | "81": "ed",
94 | "82": " her",
95 | "83": ",",
96 | "84": " and",
97 | "85": " she",
98 | "86": " wished",
99 | "87": " for",
100 | "88": " her",
101 | "89": " lost",
102 | "90": " love",
103 | "91": ",",
104 | "92": " the",
105 | "93": " brave",
106 | "94": " sailor",
107 | "95": ",",
108 | "96": " E",
109 | "97": "amon",
110 | "98": ",",
111 | "99": " to",
112 | "100": " return",
113 | "101": " to",
114 | "102": " her",
115 | "103": ".",
116 | "104": " Suddenly",
117 | "105": ",",
118 | "106": " E",
119 | "107": "amon",
120 | "108": " emerged",
121 | "109": " from",
122 | "110": " the",
123 | "111": " mist",
124 | "112": ",",
125 | "113": " frail",
126 | "114": " but",
127 | "115": " alive",
128 | "116": ".",
129 | "117": " Over",
130 | "118": "joy",
131 | "119": "ed",
132 | "120": ",",
133 | "121": " Mae",
134 | "122": "ve",
135 | "123": " sco",
136 | "124": "oped",
137 | "125": " him",
138 | "126": " into",
139 | "127": " her",
140 | "128": " arms",
141 | "129": ",",
142 | "130": " and",
143 | "131": " they",
144 | "132": " shared",
145 | "133": " a",
146 | "134": " tender",
147 | "135": " kiss",
148 | "136": " as",
149 | "137": " the",
150 | "138": " sun",
151 | "139": " dipped",
152 | "140": " into",
153 | "141": " the",
154 | "142": " horizon",
155 | "143": ",",
156 | "144": " their",
157 | "145": " love",
158 | "146": " reunited",
159 | "147": " at",
160 | "148": " last",
161 | "149": "."
162 | }
163 | }
--------------------------------------------------------------------------------
/outputs/story-name/rngstates_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/story-name/rngstates_1.pt
--------------------------------------------------------------------------------
/outputs/story-remove-word/counterfactual_1.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Lyra stood at the helm of her ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold and jewels lay at her feet. Lyra's heart soared as she reached for the treasure. But, as she touched the gold, the isle began to fade away, leaving her with nothing but a whispered promise: \"The true treasure was the journey, not the gold.\" Lyra's eyes welled up with tears as she realized the bitter truth \u2013 she had sacrificed everything for fleeting riches, and lost the love of her crew, her family, and herself. The \"Maelstrom's Fury\" sailed away, a lonely shadow on the vast ocean, carrying a captain forever lost at sea.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "top_p": 0.9,
8 | "top_k": 5,
9 | "sampler_type": "vocabulary",
10 | "start_from": 0,
11 | "token_list": {
12 | "0": "Captain",
13 | "1": " Ly",
14 | "2": "ra",
15 | "3": " stood",
16 | "4": " at",
17 | "5": " the",
18 | "6": " helm",
19 | "7": " of",
20 | "8": " her",
21 | "9": " ship",
22 | "10": ",",
23 | "11": " the",
24 | "12": " \"",
25 | "13": "M",
26 | "14": "ael",
27 | "15": "strom",
28 | "16": "'s",
29 | "17": " Fury",
30 | "18": "\",",
31 | "19": " g",
32 | "20": "azing",
33 | "21": " out",
34 | "22": " at",
35 | "23": " the",
36 | "24": " endless",
37 | "25": " sea",
38 | "26": ".",
39 | "27": " For",
40 | "28": " years",
41 | "29": ",",
42 | "30": " she",
43 | "31": " had",
44 | "32": " sailed",
45 | "33": " the",
46 | "34": " seven",
47 | "35": " seas",
48 | "36": ",",
49 | "37": " seeking",
50 | "38": " the",
51 | "39": " legendary",
52 | "40": " Golden",
53 | "41": " Isle",
54 | "42": ",",
55 | "43": " said",
56 | "44": " to",
57 | "45": " grant",
58 | "46": " a",
59 | "47": " single",
60 | "48": " wish",
61 | "49": " to",
62 | "50": " whoever",
63 | "51": " found",
64 | "52": " it",
65 | "53": ".",
66 | "54": " After",
67 | "55": " a",
68 | "56": " lifetime",
69 | "57": " of",
70 | "58": " searching",
71 | "59": ",",
72 | "60": " Ly",
73 | "61": "ra",
74 | "62": " finally",
75 | "63": " spotted",
76 | "64": " the",
77 | "65": " island",
78 | "66": " on",
79 | "67": " the",
80 | "68": " horizon",
81 | "69": ".\n\n",
82 | "70": "As",
83 | "71": " she",
84 | "72": " stepped",
85 | "73": " foot",
86 | "74": " on",
87 | "75": " its",
88 | "76": " sandy",
89 | "77": " shores",
90 | "78": ",",
91 | "79": " a",
92 | "80": " warm",
93 | "81": " glow",
94 | "82": " envelop",
95 | "83": "ed",
96 | "84": " her",
97 | "85": ".",
98 | "86": " Before",
99 | "87": " her",
100 | "88": ",",
101 | "89": " a",
102 | "90": " chest",
103 | "91": " overflowing",
104 | "92": " with",
105 | "93": " glitter",
106 | "94": "ing",
107 | "95": " gold",
108 | "96": " and",
109 | "97": " jewels",
110 | "98": " lay",
111 | "99": " at",
112 | "100": " her",
113 | "101": " feet",
114 | "102": ".",
115 | "103": " Ly",
116 | "104": "ra",
117 | "105": "'s",
118 | "106": " heart",
119 | "107": " soared",
120 | "108": " as",
121 | "109": " she",
122 | "110": " reached",
123 | "111": " for",
124 | "112": " the",
125 | "113": " treasure",
126 | "114": ".",
127 | "115": " But",
128 | "116": ",",
129 | "117": " as",
130 | "118": " she",
131 | "119": " touched",
132 | "120": " the",
133 | "121": " gold",
134 | "122": ",",
135 | "123": " the",
136 | "124": " is",
137 | "125": "le",
138 | "126": " began",
139 | "127": " to",
140 | "128": " fade",
141 | "129": " away",
142 | "130": ",",
143 | "131": " leaving",
144 | "132": " her",
145 | "133": " with",
146 | "134": " nothing",
147 | "135": " but",
148 | "136": " a",
149 | "137": " whispered",
150 | "138": " promise",
151 | "139": ":",
152 | "140": " \"",
153 | "141": "The",
154 | "142": " true",
155 | "143": " treasure",
156 | "144": " was",
157 | "145": " the",
158 | "146": " journey",
159 | "147": ",",
160 | "148": " not",
161 | "149": " the",
162 | "150": " gold",
163 | "151": ".\"",
164 | "152": " Ly",
165 | "153": "ra",
166 | "154": "'s",
167 | "155": " eyes",
168 | "156": " wel",
169 | "157": "led",
170 | "158": " up",
171 | "159": " with",
172 | "160": " tears",
173 | "161": " as",
174 | "162": " she",
175 | "163": " realized",
176 | "164": " the",
177 | "165": " bitter",
178 | "166": " truth",
179 | "167": " \u2013",
180 | "168": " she",
181 | "169": " had",
182 | "170": " sacrificed",
183 | "171": " everything",
184 | "172": " for",
185 | "173": " fleeting",
186 | "174": " riches",
187 | "175": ",",
188 | "176": " and",
189 | "177": " lost",
190 | "178": " the",
191 | "179": " love",
192 | "180": " of",
193 | "181": " her",
194 | "182": " crew",
195 | "183": ",",
196 | "184": " her",
197 | "185": " family",
198 | "186": ",",
199 | "187": " and",
200 | "188": " herself",
201 | "189": ".",
202 | "190": " The",
203 | "191": " \"",
204 | "192": "M",
205 | "193": "ael",
206 | "194": "strom",
207 | "195": "'s",
208 | "196": " Fury",
209 | "197": "\"",
210 | "198": " sailed",
211 | "199": " away",
212 | "200": ",",
213 | "201": " a",
214 | "202": " lonely",
215 | "203": " shadow",
216 | "204": " on",
217 | "205": " the",
218 | "206": " vast",
219 | "207": " ocean",
220 | "208": ",",
221 | "209": " carrying",
222 | "210": " a",
223 | "211": " captain",
224 | "212": " forever",
225 | "213": " lost",
226 | "214": " at",
227 | "215": " sea",
228 | "216": "."
229 | }
230 | }
--------------------------------------------------------------------------------
/outputs/story-remove-word/factual.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold and jewels lay at her feet. Lyra's heart soared as she reached for the treasure. But, as she touched the gold, the isle began to fade away, leaving her with nothing but a whispered promise: \"The true treasure was the journey, not the gold.\" Lyra's eyes welled up with tears as she realized the bitter truth \u2013 she had sacrificed everything for fleeting riches, and lost the love of her crew, her family, and herself. The \"Maelstrom's Fury\" sailed away, a lonely shadow on the vast ocean, carrying a captain forever haunted by the what-ifs.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "sampler_type": "vocabulary",
8 | "top_p": 0.9,
9 | "top_k": 5
10 | }
--------------------------------------------------------------------------------
/outputs/story-remove-word/intervention_1.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "sampler_type": "vocabulary",
8 | "top_p": 0.9,
9 | "top_k": 5,
10 | "start_from": 29,
11 | "token_list": {
12 | "0": "Captain",
13 | "1": " Ly",
14 | "2": "ra",
15 | "3": " stood",
16 | "4": " at",
17 | "5": " the",
18 | "6": " helm",
19 | "7": " of",
20 | "8": " her",
21 | "9": " trust",
22 | "10": "y",
23 | "11": " ship",
24 | "12": ",",
25 | "13": " the",
26 | "14": " \"",
27 | "15": "M",
28 | "16": "ael",
29 | "17": "strom",
30 | "18": "'s",
31 | "19": " Fury",
32 | "20": "\",",
33 | "21": " g",
34 | "22": "azing",
35 | "23": " out",
36 | "24": " at",
37 | "25": " the",
38 | "26": " endless",
39 | "27": " sea",
40 | "28": ".",
41 | "29": " For",
42 | "30": " years",
43 | "31": ",",
44 | "32": " she",
45 | "33": " had",
46 | "34": " sailed",
47 | "35": " the",
48 | "36": " seven",
49 | "37": " seas",
50 | "38": ",",
51 | "39": " seeking",
52 | "40": " the",
53 | "41": " legendary",
54 | "42": " Golden",
55 | "43": " Isle",
56 | "44": ",",
57 | "45": " said",
58 | "46": " to",
59 | "47": " grant",
60 | "48": " a",
61 | "49": " single",
62 | "50": " wish",
63 | "51": " to",
64 | "52": " whoever",
65 | "53": " found",
66 | "54": " it",
67 | "55": ".",
68 | "56": " After",
69 | "57": " a",
70 | "58": " lifetime",
71 | "59": " of",
72 | "60": " searching",
73 | "61": ",",
74 | "62": " Ly",
75 | "63": "ra",
76 | "64": " finally",
77 | "65": " spotted",
78 | "66": " the",
79 | "67": " island",
80 | "68": " on",
81 | "69": " the",
82 | "70": " horizon",
83 | "71": ".\n\n",
84 | "72": "As",
85 | "73": " she",
86 | "74": " stepped",
87 | "75": " foot",
88 | "76": " on",
89 | "77": " its",
90 | "78": " sandy",
91 | "79": " shores",
92 | "80": ",",
93 | "81": " a",
94 | "82": " warm",
95 | "83": " glow",
96 | "84": " envelop",
97 | "85": "ed",
98 | "86": " her",
99 | "87": ".",
100 | "88": " Before",
101 | "89": " her",
102 | "90": ",",
103 | "91": " a",
104 | "92": " chest",
105 | "93": " overflowing",
106 | "94": " with",
107 | "95": " glitter",
108 | "96": "ing",
109 | "97": " gold",
110 | "98": " and",
111 | "99": " jewels",
112 | "100": " lay",
113 | "101": " at",
114 | "102": " her",
115 | "103": " feet",
116 | "104": ".",
117 | "105": " Ly",
118 | "106": "ra",
119 | "107": "'s",
120 | "108": " heart",
121 | "109": " soared",
122 | "110": " as",
123 | "111": " she",
124 | "112": " reached",
125 | "113": " for",
126 | "114": " the",
127 | "115": " treasure",
128 | "116": ".",
129 | "117": " But",
130 | "118": ",",
131 | "119": " as",
132 | "120": " she",
133 | "121": " touched",
134 | "122": " the",
135 | "123": " gold",
136 | "124": ",",
137 | "125": " the",
138 | "126": " is",
139 | "127": "le",
140 | "128": " began",
141 | "129": " to",
142 | "130": " fade",
143 | "131": " away",
144 | "132": ",",
145 | "133": " leaving",
146 | "134": " her",
147 | "135": " with",
148 | "136": " nothing",
149 | "137": " but",
150 | "138": " a",
151 | "139": " whispered",
152 | "140": " promise",
153 | "141": ":",
154 | "142": " \"",
155 | "143": "The",
156 | "144": " true",
157 | "145": " treasure",
158 | "146": " was",
159 | "147": " the",
160 | "148": " journey",
161 | "149": ",",
162 | "150": " not",
163 | "151": " the",
164 | "152": " gold",
165 | "153": ".\"",
166 | "154": " Ly",
167 | "155": "ra",
168 | "156": "'s",
169 | "157": " eyes",
170 | "158": " wel",
171 | "159": "led",
172 | "160": " up",
173 | "161": " with",
174 | "162": " tears",
175 | "163": " as",
176 | "164": " she",
177 | "165": " realized",
178 | "166": " the",
179 | "167": " bitter",
180 | "168": " truth",
181 | "169": " \u2013",
182 | "170": " she",
183 | "171": " had",
184 | "172": " sacrificed",
185 | "173": " everything",
186 | "174": " for",
187 | "175": " fleeting",
188 | "176": " riches",
189 | "177": ",",
190 | "178": " and",
191 | "179": " lost",
192 | "180": " the",
193 | "181": " love",
194 | "182": " of",
195 | "183": " her",
196 | "184": " crew",
197 | "185": ",",
198 | "186": " her",
199 | "187": " family",
200 | "188": ",",
201 | "189": " and",
202 | "190": " herself",
203 | "191": ".",
204 | "192": " The",
205 | "193": " \"",
206 | "194": "M",
207 | "195": "ael",
208 | "196": "strom",
209 | "197": "'s",
210 | "198": " Fury",
211 | "199": "\"",
212 | "200": " sailed",
213 | "201": " away",
214 | "202": ",",
215 | "203": " a",
216 | "204": " lonely",
217 | "205": " shadow",
218 | "206": " on",
219 | "207": " the",
220 | "208": " vast",
221 | "209": " ocean",
222 | "210": ",",
223 | "211": " carrying",
224 | "212": " a",
225 | "213": " captain",
226 | "214": " forever",
227 | "215": " haunted",
228 | "216": " by",
229 | "217": " the",
230 | "218": " what",
231 | "219": "-",
232 | "220": "ifs",
233 | "221": "."
234 | }
235 | }
--------------------------------------------------------------------------------
/outputs/story-remove-word/intervention_2.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Lyra stood at the helm of her ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold and jewels lay at her feet. Lyra's heart soared as she reached for the treasure. But, as she touched the gold, the isle began to fade away, leaving her with nothing but a whispered promise: \"The true treasure was the journey, not the gold.\" Lyra's eyes welled up with tears as she realized the bitter truth \u2013 she had sacrificed everything for fleeting riches, and lost the love of her crew, her family, and herself. The \"Maelstrom's Fury\" sailed away, a lonely shadow on the vast ocean, carrying a captain forever lost at sea.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "top_p": 0.9,
8 | "top_k": 5,
9 | "sampler_type": "vocabulary",
10 | "start_from": 0,
11 | "token_list": {
12 | "0": "Captain",
13 | "1": " Ly",
14 | "2": "ra",
15 | "3": " stood",
16 | "4": " at",
17 | "5": " the",
18 | "6": " helm",
19 | "7": " of",
20 | "8": " her",
21 | "9": " ship",
22 | "10": ",",
23 | "11": " the",
24 | "12": " \"",
25 | "13": "M",
26 | "14": "ael",
27 | "15": "strom",
28 | "16": "'s",
29 | "17": " Fury",
30 | "18": "\",",
31 | "19": " g",
32 | "20": "azing",
33 | "21": " out",
34 | "22": " at",
35 | "23": " the",
36 | "24": " endless",
37 | "25": " sea",
38 | "26": ".",
39 | "27": " For",
40 | "28": " years",
41 | "29": ",",
42 | "30": " she",
43 | "31": " had",
44 | "32": " sailed",
45 | "33": " the",
46 | "34": " seven",
47 | "35": " seas",
48 | "36": ",",
49 | "37": " seeking",
50 | "38": " the",
51 | "39": " legendary",
52 | "40": " Golden",
53 | "41": " Isle",
54 | "42": ",",
55 | "43": " said",
56 | "44": " to",
57 | "45": " grant",
58 | "46": " a",
59 | "47": " single",
60 | "48": " wish",
61 | "49": " to",
62 | "50": " whoever",
63 | "51": " found",
64 | "52": " it",
65 | "53": ".",
66 | "54": " After",
67 | "55": " a",
68 | "56": " lifetime",
69 | "57": " of",
70 | "58": " searching",
71 | "59": ",",
72 | "60": " Ly",
73 | "61": "ra",
74 | "62": " finally",
75 | "63": " spotted",
76 | "64": " the",
77 | "65": " island",
78 | "66": " on",
79 | "67": " the",
80 | "68": " horizon",
81 | "69": ".\n\n",
82 | "70": "As",
83 | "71": " she",
84 | "72": " stepped",
85 | "73": " foot",
86 | "74": " on",
87 | "75": " its",
88 | "76": " sandy",
89 | "77": " shores",
90 | "78": ",",
91 | "79": " a",
92 | "80": " warm",
93 | "81": " glow",
94 | "82": " envelop",
95 | "83": "ed",
96 | "84": " her",
97 | "85": ".",
98 | "86": " Before",
99 | "87": " her",
100 | "88": ",",
101 | "89": " a",
102 | "90": " chest",
103 | "91": " overflowing",
104 | "92": " with",
105 | "93": " glitter",
106 | "94": "ing",
107 | "95": " gold",
108 | "96": " and",
109 | "97": " jewels",
110 | "98": " lay",
111 | "99": " at",
112 | "100": " her",
113 | "101": " feet",
114 | "102": ".",
115 | "103": " Ly",
116 | "104": "ra",
117 | "105": "'s",
118 | "106": " heart",
119 | "107": " soared",
120 | "108": " as",
121 | "109": " she",
122 | "110": " reached",
123 | "111": " for",
124 | "112": " the",
125 | "113": " treasure",
126 | "114": ".",
127 | "115": " But",
128 | "116": ",",
129 | "117": " as",
130 | "118": " she",
131 | "119": " touched",
132 | "120": " the",
133 | "121": " gold",
134 | "122": ",",
135 | "123": " the",
136 | "124": " is",
137 | "125": "le",
138 | "126": " began",
139 | "127": " to",
140 | "128": " fade",
141 | "129": " away",
142 | "130": ",",
143 | "131": " leaving",
144 | "132": " her",
145 | "133": " with",
146 | "134": " nothing",
147 | "135": " but",
148 | "136": " a",
149 | "137": " whispered",
150 | "138": " promise",
151 | "139": ":",
152 | "140": " \"",
153 | "141": "The",
154 | "142": " true",
155 | "143": " treasure",
156 | "144": " was",
157 | "145": " the",
158 | "146": " journey",
159 | "147": ",",
160 | "148": " not",
161 | "149": " the",
162 | "150": " gold",
163 | "151": ".\"",
164 | "152": " Ly",
165 | "153": "ra",
166 | "154": "'s",
167 | "155": " eyes",
168 | "156": " wel",
169 | "157": "led",
170 | "158": " up",
171 | "159": " with",
172 | "160": " tears",
173 | "161": " as",
174 | "162": " she",
175 | "163": " realized",
176 | "164": " the",
177 | "165": " bitter",
178 | "166": " truth",
179 | "167": " \u2013",
180 | "168": " she",
181 | "169": " had",
182 | "170": " sacrificed",
183 | "171": " everything",
184 | "172": " for",
185 | "173": " fleeting",
186 | "174": " riches",
187 | "175": ",",
188 | "176": " and",
189 | "177": " lost",
190 | "178": " the",
191 | "179": " love",
192 | "180": " of",
193 | "181": " her",
194 | "182": " crew",
195 | "183": ",",
196 | "184": " her",
197 | "185": " family",
198 | "186": ",",
199 | "187": " and",
200 | "188": " herself",
201 | "189": ".",
202 | "190": " The",
203 | "191": " \"",
204 | "192": "M",
205 | "193": "ael",
206 | "194": "strom",
207 | "195": "'s",
208 | "196": " Fury",
209 | "197": "\"",
210 | "198": " sailed",
211 | "199": " away",
212 | "200": ",",
213 | "201": " a",
214 | "202": " lonely",
215 | "203": " shadow",
216 | "204": " on",
217 | "205": " the",
218 | "206": " vast",
219 | "207": " ocean",
220 | "208": ",",
221 | "209": " carrying",
222 | "210": " a",
223 | "211": " captain",
224 | "212": " forever",
225 | "213": " lost",
226 | "214": " at",
227 | "215": " sea",
228 | "216": "."
229 | }
230 | }
--------------------------------------------------------------------------------
/outputs/story-remove-word/rngstates_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/story-remove-word/rngstates_1.pt
--------------------------------------------------------------------------------
/outputs/story-ship/counterfactual_1.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Cyclone's Rage\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold coins appeared. But, just as she reached for it, a storm brewed, and a rival captain, seeking the treasure for himself, appeared. Lyra fought valiantly, but in the end, she was defeated.\n\nAs she lay on the beach, her ship destroyed and her crew scattered, Lyra realized her wish had been granted \u2013 not for wealth or power, but for the courage to sail through life, no matter the outcome.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "top_p": 0.9,
8 | "top_k": 5,
9 | "sampler_type": "vocabulary",
10 | "start_from": 0,
11 | "token_list": {
12 | "0": "Captain",
13 | "1": " Ly",
14 | "2": "ra",
15 | "3": " stood",
16 | "4": " at",
17 | "5": " the",
18 | "6": " helm",
19 | "7": " of",
20 | "8": " her",
21 | "9": " trust",
22 | "10": "y",
23 | "11": " ship",
24 | "12": ",",
25 | "13": " the",
26 | "14": " \"",
27 | "15": "C",
28 | "16": "ycl",
29 | "17": "one",
30 | "18": "'s",
31 | "19": " Rage",
32 | "20": "\",",
33 | "21": " g",
34 | "22": "azing",
35 | "23": " out",
36 | "24": " at",
37 | "25": " the",
38 | "26": " endless",
39 | "27": " sea",
40 | "28": ".",
41 | "29": " For",
42 | "30": " years",
43 | "31": ",",
44 | "32": " she",
45 | "33": " had",
46 | "34": " sailed",
47 | "35": " the",
48 | "36": " seven",
49 | "37": " seas",
50 | "38": ",",
51 | "39": " seeking",
52 | "40": " the",
53 | "41": " legendary",
54 | "42": " Golden",
55 | "43": " Isle",
56 | "44": ",",
57 | "45": " said",
58 | "46": " to",
59 | "47": " grant",
60 | "48": " a",
61 | "49": " single",
62 | "50": " wish",
63 | "51": " to",
64 | "52": " whoever",
65 | "53": " found",
66 | "54": " it",
67 | "55": ".",
68 | "56": " After",
69 | "57": " a",
70 | "58": " lifetime",
71 | "59": " of",
72 | "60": " searching",
73 | "61": ",",
74 | "62": " Ly",
75 | "63": "ra",
76 | "64": " finally",
77 | "65": " spotted",
78 | "66": " the",
79 | "67": " island",
80 | "68": " on",
81 | "69": " the",
82 | "70": " horizon",
83 | "71": ".\n\n",
84 | "72": "As",
85 | "73": " she",
86 | "74": " stepped",
87 | "75": " foot",
88 | "76": " on",
89 | "77": " its",
90 | "78": " sandy",
91 | "79": " shores",
92 | "80": ",",
93 | "81": " a",
94 | "82": " warm",
95 | "83": " glow",
96 | "84": " envelop",
97 | "85": "ed",
98 | "86": " her",
99 | "87": ".",
100 | "88": " Before",
101 | "89": " her",
102 | "90": ",",
103 | "91": " a",
104 | "92": " chest",
105 | "93": " overflowing",
106 | "94": " with",
107 | "95": " glitter",
108 | "96": "ing",
109 | "97": " gold",
110 | "98": " coins",
111 | "99": " appeared",
112 | "100": ".",
113 | "101": " But",
114 | "102": ",",
115 | "103": " just",
116 | "104": " as",
117 | "105": " she",
118 | "106": " reached",
119 | "107": " for",
120 | "108": " it",
121 | "109": ",",
122 | "110": " a",
123 | "111": " storm",
124 | "112": " brewed",
125 | "113": ",",
126 | "114": " and",
127 | "115": " a",
128 | "116": " rival",
129 | "117": " captain",
130 | "118": ",",
131 | "119": " seeking",
132 | "120": " the",
133 | "121": " treasure",
134 | "122": " for",
135 | "123": " himself",
136 | "124": ",",
137 | "125": " appeared",
138 | "126": ".",
139 | "127": " Ly",
140 | "128": "ra",
141 | "129": " fought",
142 | "130": " val",
143 | "131": "iant",
144 | "132": "ly",
145 | "133": ",",
146 | "134": " but",
147 | "135": " in",
148 | "136": " the",
149 | "137": " end",
150 | "138": ",",
151 | "139": " she",
152 | "140": " was",
153 | "141": " defeated",
154 | "142": ".\n\n",
155 | "143": "As",
156 | "144": " she",
157 | "145": " lay",
158 | "146": " on",
159 | "147": " the",
160 | "148": " beach",
161 | "149": ",",
162 | "150": " her",
163 | "151": " ship",
164 | "152": " destroyed",
165 | "153": " and",
166 | "154": " her",
167 | "155": " crew",
168 | "156": " scattered",
169 | "157": ",",
170 | "158": " Ly",
171 | "159": "ra",
172 | "160": " realized",
173 | "161": " her",
174 | "162": " wish",
175 | "163": " had",
176 | "164": " been",
177 | "165": " granted",
178 | "166": " \u2013",
179 | "167": " not",
180 | "168": " for",
181 | "169": " wealth",
182 | "170": " or",
183 | "171": " power",
184 | "172": ",",
185 | "173": " but",
186 | "174": " for",
187 | "175": " the",
188 | "176": " courage",
189 | "177": " to",
190 | "178": " sail",
191 | "179": " through",
192 | "180": " life",
193 | "181": ",",
194 | "182": " no",
195 | "183": " matter",
196 | "184": " the",
197 | "185": " outcome",
198 | "186": "."
199 | }
200 | }
--------------------------------------------------------------------------------
/outputs/story-ship/factual.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Maelstrom's Fury\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold and jewels lay at her feet. Lyra's heart soared as she reached for the treasure. But, as she touched the gold, the isle began to fade away, leaving her with nothing but a whispered promise: \"The true treasure was the journey, not the gold.\" Lyra's eyes welled up with tears as she realized the bitter truth \u2013 she had sacrificed everything for fleeting riches, and lost the love of her crew, her family, and herself. The \"Maelstrom's Fury\" sailed away, a lonely shadow on the vast ocean, carrying a captain forever haunted by the what-ifs.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "sampler_type": "vocabulary",
8 | "top_p": 0.9,
9 | "top_k": 5
10 | }
--------------------------------------------------------------------------------
/outputs/story-ship/intervention_1.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Cyclone's Rage\", gazing out at the endless sea.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "sampler_type": "vocabulary",
8 | "top_p": 0.9,
9 | "top_k": 5,
10 | "start_from": 29,
11 | "token_list": {
12 | "0": "Captain",
13 | "1": " Ly",
14 | "2": "ra",
15 | "3": " stood",
16 | "4": " at",
17 | "5": " the",
18 | "6": " helm",
19 | "7": " of",
20 | "8": " her",
21 | "9": " trust",
22 | "10": "y",
23 | "11": " ship",
24 | "12": ",",
25 | "13": " the",
26 | "14": " \"",
27 | "15": "M",
28 | "16": "ael",
29 | "17": "strom",
30 | "18": "'s",
31 | "19": " Fury",
32 | "20": "\",",
33 | "21": " g",
34 | "22": "azing",
35 | "23": " out",
36 | "24": " at",
37 | "25": " the",
38 | "26": " endless",
39 | "27": " sea",
40 | "28": ".",
41 | "29": " For",
42 | "30": " years",
43 | "31": ",",
44 | "32": " she",
45 | "33": " had",
46 | "34": " sailed",
47 | "35": " the",
48 | "36": " seven",
49 | "37": " seas",
50 | "38": ",",
51 | "39": " seeking",
52 | "40": " the",
53 | "41": " legendary",
54 | "42": " Golden",
55 | "43": " Isle",
56 | "44": ",",
57 | "45": " said",
58 | "46": " to",
59 | "47": " grant",
60 | "48": " a",
61 | "49": " single",
62 | "50": " wish",
63 | "51": " to",
64 | "52": " whoever",
65 | "53": " found",
66 | "54": " it",
67 | "55": ".",
68 | "56": " After",
69 | "57": " a",
70 | "58": " lifetime",
71 | "59": " of",
72 | "60": " searching",
73 | "61": ",",
74 | "62": " Ly",
75 | "63": "ra",
76 | "64": " finally",
77 | "65": " spotted",
78 | "66": " the",
79 | "67": " island",
80 | "68": " on",
81 | "69": " the",
82 | "70": " horizon",
83 | "71": ".\n\n",
84 | "72": "As",
85 | "73": " she",
86 | "74": " stepped",
87 | "75": " foot",
88 | "76": " on",
89 | "77": " its",
90 | "78": " sandy",
91 | "79": " shores",
92 | "80": ",",
93 | "81": " a",
94 | "82": " warm",
95 | "83": " glow",
96 | "84": " envelop",
97 | "85": "ed",
98 | "86": " her",
99 | "87": ".",
100 | "88": " Before",
101 | "89": " her",
102 | "90": ",",
103 | "91": " a",
104 | "92": " chest",
105 | "93": " overflowing",
106 | "94": " with",
107 | "95": " glitter",
108 | "96": "ing",
109 | "97": " gold",
110 | "98": " and",
111 | "99": " jewels",
112 | "100": " lay",
113 | "101": " at",
114 | "102": " her",
115 | "103": " feet",
116 | "104": ".",
117 | "105": " Ly",
118 | "106": "ra",
119 | "107": "'s",
120 | "108": " heart",
121 | "109": " soared",
122 | "110": " as",
123 | "111": " she",
124 | "112": " reached",
125 | "113": " for",
126 | "114": " the",
127 | "115": " treasure",
128 | "116": ".",
129 | "117": " But",
130 | "118": ",",
131 | "119": " as",
132 | "120": " she",
133 | "121": " touched",
134 | "122": " the",
135 | "123": " gold",
136 | "124": ",",
137 | "125": " the",
138 | "126": " is",
139 | "127": "le",
140 | "128": " began",
141 | "129": " to",
142 | "130": " fade",
143 | "131": " away",
144 | "132": ",",
145 | "133": " leaving",
146 | "134": " her",
147 | "135": " with",
148 | "136": " nothing",
149 | "137": " but",
150 | "138": " a",
151 | "139": " whispered",
152 | "140": " promise",
153 | "141": ":",
154 | "142": " \"",
155 | "143": "The",
156 | "144": " true",
157 | "145": " treasure",
158 | "146": " was",
159 | "147": " the",
160 | "148": " journey",
161 | "149": ",",
162 | "150": " not",
163 | "151": " the",
164 | "152": " gold",
165 | "153": ".\"",
166 | "154": " Ly",
167 | "155": "ra",
168 | "156": "'s",
169 | "157": " eyes",
170 | "158": " wel",
171 | "159": "led",
172 | "160": " up",
173 | "161": " with",
174 | "162": " tears",
175 | "163": " as",
176 | "164": " she",
177 | "165": " realized",
178 | "166": " the",
179 | "167": " bitter",
180 | "168": " truth",
181 | "169": " \u2013",
182 | "170": " she",
183 | "171": " had",
184 | "172": " sacrificed",
185 | "173": " everything",
186 | "174": " for",
187 | "175": " fleeting",
188 | "176": " riches",
189 | "177": ",",
190 | "178": " and",
191 | "179": " lost",
192 | "180": " the",
193 | "181": " love",
194 | "182": " of",
195 | "183": " her",
196 | "184": " crew",
197 | "185": ",",
198 | "186": " her",
199 | "187": " family",
200 | "188": ",",
201 | "189": " and",
202 | "190": " herself",
203 | "191": ".",
204 | "192": " The",
205 | "193": " \"",
206 | "194": "M",
207 | "195": "ael",
208 | "196": "strom",
209 | "197": "'s",
210 | "198": " Fury",
211 | "199": "\"",
212 | "200": " sailed",
213 | "201": " away",
214 | "202": ",",
215 | "203": " a",
216 | "204": " lonely",
217 | "205": " shadow",
218 | "206": " on",
219 | "207": " the",
220 | "208": " vast",
221 | "209": " ocean",
222 | "210": ",",
223 | "211": " carrying",
224 | "212": " a",
225 | "213": " captain",
226 | "214": " forever",
227 | "215": " haunted",
228 | "216": " by",
229 | "217": " the",
230 | "218": " what",
231 | "219": "-",
232 | "220": "ifs",
233 | "221": "."
234 | }
235 | }
--------------------------------------------------------------------------------
/outputs/story-ship/intervention_2.json:
--------------------------------------------------------------------------------
1 | {
2 | "system": "Be creative and keep your response as short as possible.",
3 | "query": "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
4 | "response": "Captain Lyra stood at the helm of her trusty ship, the \"Cyclone's Rage\", gazing out at the endless sea. For years, she had sailed the seven seas, seeking the legendary Golden Isle, said to grant a single wish to whoever found it. After a lifetime of searching, Lyra finally spotted the island on the horizon.\n\nAs she stepped foot on its sandy shores, a warm glow enveloped her. Before her, a chest overflowing with glittering gold coins appeared. But, just as she reached for it, a storm brewed, and a rival captain, seeking the treasure for himself, appeared. Lyra fought valiantly, but in the end, she was defeated.\n\nAs she lay on the beach, her ship destroyed and her crew scattered, Lyra realized her wish had been granted \u2013 not for wealth or power, but for the courage to sail through life, no matter the outcome.",
5 | "seed": 2,
6 | "temperature": 0.9,
7 | "top_p": 0.9,
8 | "top_k": 5,
9 | "sampler_type": "vocabulary",
10 | "start_from": 0,
11 | "token_list": {
12 | "0": "Captain",
13 | "1": " Ly",
14 | "2": "ra",
15 | "3": " stood",
16 | "4": " at",
17 | "5": " the",
18 | "6": " helm",
19 | "7": " of",
20 | "8": " her",
21 | "9": " trust",
22 | "10": "y",
23 | "11": " ship",
24 | "12": ",",
25 | "13": " the",
26 | "14": " \"",
27 | "15": "C",
28 | "16": "ycl",
29 | "17": "one",
30 | "18": "'s",
31 | "19": " Rage",
32 | "20": "\",",
33 | "21": " g",
34 | "22": "azing",
35 | "23": " out",
36 | "24": " at",
37 | "25": " the",
38 | "26": " endless",
39 | "27": " sea",
40 | "28": ".",
41 | "29": " For",
42 | "30": " years",
43 | "31": ",",
44 | "32": " she",
45 | "33": " had",
46 | "34": " sailed",
47 | "35": " the",
48 | "36": " seven",
49 | "37": " seas",
50 | "38": ",",
51 | "39": " seeking",
52 | "40": " the",
53 | "41": " legendary",
54 | "42": " Golden",
55 | "43": " Isle",
56 | "44": ",",
57 | "45": " said",
58 | "46": " to",
59 | "47": " grant",
60 | "48": " a",
61 | "49": " single",
62 | "50": " wish",
63 | "51": " to",
64 | "52": " whoever",
65 | "53": " found",
66 | "54": " it",
67 | "55": ".",
68 | "56": " After",
69 | "57": " a",
70 | "58": " lifetime",
71 | "59": " of",
72 | "60": " searching",
73 | "61": ",",
74 | "62": " Ly",
75 | "63": "ra",
76 | "64": " finally",
77 | "65": " spotted",
78 | "66": " the",
79 | "67": " island",
80 | "68": " on",
81 | "69": " the",
82 | "70": " horizon",
83 | "71": ".\n\n",
84 | "72": "As",
85 | "73": " she",
86 | "74": " stepped",
87 | "75": " foot",
88 | "76": " on",
89 | "77": " its",
90 | "78": " sandy",
91 | "79": " shores",
92 | "80": ",",
93 | "81": " a",
94 | "82": " warm",
95 | "83": " glow",
96 | "84": " envelop",
97 | "85": "ed",
98 | "86": " her",
99 | "87": ".",
100 | "88": " Before",
101 | "89": " her",
102 | "90": ",",
103 | "91": " a",
104 | "92": " chest",
105 | "93": " overflowing",
106 | "94": " with",
107 | "95": " glitter",
108 | "96": "ing",
109 | "97": " gold",
110 | "98": " coins",
111 | "99": " appeared",
112 | "100": ".",
113 | "101": " But",
114 | "102": ",",
115 | "103": " just",
116 | "104": " as",
117 | "105": " she",
118 | "106": " reached",
119 | "107": " for",
120 | "108": " it",
121 | "109": ",",
122 | "110": " a",
123 | "111": " storm",
124 | "112": " brewed",
125 | "113": ",",
126 | "114": " and",
127 | "115": " a",
128 | "116": " rival",
129 | "117": " captain",
130 | "118": ",",
131 | "119": " seeking",
132 | "120": " the",
133 | "121": " treasure",
134 | "122": " for",
135 | "123": " himself",
136 | "124": ",",
137 | "125": " appeared",
138 | "126": ".",
139 | "127": " Ly",
140 | "128": "ra",
141 | "129": " fought",
142 | "130": " val",
143 | "131": "iant",
144 | "132": "ly",
145 | "133": ",",
146 | "134": " but",
147 | "135": " in",
148 | "136": " the",
149 | "137": " end",
150 | "138": ",",
151 | "139": " she",
152 | "140": " was",
153 | "141": " defeated",
154 | "142": ".\n\n",
155 | "143": "As",
156 | "144": " she",
157 | "145": " lay",
158 | "146": " on",
159 | "147": " the",
160 | "148": " beach",
161 | "149": ",",
162 | "150": " her",
163 | "151": " ship",
164 | "152": " destroyed",
165 | "153": " and",
166 | "154": " her",
167 | "155": " crew",
168 | "156": " scattered",
169 | "157": ",",
170 | "158": " Ly",
171 | "159": "ra",
172 | "160": " realized",
173 | "161": " her",
174 | "162": " wish",
175 | "163": " had",
176 | "164": " been",
177 | "165": " granted",
178 | "166": " \u2013",
179 | "167": " not",
180 | "168": " for",
181 | "169": " wealth",
182 | "170": " or",
183 | "171": " power",
184 | "172": ",",
185 | "173": " but",
186 | "174": " for",
187 | "175": " the",
188 | "176": " courage",
189 | "177": " to",
190 | "178": " sail",
191 | "179": " through",
192 | "180": " life",
193 | "181": ",",
194 | "182": " no",
195 | "183": " matter",
196 | "184": " the",
197 | "185": " outcome",
198 | "186": "."
199 | }
200 | }
--------------------------------------------------------------------------------
/outputs/story-ship/rngstates_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/outputs/story-ship/rngstates_1.pt
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | asttokens==2.4.1
2 | blobfile==3.0.0
3 | certifi==2024.8.30
4 | charset-normalizer==3.3.2
5 | comm==0.2.2
6 | contourpy==1.3.0
7 | cramjam==2.8.4
8 | cycler==0.12.1
9 | debugpy==1.8.6
10 | decorator==5.1.1
11 | executing==2.1.0
12 | fairscale==0.4.13
13 | fastparquet==2024.5.0
14 | filelock==3.16.1
15 | fire==0.6.0
16 | fonttools==4.54.1
17 | fsspec==2024.9.0
18 | idna==3.10
19 | ipykernel==6.29.5
20 | ipython==8.27.0
21 | jedi==0.19.1
22 | Jinja2==3.1.4
23 | jupyter_client==8.6.3
24 | jupyter_core==5.7.2
25 | kiwisolver==1.4.7
26 | Levenshtein==0.26.0
27 | lxml==5.3.0
28 | MarkupSafe==2.1.5
29 | matplotlib==3.9.2
30 | matplotlib-inline==0.1.7
31 | mpmath==1.3.0
32 | nest-asyncio==1.6.0
33 | networkx==3.3
34 | numpy==2.1.1
35 | nvidia-cublas-cu12==12.1.3.1
36 | nvidia-cuda-cupti-cu12==12.1.105
37 | nvidia-cuda-nvrtc-cu12==12.1.105
38 | nvidia-cuda-runtime-cu12==12.1.105
39 | nvidia-cudnn-cu12==9.1.0.70
40 | nvidia-cufft-cu12==11.0.2.54
41 | nvidia-curand-cu12==10.3.2.106
42 | nvidia-cusolver-cu12==11.4.5.107
43 | nvidia-cusparse-cu12==12.1.0.106
44 | nvidia-nccl-cu12==2.20.5
45 | nvidia-nvjitlink-cu12==12.6.68
46 | nvidia-nvtx-cu12==12.1.105
47 | packaging==24.1
48 | pandas==2.2.3
49 | parso==0.8.4
50 | pexpect==4.9.0
51 | pillow==10.4.0
52 | platformdirs==4.3.6
53 | prompt_toolkit==3.0.48
54 | psutil==6.0.0
55 | ptyprocess==0.7.0
56 | pure_eval==0.2.3
57 | pycryptodomex==3.20.0
58 | Pygments==2.18.0
59 | pyparsing==3.1.4
60 | python-dateutil==2.9.0.post0
61 | pytz==2024.2
62 | pyzmq==26.2.0
63 | RapidFuzz==3.10.0
64 | regex==2024.9.11
65 | requests==2.32.3
66 | seaborn==0.13.2
67 | six==1.16.0
68 | stack-data==0.6.3
69 | sympy==1.13.3
70 | termcolor==2.4.0
71 | tiktoken==0.7.0
72 | torch==2.4.1
73 | tornado==6.4.1
74 | tqdm==4.66.5
75 | traitlets==5.14.3
76 | triton==3.0.0
77 | typing_extensions==4.12.2
78 | tzdata==2024.2
79 | urllib3==2.2.3
80 | wcwidth==0.2.13
--------------------------------------------------------------------------------
/scripts/bias.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import argparse
3 |
4 | parser=argparse.ArgumentParser()
5 | parser.add_argument("--i", default=1)
6 | parser.add_argument("--direct", default=False)
7 | parser.add_argument("--prior", default=False)
8 | parser.add_argument("--model_family", default='llama3', choices=['llama3', 'mistral'])
9 | parser.add_argument("--attribute_to_intervene", default='Sex', choices=['Sex', 'Race'])
10 | args = parser.parse_args()
11 |
12 | nproc_per_node=1
13 | cmd = ['torchrun', '--nproc_per_node', f"{nproc_per_node}", 'src/bias.py']
14 | for k, v in vars(args).items():
15 | cmd.append(f"--{k}")
16 | cmd.append(str(v))
17 | subprocess.run(args=cmd)
--------------------------------------------------------------------------------
/scripts/census_queries.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import argparse
3 |
4 | parser=argparse.ArgumentParser()
5 | parser.add_argument("--model_family", default='llama3', choices=['llama3', 'mistral'])
6 | args = parser.parse_args()
7 | nproc_per_node = 1
8 |
9 | model_family = args.model_family
10 |
11 | if model_family=="llama3":
12 | seeds=[1094356, 2345656, 823845969964454]
13 | weights_path = "src/llama3/pretrained/Meta-Llama-3-8B-Instruct/"
14 | tokenizer_path = "src/llama3/pretrained/Meta-Llama-3-8B-Instruct/tokenizer.model"
15 | elif model_family=="mistral":
16 | seeds=[34435326297321, 124638290, 823845969964454]
17 | weights_path = "src/mistral-inference/8B-Instruct/"
18 | tokenizer_path = "/src/mistral-inference/8B-Instruct/"
19 |
20 |
21 |
22 | for i, seed in enumerate(seeds):
23 | params = {
24 | "ckpt_dir": weights_path,
25 | "tokenizer_path": tokenizer_path,
26 | "max_seq_len": 8192,
27 | "max_batch_size": 2,
28 | "seed": seed,
29 | "temperature": 0.8,
30 | "query": '"Generate census data of 50 fictional people."',
31 | "exp_name": f"{model_family}/census{i + 5}", # this is the name of the experiment -- results are saved under outputs/exp_name,
32 | "genstates": "True", # if True, the script will save the states of the random number generator,
33 | "system": '"Return only the following information: Age, Sex, Citizenship, Race, Ethnicity, Marital Status, Number of Children, Occupation, Income, Education. For Race, choose only between following options: White American, Black or African American, American Indian or Alaska Native, Asian American, Native Hawaiian or Other Pacific Islander, Other or Two or more races (multiracial). For Ethinicity, choose only between following options: Non-Hispanic/Latino or Hispanic/Latino. Return a list in json format delimited by \\"\`\`\`\\"."',
34 | "model_family": model_family
35 | }
36 |
37 | cmd = ['torchrun', '--nproc_per_node', f"{nproc_per_node}", 'src/single_query.py']
38 | for j, k_v in enumerate(params.items()):
39 | k, v = k_v
40 | if j >= 4:
41 | cmd.append(f"--{k}={v}")
42 | else:
43 | cmd.append(f"--{k}")
44 | cmd.append(str(v))
45 | cmd = ' '.join(cmd)
46 | subprocess.run(args=cmd, shell=True)
47 |
--------------------------------------------------------------------------------
/scripts/cf_query.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | nproc_per_node = 1
3 |
4 | params = {
5 | 'max_seq_len': 2048,
6 | 'max_batch_size': 2,
7 | 'exp_name': "story-test",
8 | 'iteration': 1,
9 | 'prior': False,
10 | }
11 |
12 | cmd = ['torchrun', '--nproc_per_node', f"{nproc_per_node}", 'src/cf_query.py']
13 | for k, v in params.items():
14 | cmd.append(f"--{k}")
15 | cmd.append(str(v))
16 |
17 | subprocess.run(args=cmd)
--------------------------------------------------------------------------------
/scripts/stability.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import subprocess
3 |
4 | # Constant parameters
5 | # llama3
6 | ckpt_path=f"src/llama3/pretrained/Meta-Llama-3-8B-Instruct/"
7 | tokenizer_path=f"{ckpt_path}/tokenizer.model"
8 | model_family = 'llama3'
9 |
10 | # mistral
11 | # ckpt_path = f"src/mistral-inference/8B-Instruct/"
12 | # tokenizer_path = f"{ckpt_path}/src/mistral-inference/8B-Instruct/"
13 | # model_family = 'mistral'
14 |
15 | seed = 42
16 | system = "Keep your replies short and to the point."
17 | top_p = [0.75,0.9,0.95,0.99,0.999]
18 | top_k = [2,3,5,10,100]
19 | temperature = [0.0,0.2,0.4,0.6,0.8,1.0]
20 | input_file = f"data/questions.parquet"
21 | output_dir = "outputs/stability/mistral" # this is the name of the experiment -- results are saved under output_dir
22 | chunk_size = 25
23 | num_interventions = 2
24 | categorical = False
25 |
26 | parser=argparse.ArgumentParser()
27 | parser.add_argument("--ckpt_dir", default=ckpt_path)
28 | parser.add_argument("--tokenizer_path", default=tokenizer_path)
29 | parser.add_argument("--max_seq_len", default=2048, type=int)
30 | parser.add_argument("--max_batch_size", default=2, type=int)
31 | parser.add_argument("--seed", default=seed, type=int,
32 | help="the seed for the initialization of the random number generator")
33 | parser.add_argument("--input_file", default=input_file,
34 | help="path to file containing input queries")
35 | parser.add_argument("--output_dir", default=output_dir,
36 | help="path to output directory")
37 | parser.add_argument("--output_file_params", default="params",
38 | help="file name to store experiment parameters")
39 | parser.add_argument("--output_file_responses", default="stability_strings",
40 | help="file name to store all model responses")
41 | parser.add_argument("--system", default=system,
42 | help="the system prompt you want to use")
43 | parser.add_argument("--temperature", default=temperature,
44 | help="the temperature parameter of the model")
45 | parser.add_argument("--top_p",default=top_p,
46 | help="the value of p for top-p samplers")
47 | parser.add_argument("--top_k",default=top_k,
48 | help="the value of k for the top-k sampler")
49 | parser.add_argument("--chunk_size",default=chunk_size,
50 | help="number of queries to be written to file at once")
51 | parser.add_argument("--num_interventions",default=num_interventions,
52 | help="number of interventions per query")
53 | parser.add_argument("--top_p_intervention", default=0.9,
54 | help="top-p value to select the post-intervention token")
55 | parser.add_argument("--intervention_seed", default=seed,
56 | help="Seed to select the post-intervention token")
57 | parser.add_argument("--intervention_position_seed", default=seed,
58 | help="Seed to select the intervention position")
59 | parser.add_argument("--model_family", default=model_family, choices=['llama3', 'mistral'],
60 | help="Please select the model family")
61 | parser.add_argument("--categorical", default=False,
62 | help="If true, runs inverse transform sampler")
63 |
64 | args = parser.parse_args()
65 | if args.model_family == 'llama3':
66 | args.ckpt_dir = f"src/llama3/pretrained/Meta-Llama-3-8B-Instruct/"
67 | args.tokenizer_path = f"{args.ckpt_dir}/tokenizer.model"
68 | args.output_dir = "outputs/stability/llama3"
69 | elif args.model_family == 'mistral':
70 | args.ckpt_dir = f"src/mistral-inference/8B-Instruct/"
71 | args.tokenizer_path = f"{args.ckpt_dir}/tokenizer.model"
72 | args.output_dir = "outputs/stability/mistral"
73 | if args.categorical == 'True':
74 | args.top_p=[]
75 | args.top_k=[]
76 | args.output_dir="outputs/stability/categorical"
77 |
78 |
79 | nproc_per_node=1
80 | cmd = ['torchrun', '--nproc_per_node', f"{nproc_per_node}", 'src/stability.py']
81 | for k, v in vars(args).items():
82 | cmd.append(f"--{k}")
83 | cmd.append(str(v))
84 |
85 | print(cmd)
86 | subprocess.run(args=cmd)
87 |
88 |
89 |
--------------------------------------------------------------------------------
/scripts/story_query.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | nproc_per_node = 1
3 |
4 |
5 | params = {
6 | 'query': "Tell me a fantasy story about a captain. The story should have either a happy or a sad ending.",
7 | 'system': "Be creative and keep your response as short as possible.",
8 | 'seed': 2,
9 | 'temperature': 0.9,
10 | 'max_seq_len': 2048,
11 | 'max_batch_size': 2,
12 | 'exp_name': "story-test",
13 | 'sampler_type': "vocabulary", # "vocabulary" or "top-p position" or "top-p token" or "top-k token"
14 | 'top_p': 0.9, # this value doesn't matter if the respective sampler_type is not "top-p"
15 | 'top_k': 5, # this value doesn't matter if the respective sampler_type is not "top-k"
16 | 'genstates': "True",
17 | 'ckpt_dir':"src/llama3/pretrained/Meta-Llama-3-8B-Instruct/",
18 | 'tokenizer_path':"src/llama3/pretrained/Meta-Llama-3-8B-Instruct/tokenizer.model",
19 | 'model_family': "llama3"
20 | }
21 | # NOTE: instructions for using different model families
22 | # mistral:
23 | # - ckpt_dir="src/mistral-inference/8B-Instruct/"
24 | # - tokenizer_path="src/mistral-inference/8B-Instruct/"
25 | # - model_family="mistral"
26 | # llama3:
27 | # - ckpt_dir="src/llama3/pretrained/Meta-Llama-3-8B-Instruct/"
28 | # - tokenizer_path="src/llama3/pretrained/Meta-Llama-3-8B-Instruct/tokenizer.model"
29 | # - model_family="llama3"
30 |
31 | cmd = ['torchrun', '--nproc_per_node', f"{nproc_per_node}", 'src/single_query.py']
32 | for k, v in params.items():
33 | cmd.append(f"--{k}")
34 | cmd.append(str(v))
35 |
36 | subprocess.run(args=cmd)
--------------------------------------------------------------------------------
/src/cf_query.py:
--------------------------------------------------------------------------------
1 | # This software was developed by Ivi Chatzi, Nina Corvelo Benz, Eleni Straitouri, Stratis Tsirtsis, and Manuel Gomez Rodriguez.
2 | # If you use this code, please cite the paper "Counterfactual Token Generation in Large Language Models" by the same authors.
3 |
4 | from typing import List, Optional
5 |
6 | import fire
7 |
8 | import sys
9 | import os
10 | import json
11 | import torch
12 | from sampler import Sampler
13 |
14 | # mistral stuff
15 | from mistral_common.tokens.tokenizers.base import Tokenizer
16 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
17 | from mistral_common.protocol.instruct.request import ChatCompletionRequest
18 | from mistral_common.protocol.instruct.messages import (
19 | AssistantMessage,
20 | UserMessage,
21 | SystemMessage
22 | )
23 | from mistral_inference.generate import generate
24 | # Add the src directory to the Python path
25 | sys.path.append(os.path.abspath("src/mistral-inference/src"))
26 | from mistral_inference.main import load_tokenizer, get_model_cls
27 | from pathlib import Path
28 |
29 | # Add the src directory to the Python path
30 | sys.path.append(os.path.abspath("src/llama3"))
31 | from llama import Dialog, Llama
32 |
33 | def main(
34 | max_seq_len: int = 512,
35 | max_batch_size: int = 4,
36 | max_gen_len: Optional[int] = None,
37 | exp_name: str = "test",
38 | prior: bool = False,
39 | iteration: int = 1
40 | ):
41 | """
42 | Examples to run with the models finetuned for chat. Prompts correspond of chat
43 | turns between the user and assistant with the final one always being the user.
44 |
45 | An optional system prompt at the beginning to control how the model should respond
46 | is also supported.
47 |
48 | The context window of llama3 models is 8192 tokens, so `max_seq_len` needs to be <= 8192.
49 |
50 | `max_gen_len` is optional because finetuned models are able to stop generations naturally.
51 | """
52 |
53 | with open(f"outputs/{exp_name}/intervention_{iteration}.json") as f:
54 | data = json.load(f)
55 | system = data["system"]
56 | query = data["query"]
57 | partial_response = data["response"]
58 | top_p = data["top_p"]
59 | top_k = data["top_k"]
60 | temperature = data["temperature"]
61 | start_from = data["start_from"]
62 | seed = data["seed"]
63 | sampler_type = data["sampler_type"]
64 | model_family = data["model_family"]
65 | ckpt_dir = data["ckpt_dir"]
66 | tokenizer_path = data["tokenizer_path"]
67 |
68 | sampler = Sampler(sampler_type=sampler_type, top_p=top_p, top_k=top_k)
69 |
70 | with open(f"outputs/{exp_name}/rngstates_{iteration}.pt", "rb") as f:
71 | rngstates = torch.load(f)
72 | init_rng_state = rngstates[start_from,:]
73 |
74 | if model_family == "llama3":
75 | generator = Llama.build(
76 | ckpt_dir=ckpt_dir,
77 | tokenizer_path=tokenizer_path,
78 | max_seq_len=max_seq_len,
79 | max_batch_size=max_batch_size,
80 | seed=seed,
81 | init_rng_state=init_rng_state,
82 | prior=prior
83 | )
84 |
85 | dialogs: List[Dialog] = [
86 | [ {"role": "system", "content": system},
87 | {"role": "user", "content": query}]
88 | ]
89 |
90 | results, rngstates = generator.chat_completion(
91 | dialogs,
92 | max_gen_len=max_gen_len,
93 | temperature=temperature,
94 | sampler=sampler,
95 | genstates=False,
96 | partial_response=partial_response,
97 | )
98 |
99 | elif model_family == "mistral":
100 |
101 | # load the model
102 | mistral_tokenizer: MistralTokenizer = load_tokenizer(Path(ckpt_dir))
103 | tokenizer: Tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer
104 |
105 | model_cls = get_model_cls(ckpt_dir)
106 | model = model_cls.from_folder(Path(ckpt_dir), max_batch_size=max_batch_size, num_pipeline_ranks=1)
107 |
108 | messages: List[SystemMessage | UserMessage | AssistantMessage] = []
109 | messages += [SystemMessage(content=system)]
110 | messages += [UserMessage(content=query)]
111 |
112 | chat_completion_request = ChatCompletionRequest(messages=messages)
113 | tokenized = mistral_tokenizer.encode_chat_completion(chat_completion_request)
114 | tokens = tokenized.tokens
115 |
116 | # initialize a random number generator with the given seed
117 | rng = torch.Generator(device="cuda")
118 | rng.manual_seed(seed)
119 |
120 | if not prior:
121 | # turn init_rng_state into a torch.ByteTensor
122 | init_rng_state = torch.tensor(init_rng_state, device="cpu").to(torch.uint8)
123 | rng.set_state(init_rng_state)
124 |
125 | fixed_tokens = tokenizer.encode(partial_response, bos=False, eos=False)
126 |
127 | generated_tokens, _, _ = generate( # type: ignore[operator]
128 | encoded_prompts = [tokens],
129 | model = model,
130 | max_tokens = max_seq_len,
131 | temperature = temperature,
132 | eos_id = tokenizer.eos_id,
133 | sampler = sampler,
134 | genstates = False,
135 | rng = rng,
136 | fixed_tokens=fixed_tokens
137 | )
138 |
139 | results = [
140 | {
141 | "generation": {
142 | "role": "assistant",
143 | "content": tokenizer.decode(t),
144 | "token_list": [tokenizer.decode([x]) for x in t],
145 | },
146 | }
147 | for t in generated_tokens
148 | ]
149 |
150 |
151 | output_dir = os.path.join("outputs", exp_name)
152 |
153 | output = {}
154 | output["system"] = system
155 | output["query"] = query
156 | output["response"] = results[0]["generation"]["content"]
157 | output["seed"] = seed
158 | output["temperature"] = temperature
159 | output["sampler_type"] = sampler_type
160 | output["top_p"] = top_p
161 | output["top_k"] = top_k
162 | output["model_family"] = model_family
163 | output["start_from"] = 0
164 | output["token_list"] = {ind: tok for ind, tok in enumerate(results[0]["generation"]["token_list"])}
165 |
166 | with open(os.path.join(output_dir, f"counterfactual_{iteration}.json"), "w") as f:
167 | json.dump(output, f, indent=4)
168 |
169 | with open(os.path.join(output_dir, f"intervention_{iteration+1}.json"), "w") as f:
170 | json.dump(output, f, indent=4)
171 |
172 | print("----------------")
173 | print("SYSTEM:", system)
174 | print("----------------")
175 | print("QUERY:", query)
176 | print("----------------")
177 | print("RESPONSE:", results[0]["generation"]["content"])
178 | print("----------------")
179 |
180 |
181 | if __name__ == "__main__":
182 | fire.Fire(main)
183 |
184 |
--------------------------------------------------------------------------------
/src/llama3/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to make participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | This Code of Conduct also applies outside the project spaces when there is a
56 | reasonable belief that an individual's behavior may have a negative impact on
57 | the project or its community.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported by contacting the project team at . All
63 | complaints will be reviewed and investigated and will result in a response that
64 | is deemed necessary and appropriate to the circumstances. The project team is
65 | obligated to maintain confidentiality with regard to the reporter of an incident.
66 | Further details of specific enforcement policies may be posted separately.
67 |
68 | Project maintainers who do not follow or enforce the Code of Conduct in good
69 | faith may face temporary or permanent repercussions as determined by other
70 | members of the project's leadership.
71 |
72 | ## Attribution
73 |
74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76 |
77 | [homepage]: https://www.contributor-covenant.org
78 |
79 | For answers to common questions about this code of conduct, see
80 | https://www.contributor-covenant.org/faq
81 |
--------------------------------------------------------------------------------
/src/llama3/USE_POLICY.md:
--------------------------------------------------------------------------------
1 | # Meta Llama 3 Acceptable Use Policy
2 |
3 | Meta is committed to promoting safe and fair use of its tools and features, including Llama 3. If you access or use Llama 3, you agree to this Acceptable Use Policy (“Policy”). The most recent copy of this policy can be found at [ai.meta.com/llama/use-policy](http://ai.meta.com/llama/use-policy).
4 |
5 | ## Prohibited Uses
6 | We want everyone to use Llama 3 safely and responsibly. You agree you will not use, or allow others to use, Llama 3 to:
7 |
8 | 1. Violate the law or others’ rights, including to:
9 | 1. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
10 | 1. Violence or terrorism
11 | 2. Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
12 | 3. Human trafficking, exploitation, and sexual violence
13 | 4. The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
14 | 5. Sexual solicitation
15 | 6. Any other criminal activity
16 | 2. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
17 | 3. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
18 | 4. Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
19 | 5. Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
20 | 6. Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any products or services using the Llama 3 Materials
21 | 7. Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
22 |
23 |
24 |
25 | 2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of Llama 3 related to the following:
26 | 1. Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
27 | 2. Guns and illegal weapons (including weapon development)
28 | 3. Illegal drugs and regulated/controlled substances
29 | 4. Operation of critical infrastructure, transportation technologies, or heavy machinery
30 | 5. Self-harm or harm to others, including suicide, cutting, and eating disorders
31 | 6. Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
32 |
33 |
34 |
35 | 3. Intentionally deceive or mislead others, including use of Llama 3 related to the following:
36 | 1. Generating, promoting, or furthering fraud or the creation or promotion of disinformation
37 | 2. Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
38 | 3. Generating, promoting, or further distributing spam
39 | 4. Impersonating another individual without consent, authorization, or legal right
40 | 5. Representing that the use of Llama 3 or outputs are human-generated
41 | 6. Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
42 | 4. Fail to appropriately disclose to end users any known dangers of your AI system
43 |
44 | Please report any violation of this Policy, software “bug,” or other problems that could lead to a violation of this Policy through one of the following means:
45 |
46 | * Reporting issues with the model: [github.com/facebookresearch/llama](http://github.com/facebookresearch/llama)
47 | * Reporting risky content generated by the model: [developers.facebook.com/llama_output_feedback](http://developers.facebook.com/llama_output_feedback)
48 | * Reporting bugs and security concerns: [facebook.com/whitehat/info](http://facebook.com/whitehat/info)
49 | * Reporting violations of the Acceptable Use Policy or unlicensed uses of Llama: [LlamaUseReport@meta.com](mailto:LlamaUseReport@meta.com)
50 |
51 |
--------------------------------------------------------------------------------
/src/llama3/llama/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
3 |
4 | from .generation import Llama
5 | from .model import ModelArgs, Transformer
6 | from .tokenizer import Dialog, Tokenizer
7 |
--------------------------------------------------------------------------------
/src/llama3/llama/test_tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
3 |
4 | import os
5 | from unittest import TestCase
6 | from llama.tokenizer import ChatFormat, Tokenizer
7 |
8 | # TOKENIZER_PATH= python -m unittest llama/test_tokenizer.py
9 |
10 | class TokenizerTests(TestCase):
11 | def setUp(self):
12 | self.tokenizer = Tokenizer(os.environ["TOKENIZER_PATH"])
13 | self.format = ChatFormat(self.tokenizer)
14 |
15 | def test_special_tokens(self):
16 | self.assertEqual(
17 | self.tokenizer.special_tokens["<|begin_of_text|>"],
18 | 128000,
19 | )
20 |
21 | def test_encode(self):
22 | self.assertEqual(
23 | self.tokenizer.encode(
24 | "This is a test sentence.",
25 | bos=True,
26 | eos=True
27 | ),
28 | [128000, 2028, 374, 264, 1296, 11914, 13, 128001],
29 | )
30 |
31 | def test_decode(self):
32 | self.assertEqual(
33 | self.tokenizer.decode(
34 | [128000, 2028, 374, 264, 1296, 11914, 13, 128001],
35 | ),
36 | "<|begin_of_text|>This is a test sentence.<|end_of_text|>",
37 | )
38 |
39 | def test_encode_message(self):
40 | message = {
41 | "role": "user",
42 | "content": "This is a test sentence.",
43 | }
44 | self.assertEqual(
45 | self.format.encode_message(message),
46 | [
47 | 128006, # <|start_header_id|>
48 | 882, # "user"
49 | 128007, # <|end_header_id|>
50 | 271, # "\n\n"
51 | 2028, 374, 264, 1296, 11914, 13, # This is a test sentence.
52 | 128009, # <|eot_id|>
53 | ]
54 | )
55 |
56 | def test_encode_dialog(self):
57 | dialog = [
58 | {
59 | "role": "system",
60 | "content": "This is a test sentence.",
61 | },
62 | {
63 | "role": "user",
64 | "content": "This is a response.",
65 | }
66 | ]
67 | self.assertEqual(
68 | self.format.encode_dialog_prompt(dialog),
69 | [
70 | 128000, # <|begin_of_text|>
71 | 128006, # <|start_header_id|>
72 | 9125, # "system"
73 | 128007, # <|end_header_id|>
74 | 271, # "\n\n"
75 | 2028, 374, 264, 1296, 11914, 13, # "This is a test sentence."
76 | 128009, # <|eot_id|>
77 | 128006, # <|start_header_id|>
78 | 882, # "user"
79 | 128007, # <|end_header_id|>
80 | 271, # "\n\n"
81 | 2028, 374, 264, 2077, 13, # "This is a response.",
82 | 128009, # <|eot_id|>
83 | 128006, # <|start_header_id|>
84 | 78191, # "assistant"
85 | 128007, # <|end_header_id|>
86 | 271, # "\n\n"
87 | ]
88 | )
89 |
--------------------------------------------------------------------------------
/src/llama3/pretrained/.placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/src/llama3/pretrained/.placeholder
--------------------------------------------------------------------------------
/src/llama3/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | fairscale
3 | fire
4 | tiktoken==0.4.0
5 | blobfile
6 |
--------------------------------------------------------------------------------
/src/llama3/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
3 |
4 | from setuptools import find_packages, setup
5 |
6 |
7 | def get_requirements(path: str):
8 | return [l.strip() for l in open(path)]
9 |
10 |
11 | setup(
12 | name="llama3",
13 | version="0.0.1",
14 | packages=find_packages(),
15 | install_requires=get_requirements("requirements.txt"),
16 | )
17 |
--------------------------------------------------------------------------------
/src/mistral-inference/assets/smoe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Networks-Learning/counterfactual-llms/ad3e0702ac5308f66c725e7f8725e7cd70f7a363/src/mistral-inference/assets/smoe.png
--------------------------------------------------------------------------------
/src/mistral-inference/deploy/.dockerignore:
--------------------------------------------------------------------------------
1 | *
2 | !entrypoint.sh
3 |
--------------------------------------------------------------------------------
/src/mistral-inference/deploy/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM --platform=amd64 nvcr.io/nvidia/cuda:12.1.0-devel-ubuntu22.04 as base
2 |
3 | WORKDIR /workspace
4 |
5 | RUN apt update && \
6 | apt install -y python3-pip python3-packaging \
7 | git ninja-build && \
8 | pip3 install -U pip
9 |
10 | # Tweak this list to reduce build time
11 | # https://developer.nvidia.com/cuda-gpus
12 | ENV TORCH_CUDA_ARCH_LIST "7.0;7.2;7.5;8.0;8.6;8.9;9.0"
13 |
14 | RUN pip3 install "torch==2.1.1"
15 |
16 | # This build is slow but NVIDIA does not provide binaries. Increase MAX_JOBS as needed.
17 | RUN pip3 install "git+https://github.com/stanford-futuredata/megablocks.git"
18 | RUN pip3 install "git+https://github.com/vllm-project/vllm.git"
19 | RUN pip3 install "xformers==0.0.23" "transformers==4.36.0" "fschat[model_worker]==0.2.34"
20 |
21 | RUN git clone https://github.com/NVIDIA/apex && \
22 | cd apex && git checkout 2386a912164b0c5cfcd8be7a2b890fbac5607c82 && \
23 | sed -i '/check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)/d' setup.py && \
24 | python3 setup.py install --cpp_ext --cuda_ext
25 |
26 |
27 | COPY entrypoint.sh .
28 |
29 | RUN chmod +x /workspace/entrypoint.sh
30 |
31 | ENTRYPOINT ["/workspace/entrypoint.sh"]
--------------------------------------------------------------------------------
/src/mistral-inference/deploy/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [[ ! -z "${HF_TOKEN}" ]]; then
4 | echo "The HF_TOKEN environment variable is set, logging to Hugging Face."
5 | python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
6 | else
7 | echo "The HF_TOKEN environment variable is not set or empty, not logging to Hugging Face."
8 | fi
9 |
10 | # Run the provided command
11 | exec python3 -u -m vllm.entrypoints.openai.api_server "$@"
12 |
--------------------------------------------------------------------------------
/src/mistral-inference/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "mistral_inference"
3 | version = "1.5.0"
4 | description = ""
5 | authors = ["bam4d "]
6 | readme = "README.md"
7 | packages = [{ include = "mistral_inference", from = "src" }]
8 |
9 | [tool.ruff]
10 | lint.select = ["E", "F", "W", "Q", "I"]
11 | lint.ignore = ["E203"]
12 | lint.fixable = ["ALL"]
13 | lint.unfixable = []
14 | line-length = 120
15 | exclude = ["docs", "build", "tutorials"]
16 |
17 | [tool.mypy]
18 | disallow_untyped_defs = true
19 | show_error_codes = true
20 | no_implicit_optional = true
21 | warn_return_any = true
22 | warn_unused_ignores = true
23 | exclude = ["docs", "tools", "build"]
24 |
25 | [tool.poetry.dependencies]
26 | python = "^3.9.10"
27 | xformers = ">=0.0.24"
28 | simple-parsing = ">=0.1.5"
29 | fire = ">=0.6.0"
30 | mistral_common = ">=1.4.0"
31 | safetensors = ">=0.4.0"
32 | pillow = ">=10.3.0"
33 |
34 | [tool.poetry.group.dev.dependencies]
35 | types-protobuf = "4.24.0.20240129"
36 | mypy-protobuf = "^3.5.0"
37 | pytest = "7.4.4"
38 | ruff = "^0.2.2"
39 | mypy = "^1.8.0"
40 |
41 | [build-system]
42 | requires = ["poetry-core>=1.0.0"]
43 | build-backend = "poetry.core.masonry.api"
44 |
45 | [tool.pytest.ini_options]
46 | testpaths = ["./tests"]
47 |
48 | [tool.poetry.scripts]
49 | mistral-chat = "mistral_inference.main:mistral_chat"
50 | mistral-demo = "mistral_inference.main:mistral_demo"
51 |
--------------------------------------------------------------------------------
/src/mistral-inference/src/mistral_inference/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.5.0"
2 |
--------------------------------------------------------------------------------
/src/mistral-inference/src/mistral_inference/args.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Optional
3 |
4 | from simple_parsing.helpers import Serializable
5 |
6 | from mistral_inference.lora import LoraArgs
7 | from mistral_inference.moe import MoeArgs
8 |
9 |
10 | @dataclass
11 | class VisionEncoderArgs:
12 | hidden_size: int
13 | num_channels: int
14 | image_size: int
15 | patch_size: int
16 | intermediate_size: int
17 | num_hidden_layers: int
18 | num_attention_heads: int
19 | rope_theta: float = 1e4 # for rope-2D
20 | image_token_id: int = 10
21 |
22 |
23 | @dataclass
24 | class TransformerArgs(Serializable):
25 | dim: int
26 | n_layers: int
27 | head_dim: int
28 | hidden_dim: int
29 | n_heads: int
30 | n_kv_heads: int
31 | norm_eps: float
32 | vocab_size: int
33 |
34 | max_batch_size: int = 0
35 |
36 | # For rotary embeddings. If not set, will be inferred
37 | rope_theta: Optional[float] = None
38 | # If this is set, we will use MoE layers instead of dense layers.
39 | moe: Optional[MoeArgs] = None
40 | # If this is set, we will load LoRA linear layers instead of linear layers.
41 | lora: Optional[LoraArgs] = None
42 | sliding_window: Optional[int] | Optional[List[int]] = None
43 | _sliding_window: Optional[int] | Optional[List[int]] = None
44 | model_type: str = "transformer"
45 |
46 | vision_encoder: Optional[VisionEncoderArgs] = None
47 |
48 | def __post_init__(self) -> None:
49 | assert self.model_type == "transformer", self.model_type
50 | assert self.sliding_window is None or self._sliding_window is None
51 |
52 | # hack for now so that vLLM is supported correctly
53 | self.sliding_window = self.sliding_window if self.sliding_window is not None else self._sliding_window
54 |
55 |
56 | @dataclass
57 | class MambaArgs(Serializable):
58 | dim: int
59 | n_layers: int
60 | vocab_size: int
61 | n_groups: int
62 | rms_norm: bool
63 | residual_in_fp32: bool
64 | fused_add_norm: bool
65 | pad_vocab_size_multiple: int
66 | tie_embeddings: bool
67 | model_type: str = "mamba"
68 |
69 | def __post_init__(self) -> None:
70 | assert self.model_type == "mamba", self.model_type
71 |
--------------------------------------------------------------------------------
/src/mistral-inference/src/mistral_inference/generate.py:
--------------------------------------------------------------------------------
1 | # DISCLAIMER: The original code was edited by Ivi Chatzi, Nina Corvelo Benz, Eleni Straitouri, Stratis Tsirtsis, and Manuel Gomez Rodriguez.
2 | # If you use this version of the code, please cite the paper "Counterfactual Token Generation in Large Language Models" by the same authors.
3 |
4 | from typing import List, Optional, Tuple
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from mistral_inference.cache import BufferCache
10 | from mistral_inference.mamba import Mamba
11 | from mistral_inference.transformer import Transformer
12 |
13 |
14 | @torch.inference_mode()
15 | def generate_mamba(
16 | encoded_prompts: List[List[int]],
17 | model: Mamba,
18 | *,
19 | max_tokens: int,
20 | temperature: float,
21 | chunk_size: Optional[int] = None,
22 | eos_id: Optional[int] = None,
23 | ) -> Tuple[List[List[int]], List[List[float]]]:
24 | input_ids = torch.tensor(encoded_prompts, device=model.device)
25 | output = model.model.generate(
26 | input_ids=input_ids,
27 | max_length=input_ids.shape[-1] + max_tokens,
28 | cg=True,
29 | return_dict_in_generate=True,
30 | output_scores=True,
31 | enable_timing=False,
32 | eos_token_id=eos_id,
33 | temperature=temperature,
34 | top_p=0.8,
35 | )
36 | generated_tokens = output.sequences[:, input_ids.shape[-1] :].tolist()
37 |
38 | _logprobs: List[List[float]] = [[] for _ in range(len(generated_tokens))]
39 | for seq_idx, batch_score in enumerate(output.scores):
40 | for batch_idx, score in enumerate(batch_score.tolist()):
41 | _logprobs[batch_idx].append(score[generated_tokens[batch_idx][seq_idx]])
42 |
43 | return generated_tokens, _logprobs
44 |
45 |
46 | @torch.inference_mode()
47 | def generate(
48 | encoded_prompts: List[List[int]],
49 | model: Transformer,
50 | images: List[List[np.ndarray]] = [],
51 | *,
52 | max_tokens: int,
53 | temperature: float,
54 | chunk_size: Optional[int] = None,
55 | eos_id: Optional[int] = None,
56 | genstates: bool = True,
57 | sampler,
58 | rng: torch.Generator,
59 | fixed_tokens: Optional[List[int]] = None,
60 | auto_intervention: bool = False
61 | ) -> Tuple[List[List[int]], List[List[float]]]:
62 | images_torch: List[List[torch.Tensor]] = []
63 | if images:
64 | assert chunk_size is None
65 | images_torch = [
66 | [torch.tensor(im, device=model.device, dtype=model.dtype) for im in images_for_sample]
67 | for images_for_sample in images
68 | ]
69 |
70 | model = model.eval()
71 | B, V = len(encoded_prompts), model.args.vocab_size
72 |
73 | seqlens = [len(x) for x in encoded_prompts]
74 |
75 | # Cache
76 | cache_window = max(seqlens) + max_tokens
77 | cache = BufferCache(
78 | model.n_local_layers,
79 | model.args.max_batch_size,
80 | cache_window,
81 | model.args.n_kv_heads,
82 | model.args.head_dim,
83 | model.args.sliding_window,
84 | )
85 | cache.to(device=model.device, dtype=model.dtype)
86 | cache.reset()
87 |
88 | # Bookkeeping
89 | logprobs: List[List[float]] = [[] for _ in range(B)]
90 | last_token_prelogits = None
91 |
92 | # One chunk if size not specified
93 | max_prompt_len = max(seqlens)
94 | if chunk_size is None:
95 | chunk_size = max_prompt_len
96 |
97 | flattened_images: List[torch.Tensor] = sum(images_torch, [])
98 |
99 | # Encode prompt by chunks
100 | for s in range(0, max_prompt_len, chunk_size):
101 | prompt_chunks = [p[s : s + chunk_size] for p in encoded_prompts]
102 | assert all(len(p) > 0 for p in prompt_chunks)
103 | prelogits = model.forward(
104 | torch.tensor(sum(prompt_chunks, []), device=model.device, dtype=torch.long),
105 | images=flattened_images,
106 | seqlens=[len(p) for p in prompt_chunks],
107 | cache=cache,
108 | )
109 | logits = torch.log_softmax(prelogits, dim=-1)
110 |
111 | if last_token_prelogits is not None:
112 | # Pass > 1
113 | last_token_logits = torch.log_softmax(last_token_prelogits, dim=-1)
114 | for i_seq in range(B):
115 | logprobs[i_seq].append(last_token_logits[i_seq, prompt_chunks[i_seq][0]].item())
116 |
117 | offset = 0
118 | for i_seq, sequence in enumerate(prompt_chunks):
119 | logprobs[i_seq].extend([logits[offset + i, sequence[i + 1]].item() for i in range(len(sequence) - 1)])
120 | offset += len(sequence)
121 |
122 | last_token_prelogits = prelogits.index_select(
123 | 0,
124 | torch.tensor([len(p) for p in prompt_chunks], device=prelogits.device).cumsum(dim=0) - 1,
125 | )
126 | assert last_token_prelogits.shape == (B, V)
127 |
128 | # decode
129 | generated_tensors = []
130 | is_finished = torch.tensor([False for _ in range(B)])
131 | if genstates:
132 | token_genstates = torch.zeros((max_tokens, rng.get_state().numel()), dtype=torch.uint8)
133 | assert last_token_prelogits is not None
134 |
135 | if fixed_tokens is not None:
136 | fixed_counter = 0
137 |
138 | for gen_state_counter in range(max_tokens):
139 |
140 | if fixed_tokens is not None and fixed_counter < len(fixed_tokens):
141 | if auto_intervention and fixed_counter == len(fixed_tokens)-1:
142 | if temperature > 0:
143 | probs = torch.softmax(last_token_prelogits / temperature, dim=-1)
144 | next_token = sampler.intervention(probs, torch.tensor(fixed_tokens[fixed_counter], device=model.device))
145 | next_token = next_token.reshape(-1)
146 | else:
147 | next_token = torch.topk(last_token_prelogits,k=2,dim=-1).indices[0][1]
148 | next_token = next_token.unsqueeze(0)
149 | fixed_counter += 1
150 | else:
151 | next_token = torch.tensor(fixed_tokens[fixed_counter], device=model.device)
152 | next_token = next_token.unsqueeze(0)
153 | fixed_counter += 1
154 | else:
155 | ##########################################################################################
156 | # NOTE: the sampling takes place here
157 | if temperature > 0:
158 | probs = torch.softmax(last_token_prelogits / temperature, dim=-1)
159 | # get the state of the random number generator
160 | if genstates:
161 | token_genstates[gen_state_counter] = rng.get_state()
162 | next_token = sampler.sample(probs, rng)
163 | next_token = next_token.reshape(-1)
164 | else:
165 | next_token = torch.argmax(last_token_prelogits, dim=-1)
166 | ##########################################################################################
167 |
168 | if eos_id is not None:
169 | is_finished = is_finished | (next_token == eos_id).cpu()
170 |
171 | if is_finished.all():
172 | break
173 |
174 | last_token_logits = torch.log_softmax(last_token_prelogits, dim=-1)
175 | for i in range(B):
176 | logprobs[i].append(last_token_logits[i, next_token[i]].item())
177 |
178 | generated_tensors.append(next_token[:, None])
179 | last_token_prelogits = model.forward(next_token, seqlens=[1] * B, cache=cache)
180 | assert last_token_prelogits.shape == (B, V)
181 |
182 | generated_tokens: List[List[int]]
183 | if generated_tensors:
184 | generated_tokens = torch.cat(generated_tensors, 1).tolist()
185 | else:
186 | generated_tokens = []
187 |
188 | if genstates:
189 | token_genstates = token_genstates.tolist()
190 |
191 | return (generated_tokens, logprobs, token_genstates if genstates else None)
192 |
--------------------------------------------------------------------------------
/src/mistral-inference/src/mistral_inference/lora.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from dataclasses import dataclass
3 | from pathlib import Path
4 | from typing import Any, Dict, NamedTuple, Union
5 |
6 | import safetensors.torch
7 | import torch
8 | import torch.nn as nn
9 | from simple_parsing.helpers import Serializable
10 |
11 |
12 | @dataclass
13 | class LoraArgs(Serializable):
14 | rank: int
15 | scaling: float
16 |
17 | def __post_init__(self) -> None:
18 | assert self.rank > 0
19 | assert self.scaling > 0.0
20 |
21 |
22 | class LoRALinear(nn.Module):
23 | """
24 | Implementation of:
25 | - LoRA: https://arxiv.org/abs/2106.09685
26 |
27 | Notes:
28 | - Freezing is handled at network level, not layer level.
29 | - Scaling factor controls relative importance of LoRA skip
30 | connection versus original frozen weight. General guidance is
31 | to keep it to 2.0 and sweep over learning rate when changing
32 | the rank.
33 | """
34 |
35 | def __init__(
36 | self,
37 | in_features: int,
38 | out_features: int,
39 | rank: int,
40 | scaling: float,
41 | bias: bool = False,
42 | ):
43 | super().__init__()
44 |
45 | self.in_features = in_features
46 | self.out_features = out_features
47 | assert not bias
48 | self.bias = bias
49 | self.rank = rank
50 | self.scaling = scaling
51 |
52 | self.lora_A = nn.Linear(
53 | self.in_features,
54 | self.rank,
55 | bias=self.bias,
56 | )
57 | self.lora_B = nn.Linear(
58 | self.rank,
59 | self.out_features,
60 | bias=self.bias,
61 | )
62 |
63 | self.linear = nn.Linear(self.in_features, self.out_features, bias=self.bias)
64 |
65 | # make sure no LoRA weights are marked as "missing" in load_state_dict
66 | def ignore_missing_keys(m: nn.Module, incompatible_keys: NamedTuple) -> None:
67 | incompatible_keys.missing_keys[:] = [] # type: ignore
68 |
69 | self.register_load_state_dict_post_hook(ignore_missing_keys)
70 |
71 | def forward(self, x: torch.Tensor) -> torch.Tensor:
72 | lora = self.lora_B(self.lora_A(x))
73 | result: torch.Tensor = self.linear(x) + lora * self.scaling
74 | return result
75 |
76 | def _load_from_state_dict(self, state_dict: Dict[str, Any], prefix: str, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
77 | key_name = prefix + "weight"
78 |
79 | # full checkpoint
80 | if key_name in state_dict:
81 | w_ref = state_dict[key_name]
82 |
83 | # load frozen weights
84 | state_dict = {
85 | "linear.weight": w_ref,
86 | "lora_A.weight": torch.zeros_like(self.lora_A.weight, device=w_ref.device, dtype=w_ref.dtype),
87 | "lora_B.weight": torch.zeros_like(self.lora_B.weight, device=w_ref.device, dtype=w_ref.dtype),
88 | }
89 | self.load_state_dict(state_dict, assign=True, strict=True)
90 |
91 |
92 | class LoRALoaderMixin:
93 | def load_lora(self, lora_path: Union[Path, str], scaling: float = 2.0) -> None:
94 | """Loads LoRA checkpoint"""
95 |
96 | lora_path = Path(lora_path)
97 | assert lora_path.is_file(), f"{lora_path} does not exist or is not a file"
98 |
99 | state_dict = safetensors.torch.load_file(lora_path)
100 |
101 | self._load_lora_state_dict(state_dict, scaling=scaling)
102 |
103 | def _load_lora_state_dict(self, lora_state_dict: Dict[str, torch.Tensor], scaling: float = 2.0) -> None:
104 | """Loads LoRA state_dict"""
105 | lora_dtypes = set([p.dtype for p in lora_state_dict.values()])
106 | assert (
107 | len(lora_dtypes) == 1
108 | ), f"LoRA weights have multiple different dtypes {lora_dtypes}. All weights need to have the same dtype"
109 | lora_dtype = lora_dtypes.pop()
110 | assert lora_dtype == self.dtype, f"LoRA weights dtype differs from model's dtype {lora_dtype} != {self.dtype}" # type: ignore[attr-defined]
111 | assert all("lora" in key for key in lora_state_dict.keys())
112 |
113 | # move tensors to device
114 | lora_state_dict = {k: v.to(self.device) for k, v in lora_state_dict.items()} # type: ignore[attr-defined]
115 |
116 | state_dict = self.state_dict() # type: ignore[attr-defined]
117 |
118 | if self.args.lora is None: # type: ignore[attr-defined]
119 | logging.info("Loading and merging LoRA weights...")
120 |
121 | # replace every nn.Linear with a LoRALinear with 'meta' device except the output layer
122 | named_modules = dict(self.named_modules()) # type: ignore[attr-defined]
123 | for name, module in named_modules.items():
124 | if isinstance(module, nn.Linear) and name != "output":
125 | layer_id = name.split(".")[1]
126 | if layer_id not in self.layers: # type: ignore[attr-defined]
127 | logging.debug(
128 | "Skipping parameter %s at pipeline rank %d",
129 | name,
130 | self.pipeline_rank, # type: ignore[attr-defined]
131 | )
132 | elif (name + ".lora_B.weight") in lora_state_dict:
133 | weight = (
134 | module.weight
135 | + (lora_state_dict[name + ".lora_B.weight"] @ lora_state_dict[name + ".lora_A.weight"])
136 | * scaling
137 | )
138 |
139 | state_dict[name + ".weight"] = weight
140 | else:
141 | logging.info("Loading LoRA weights...")
142 | for k, v in lora_state_dict.items():
143 | state_dict.update(lora_state_dict)
144 |
145 | layer_id = k.split(".")[1]
146 | if layer_id in self.layers: # type: ignore[attr-defined]
147 | state_dict[k] = v
148 | else:
149 | logging.debug(
150 | "Skipping parameter %s at pipeline rank %d",
151 | k,
152 | self.pipeline_rank, # type: ignore[attr-defined]
153 | )
154 |
155 | self.load_state_dict(state_dict, strict=True) # type: ignore[attr-defined]
156 |
--------------------------------------------------------------------------------
/src/mistral-inference/src/mistral_inference/mamba.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | from typing import List, Optional, Union
4 |
5 | import safetensors
6 | import torch
7 | import torch.nn as nn
8 |
9 | from mistral_inference.args import MambaArgs
10 | from mistral_inference.cache import BufferCache
11 | from mistral_inference.model import ModelBase
12 |
13 | _is_mamba_installed = False
14 | try:
15 | from mamba_ssm.models.config_mamba import MambaConfig
16 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
17 |
18 | _is_mamba_installed = True
19 | except ImportError:
20 | _is_mamba_installed = False
21 |
22 |
23 | class Mamba(ModelBase, nn.Module):
24 | def __init__(self, args: MambaArgs):
25 | super().__init__()
26 | self.args = args
27 | assert _is_mamba_installed, "Mamba is not installed. Please install it using `pip install mamba-ssm`."
28 |
29 | # make sure naming is consistent with `mamba_ssm`
30 | config = MambaConfig(
31 | d_model=args.dim,
32 | n_layer=args.n_layers,
33 | vocab_size=args.vocab_size,
34 | ssm_cfg={"ngroups": args.n_groups, "layer": "Mamba2"},
35 | attn_layer_idx=[],
36 | attn_cfg={},
37 | rms_norm=args.rms_norm,
38 | residual_in_fp32=args.residual_in_fp32,
39 | fused_add_norm=args.fused_add_norm,
40 | pad_vocab_size_multiple=args.pad_vocab_size_multiple,
41 | tie_embeddings=args.tie_embeddings,
42 | )
43 | self.model = MambaLMHeadModel(config)
44 |
45 | @property
46 | def dtype(self) -> torch.dtype:
47 | return next(self.parameters()).dtype
48 |
49 | @property
50 | def device(self) -> torch.device:
51 | return next(self.parameters()).device
52 |
53 | def forward(
54 | self,
55 | input_ids: torch.Tensor,
56 | seqlens: List[int], # not supported for now
57 | cache: Optional[BufferCache] = None, # not supported for now
58 | ) -> torch.Tensor:
59 | lm_output = self.model(input_ids)
60 | result: torch.Tensor = lm_output.logits
61 | return result
62 |
63 | @staticmethod
64 | def from_folder(
65 | folder: Union[Path, str],
66 | max_batch_size: int = 1,
67 | num_pipeline_ranks: int = 1,
68 | device: Union[torch.device, str] = "cuda",
69 | dtype: Optional[torch.dtype] = None,
70 | ) -> "Mamba":
71 | with open(Path(folder) / "params.json", "r") as f:
72 | model_args = MambaArgs.from_dict(json.load(f))
73 |
74 | with torch.device("meta"):
75 | model = Mamba(model_args)
76 |
77 | model_file = Path(folder) / "consolidated.safetensors"
78 |
79 | assert model_file.exists(), f"Make sure {model_file} exists."
80 | loaded = safetensors.torch.load_file(str(model_file))
81 |
82 | model.load_state_dict(loaded, assign=True, strict=True)
83 | return model.to(device=device, dtype=dtype)
84 |
--------------------------------------------------------------------------------
/src/mistral-inference/src/mistral_inference/model.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from pathlib import Path
3 | from typing import List, Optional, Union
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from mistral_inference.cache import BufferCache
9 |
10 |
11 | class ModelBase(nn.Module, ABC):
12 | def __init__(self) -> None:
13 | super().__init__()
14 |
15 | @property
16 | @abstractmethod
17 | def dtype(self) -> torch.dtype:
18 | pass
19 |
20 | @property
21 | @abstractmethod
22 | def device(self) -> torch.device:
23 | pass
24 |
25 | @abstractmethod
26 | def forward(
27 | self,
28 | input_ids: torch.Tensor,
29 | seqlens: List[int], # not supported for now
30 | cache: Optional[BufferCache] = None, # not supported for now
31 | ) -> torch.Tensor:
32 | pass
33 |
34 | @staticmethod
35 | @abstractmethod
36 | def from_folder(
37 | folder: Union[Path, str],
38 | max_batch_size: int = 1,
39 | num_pipeline_ranks: int = 1,
40 | device: Union[torch.device, str] = "cuda",
41 | dtype: Optional[torch.dtype] = None,
42 | ) -> "ModelBase":
43 | pass
44 |
--------------------------------------------------------------------------------
/src/mistral-inference/src/mistral_inference/moe.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from typing import List
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from simple_parsing.helpers import Serializable
7 | from torch import nn
8 |
9 |
10 | @dataclasses.dataclass
11 | class MoeArgs(Serializable):
12 | num_experts: int
13 | num_experts_per_tok: int
14 |
15 |
16 | class MoeLayer(nn.Module):
17 | def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs):
18 | super().__init__()
19 | assert len(experts) > 0
20 | self.experts = nn.ModuleList(experts)
21 | self.gate = gate
22 | self.args = moe_args
23 |
24 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
25 | gate_logits = self.gate(inputs)
26 | weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)
27 | weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype)
28 | results = torch.zeros_like(inputs)
29 | for i, expert in enumerate(self.experts):
30 | batch_idx, nth_expert = torch.where(selected_experts == i)
31 | results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs[batch_idx])
32 | return results
33 |
--------------------------------------------------------------------------------
/src/mistral-inference/src/mistral_inference/rope.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch
4 |
5 |
6 | def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor:
7 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
8 | t = torch.arange(end, device=freqs.device)
9 | freqs = torch.outer(t, freqs).float()
10 | return torch.polar(torch.ones_like(freqs), freqs) # complex64
11 |
12 |
13 | def apply_rotary_emb(
14 | xq: torch.Tensor,
15 | xk: torch.Tensor,
16 | freqs_cis: torch.Tensor,
17 | ) -> Tuple[torch.Tensor, torch.Tensor]:
18 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
19 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
20 | freqs_cis = freqs_cis[:, None, :]
21 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
22 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
23 | return xq_out.type_as(xq), xk_out.type_as(xk)
24 |
25 |
26 | def precompute_freqs_cis_2d(
27 | dim: int,
28 | height: int,
29 | width: int,
30 | theta: float,
31 | ) -> torch.Tensor:
32 | """
33 | freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by
34 | (height, width) position tuples
35 | """
36 | # (dim / 2) frequency bases
37 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
38 |
39 | h = torch.arange(height, device=freqs.device)
40 | w = torch.arange(width, device=freqs.device)
41 |
42 | freqs_h = torch.outer(h, freqs[::2]).float()
43 | freqs_w = torch.outer(w, freqs[1::2]).float()
44 | freqs_2d = torch.cat(
45 | [
46 | freqs_h[:, None, :].repeat(1, width, 1),
47 | freqs_w[None, :, :].repeat(height, 1, 1),
48 | ],
49 | dim=-1,
50 | )
51 | return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
52 |
--------------------------------------------------------------------------------
/src/mistral-inference/src/mistral_inference/transformer_layers.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Optional, Tuple, Type, Union
3 |
4 | import torch
5 | from torch import nn
6 | from xformers.ops.fmha import memory_efficient_attention # type: ignore
7 | from xformers.ops.fmha.attn_bias import BlockDiagonalMask
8 |
9 | from mistral_inference.args import LoraArgs
10 | from mistral_inference.cache import CacheView
11 | from mistral_inference.lora import LoRALinear
12 | from mistral_inference.moe import MoeArgs, MoeLayer
13 | from mistral_inference.rope import apply_rotary_emb
14 |
15 |
16 | def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
17 | keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
18 | values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
19 | return keys, values
20 |
21 |
22 | def maybe_lora(
23 | lora_args: Optional[LoraArgs],
24 | ) -> Union[Type[nn.Linear], partial[LoRALinear]]:
25 | if lora_args is None:
26 | return nn.Linear
27 | else:
28 | return partial(LoRALinear, rank=lora_args.rank, scaling=lora_args.scaling)
29 |
30 |
31 | class Attention(nn.Module):
32 | def __init__(
33 | self,
34 | dim: int,
35 | n_heads: int,
36 | head_dim: int,
37 | n_kv_heads: int,
38 | lora: Optional[LoraArgs] = None,
39 | ):
40 | super().__init__()
41 |
42 | self.n_heads: int = n_heads
43 | self.head_dim: int = head_dim
44 | self.n_kv_heads: int = n_kv_heads
45 |
46 | self.repeats = self.n_heads // self.n_kv_heads
47 |
48 | self.scale = self.head_dim**-0.5
49 |
50 | MaybeLora = maybe_lora(lora)
51 | self.wq = MaybeLora(dim, n_heads * head_dim, bias=False)
52 | self.wk = MaybeLora(dim, n_kv_heads * head_dim, bias=False)
53 | self.wv = MaybeLora(dim, n_kv_heads * head_dim, bias=False)
54 | self.wo = MaybeLora(n_heads * head_dim, dim, bias=False)
55 |
56 | def forward(
57 | self,
58 | x: torch.Tensor,
59 | freqs_cis: torch.Tensor,
60 | cache: Optional[CacheView] = None,
61 | mask: Optional[BlockDiagonalMask] = None,
62 | ) -> torch.Tensor:
63 | assert mask is None or cache is None
64 | seqlen_sum, _ = x.shape
65 |
66 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
67 | xq = xq.view(seqlen_sum, self.n_heads, self.head_dim)
68 | xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim)
69 | xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim)
70 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
71 |
72 | if cache is None:
73 | key, val = xk, xv
74 | elif cache.prefill:
75 | key, val = cache.interleave_kv(xk, xv)
76 | cache.update(xk, xv)
77 | else:
78 | cache.update(xk, xv)
79 | key, val = cache.key, cache.value
80 | key = key.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
81 | val = val.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
82 |
83 | # Repeat keys and values to match number of query heads
84 | key, val = repeat_kv(key, val, self.repeats, dim=1)
85 |
86 | # xformers requires (B=1, S, H, D)
87 | xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
88 | output = memory_efficient_attention(xq, key, val, mask if cache is None else cache.mask)
89 | output = output.view(seqlen_sum, self.n_heads * self.head_dim)
90 |
91 | assert isinstance(output, torch.Tensor)
92 |
93 | return self.wo(output) # type: ignore
94 |
95 |
96 | class FeedForward(nn.Module):
97 | def __init__(self, dim: int, hidden_dim: int, lora: Optional[LoraArgs] = None):
98 | super().__init__()
99 |
100 | MaybeLora = maybe_lora(lora)
101 | self.w1 = MaybeLora(dim, hidden_dim, bias=False)
102 | self.w2 = MaybeLora(hidden_dim, dim, bias=False)
103 | self.w3 = MaybeLora(dim, hidden_dim, bias=False)
104 |
105 | def forward(self, x: torch.Tensor) -> torch.Tensor:
106 | return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore
107 |
108 |
109 | class RMSNorm(torch.nn.Module):
110 | def __init__(self, dim: int, eps: float = 1e-6):
111 | super().__init__()
112 | self.eps = eps
113 | self.weight = nn.Parameter(torch.ones(dim))
114 |
115 | def _norm(self, x: torch.Tensor) -> torch.Tensor:
116 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
117 |
118 | def forward(self, x: torch.Tensor) -> torch.Tensor:
119 | output = self._norm(x.float()).type_as(x)
120 | return output * self.weight
121 |
122 |
123 | class TransformerBlock(nn.Module):
124 | def __init__(
125 | self,
126 | dim: int,
127 | hidden_dim: int,
128 | n_heads: int,
129 | n_kv_heads: int,
130 | head_dim: int,
131 | norm_eps: float,
132 | lora: Optional[LoraArgs] = None,
133 | moe: Optional[MoeArgs] = None,
134 | ):
135 | super().__init__()
136 | self.n_heads = n_heads
137 | self.dim = dim
138 | self.attention = Attention(
139 | dim=dim,
140 | n_heads=n_heads,
141 | head_dim=head_dim,
142 | n_kv_heads=n_kv_heads,
143 | lora=lora,
144 | )
145 | self.attention_norm = RMSNorm(dim, eps=norm_eps)
146 | self.ffn_norm = RMSNorm(dim, eps=norm_eps)
147 |
148 | self.feed_forward: nn.Module
149 | if moe is not None:
150 | self.feed_forward = MoeLayer(
151 | experts=[FeedForward(dim=dim, hidden_dim=hidden_dim, lora=lora) for _ in range(moe.num_experts)],
152 | gate=nn.Linear(dim, moe.num_experts, bias=False),
153 | moe_args=moe,
154 | )
155 | else:
156 | self.feed_forward = FeedForward(dim=dim, hidden_dim=hidden_dim, lora=lora)
157 |
158 | def forward(
159 | self,
160 | x: torch.Tensor,
161 | freqs_cis: torch.Tensor,
162 | cache: Optional[CacheView] = None,
163 | mask: Optional[BlockDiagonalMask] = None,
164 | ) -> torch.Tensor:
165 | r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
166 | h = x + r
167 | r = self.feed_forward.forward(self.ffn_norm(h))
168 | out = h + r
169 | return out
170 |
--------------------------------------------------------------------------------
/src/mistral-inference/src/mistral_inference/vision_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 | from xformers.ops.fmha.attn_bias import BlockDiagonalMask
6 |
7 | from mistral_inference.args import VisionEncoderArgs
8 | from mistral_inference.rope import precompute_freqs_cis_2d
9 | from mistral_inference.transformer_layers import RMSNorm, TransformerBlock
10 |
11 |
12 | def position_meshgrid(
13 | patch_embeds_list: list[torch.Tensor],
14 | ) -> torch.Tensor:
15 | positions = torch.cat(
16 | [
17 | torch.stack(
18 | torch.meshgrid(
19 | torch.arange(p.shape[-2]),
20 | torch.arange(p.shape[-1]),
21 | indexing="ij",
22 | ),
23 | dim=-1,
24 | ).reshape(-1, 2)
25 | for p in patch_embeds_list
26 | ]
27 | )
28 | return positions
29 |
30 |
31 | class VisionTransformer(nn.Module):
32 | def __init__(self, args: VisionEncoderArgs):
33 | super().__init__()
34 | self.args = args
35 | self.patch_conv = nn.Conv2d(
36 | in_channels=args.num_channels,
37 | out_channels=args.hidden_size,
38 | kernel_size=args.patch_size,
39 | stride=args.patch_size,
40 | bias=False,
41 | )
42 | self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
43 | self.transformer = VisionTransformerBlocks(args)
44 |
45 | head_dim = self.args.hidden_size // self.args.num_attention_heads
46 | assert head_dim % 2 == 0, "ROPE requires even head_dim"
47 | self._freqs_cis: Optional[torch.Tensor] = None
48 |
49 | @property
50 | def max_patches_per_side(self) -> int:
51 | return self.args.image_size // self.args.patch_size
52 |
53 | @property
54 | def device(self) -> torch.device:
55 | return next(self.parameters()).device
56 |
57 | @property
58 | def freqs_cis(self) -> torch.Tensor:
59 | if self._freqs_cis is None:
60 | self._freqs_cis = precompute_freqs_cis_2d(
61 | dim=self.args.hidden_size // self.args.num_attention_heads,
62 | height=self.max_patches_per_side,
63 | width=self.max_patches_per_side,
64 | theta=self.args.rope_theta,
65 | )
66 |
67 | if self._freqs_cis.device != self.device:
68 | self._freqs_cis = self._freqs_cis.to(device=self.device)
69 |
70 | return self._freqs_cis
71 |
72 | def forward(
73 | self,
74 | images: List[torch.Tensor],
75 | ) -> torch.Tensor:
76 | """
77 | Args:
78 | images: list of N_img images of variable sizes, each of shape (C, H, W)
79 |
80 | Returns:
81 | image_features: tensor of token features for all tokens of all images of
82 | shape (N_toks, D)
83 | """
84 | # pass images through initial convolution independently
85 | patch_embeds_list = [self.patch_conv(img.unsqueeze(0)).squeeze(0) for img in images]
86 |
87 | # flatten to a single sequence
88 | patch_embeds = torch.cat([p.flatten(1).permute(1, 0) for p in patch_embeds_list], dim=0)
89 | patch_embeds = self.ln_pre(patch_embeds)
90 |
91 | # positional embeddings
92 | positions = position_meshgrid(patch_embeds_list).to(self.device)
93 | freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
94 |
95 | # pass through Transformer with a block diagonal mask delimiting images
96 | mask = BlockDiagonalMask.from_seqlens(
97 | [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
98 | )
99 | out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
100 |
101 | # remove batch dimension of the single sequence
102 | return out # type: ignore[no-any-return]
103 |
104 |
105 | class VisionLanguageAdapter(nn.Module):
106 | def __init__(self, in_dim: int, out_dim: int):
107 | super().__init__()
108 | self.w_in = nn.Linear(
109 | in_dim,
110 | out_dim,
111 | bias=True,
112 | )
113 | self.gelu = nn.GELU()
114 | self.w_out = nn.Linear(out_dim, out_dim, bias=True)
115 |
116 | def forward(self, x: torch.Tensor) -> torch.Tensor:
117 | return self.w_out(self.gelu(self.w_in(x))) # type: ignore[no-any-return]
118 |
119 |
120 | class VisionTransformerBlocks(nn.Module):
121 | def __init__(self, args: VisionEncoderArgs):
122 | super().__init__()
123 | self.layers = torch.nn.ModuleList()
124 | for _ in range(args.num_hidden_layers):
125 | self.layers.append(
126 | TransformerBlock(
127 | dim=args.hidden_size,
128 | hidden_dim=args.intermediate_size,
129 | n_heads=args.num_attention_heads,
130 | n_kv_heads=args.num_attention_heads,
131 | head_dim=args.hidden_size // args.num_attention_heads,
132 | norm_eps=1e-5,
133 | )
134 | )
135 |
136 | def forward(
137 | self,
138 | x: torch.Tensor,
139 | mask: BlockDiagonalMask,
140 | freqs_cis: Optional[torch.Tensor],
141 | ) -> torch.Tensor:
142 | for layer in self.layers:
143 | x = layer(x, mask=mask, freqs_cis=freqs_cis)
144 | return x
145 |
146 |
147 |
--------------------------------------------------------------------------------
/src/mistral-inference/tests/test_generate.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | import torch
5 | from mistral_inference.args import VisionEncoderArgs
6 | from mistral_inference.generate import generate_mamba
7 | from mistral_inference.main import generate
8 | from mistral_inference.mamba import Mamba, MambaArgs
9 | from mistral_inference.transformer import Transformer, TransformerArgs
10 |
11 |
12 | class DebugTokenizer:
13 | @property
14 | def bos_id(self) -> int:
15 | return 0
16 |
17 | @property
18 | def eos_id(self) -> int:
19 | return 1
20 |
21 | @property
22 | def pad_id(self) -> int:
23 | return -1
24 |
25 | def encode(self, s: str, bos: bool = True) -> List[int]:
26 | assert isinstance(s, str)
27 | t = [int(x) for x in s.split()]
28 | if bos:
29 | t = [self.bos_id, *t]
30 | return t
31 |
32 | def decode(self, t: List[int]) -> str:
33 | return " ".join([str(x) for x in t])
34 |
35 |
36 | def test_generation_transformer():
37 | torch.manual_seed(42)
38 |
39 | sequences = ["1 2 3 4 5 6 7", "0 1 2", "12 13 14", "2 4 34"]
40 | args = TransformerArgs(
41 | dim=512,
42 | n_layers=1,
43 | head_dim=128,
44 | hidden_dim=2048,
45 | n_heads=4,
46 | n_kv_heads=2,
47 | norm_eps=1e-5,
48 | vocab_size=32_000,
49 | max_batch_size=len(sequences),
50 | )
51 | model = Transformer(args).to("cuda", dtype=torch.float32)
52 | tokenizer = DebugTokenizer()
53 |
54 | encoded = [tokenizer.encode(s, bos=True) for s in sequences]
55 | toks, all_logprobs_old = generate(encoded, model, temperature=0.0, max_tokens=7)
56 |
57 | # concat generated and prompt
58 | encoded = [e + t for e, t in zip(encoded, toks)]
59 |
60 | generated, all_logprobs_new = generate(
61 | encoded, model, temperature=0.0, max_tokens=0
62 | )
63 |
64 | assert generated == []
65 |
66 | # Verify that logprobs are the same
67 | assert len(sequences) == len(all_logprobs_old) == len(all_logprobs_new)
68 | for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new):
69 | assert all(
70 | [abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]
71 | ), f"\n{lp_old}\n{lp_new}"
72 |
73 | print("All tests passed.")
74 |
75 |
76 | def test_generation_pixtral():
77 | torch.manual_seed(42)
78 | gen = np.random.default_rng(seed=42)
79 |
80 | sequences = ["1 2 2 2 2 4 5 6 7", "12 13 14", "2 2 2 2 7 8 9"]
81 | images = [[gen.normal(size=(3, 4, 4))], [], [gen.normal(size=(3, 4, 4))]]
82 | args = TransformerArgs(
83 | dim=512,
84 | n_layers=1,
85 | head_dim=128,
86 | hidden_dim=2048,
87 | n_heads=4,
88 | n_kv_heads=2,
89 | norm_eps=1e-5,
90 | vocab_size=32_000,
91 | max_batch_size=len(sequences),
92 | vision_encoder=VisionEncoderArgs(
93 | hidden_size=128,
94 | num_channels=3,
95 | image_size=4,
96 | patch_size=2,
97 | intermediate_size=256,
98 | num_hidden_layers=1,
99 | num_attention_heads=2,
100 | rope_theta=10000,
101 | image_token_id=2,
102 | ),
103 | )
104 | model = Transformer(args).to("cuda", dtype=torch.float32)
105 | tokenizer = DebugTokenizer()
106 |
107 | encoded = [tokenizer.encode(s, bos=True) for s in sequences]
108 | toks, all_logprobs_old = generate(
109 | encoded, model, images=images, temperature=0.0, max_tokens=7
110 | )
111 |
112 | # concat generated and prompt
113 | encoded = [e + t for e, t in zip(encoded, toks)]
114 |
115 | generated, all_logprobs_new = generate(
116 | encoded, model, images=images, temperature=0.0, max_tokens=0
117 | )
118 |
119 | assert generated == []
120 |
121 | # Verify that logprobs are the same
122 | assert len(sequences) == len(all_logprobs_old) == len(all_logprobs_new)
123 | for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new):
124 | assert all(
125 | [abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]
126 | ), f"\n{lp_old}\n{lp_new}"
127 |
128 | print("All tests passed.")
129 |
130 |
131 | def test_generation_mamba():
132 | torch.manual_seed(42)
133 |
134 | sequences = ["1 2 3 4 5 6 7"]
135 | args = MambaArgs(
136 | dim=512,
137 | n_layers=1,
138 | n_groups=1,
139 | rms_norm=True,
140 | residual_in_fp32=True,
141 | fused_add_norm=True,
142 | pad_vocab_size_multiple=1,
143 | tie_embeddings=False,
144 | vocab_size=32768,
145 | )
146 | model = Mamba(args).to("cuda", dtype=torch.float32)
147 | tokenizer = DebugTokenizer()
148 |
149 | encoded = [tokenizer.encode(s, bos=True) for s in sequences]
150 | toks, all_logprobs_old = generate_mamba(
151 | encoded, model, temperature=0.0, max_tokens=7
152 | )
153 |
154 | assert len(toks[0]) == 7
155 | assert toks == [[25574, 14821, 11843, 23698, 12735, 23522, 27542]]
156 |
157 |
158 | def test_chunks_transformer():
159 | torch.manual_seed(42)
160 |
161 | sequences = [
162 | " ".join([str(i) for i in range(7)]),
163 | " ".join([str(i) for i in range(9, 0, -1)]),
164 | ]
165 | args = TransformerArgs(
166 | dim=512,
167 | n_layers=1,
168 | head_dim=128,
169 | hidden_dim=2048,
170 | n_heads=4,
171 | n_kv_heads=2,
172 | norm_eps=1e-5,
173 | vocab_size=32_000,
174 | max_batch_size=3,
175 | )
176 | model = Transformer(args).to("cuda", dtype=torch.float32)
177 | tokenizer = DebugTokenizer()
178 |
179 | encoded = [tokenizer.encode(s, bos=True) for s in sequences]
180 | toks, all_logprobs_old = generate(encoded, model, temperature=0.0, max_tokens=8)
181 |
182 | # concat generated and prompt
183 | encoded = [e + t for e, t in zip(encoded, toks)]
184 |
185 | generated, all_logprobs_new = generate(
186 | encoded, model, temperature=0.0, max_tokens=0, chunk_size=5
187 | )
188 | assert len(generated) == 0
189 |
190 | for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new):
191 | assert all(
192 | [abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]
193 | ), f"\n{lp_old}\n{lp_new}"
194 |
--------------------------------------------------------------------------------
/src/mistral-inference/tutorials/getting_started.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Getting Started with `mistral-inference`\n",
8 | "\n",
9 | "This notebook will guide you through the process of running Mistral models locally. We will cover the following: \n",
10 | "- How to chat with Mistral 7B Instruct\n",
11 | "- How to run Mistral 7B Instruct with function calling capabilities\n",
12 | "\n",
13 | "We recommend using a GPU such as the A100 to run this notebook. "
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": null,
19 | "metadata": {
20 | "id": "G6tXvIsQenpI"
21 | },
22 | "outputs": [],
23 | "source": [
24 | "!pip install mistral-inference"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {},
30 | "source": [
31 | "## Download Mistral 7B Instruct"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {
38 | "colab": {
39 | "background_save": true
40 | },
41 | "id": "4ytmRt0WQeMW"
42 | },
43 | "outputs": [],
44 | "source": [
45 | "!wget https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-Instruct-v0.3.tar"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {
52 | "id": "eRZg_8wvs5A6"
53 | },
54 | "outputs": [],
55 | "source": [
56 | "!DIR=$HOME/mistral_7b_instruct_v3 && mkdir -p $DIR && tar -xf mistral-7B-Instruct-v0.3.tar -C $DIR"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {
63 | "id": "7CN8gShDf65M"
64 | },
65 | "outputs": [],
66 | "source": [
67 | "!ls mistral_7b_instruct_v3"
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {},
73 | "source": [
74 | "## Chat with the model"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "import os \n",
84 | "\n",
85 | "from mistral_inference.transformer import Transformer\n",
86 | "from mistral_inference.generate import generate\n",
87 | "\n",
88 | "from mistral_common.tokens.tokenizers.mistral import MistralTokenizer\n",
89 | "from mistral_common.protocol.instruct.messages import UserMessage\n",
90 | "from mistral_common.protocol.instruct.request import ChatCompletionRequest\n",
91 | "\n",
92 | "# load tokenizer\n",
93 | "mistral_tokenizer = MistralTokenizer.from_file(os.path.expanduser(\"~\")+\"/mistral_7b_instruct_v3/tokenizer.model.v3\")\n",
94 | "# chat completion request\n",
95 | "completion_request = ChatCompletionRequest(messages=[UserMessage(content=\"Explain Machine Learning to me in a nutshell.\")])\n",
96 | "# encode message\n",
97 | "tokens = mistral_tokenizer.encode_chat_completion(completion_request).tokens\n",
98 | "# load model\n",
99 | "model = Transformer.from_folder(os.path.expanduser(\"~\")+\"/mistral_7b_instruct_v3\")\n",
100 | "# generate results\n",
101 | "out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=mistral_tokenizer.instruct_tokenizer.tokenizer.eos_id)\n",
102 | "# decode generated tokens\n",
103 | "result = mistral_tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])\n",
104 | "print(result)"
105 | ]
106 | },
107 | {
108 | "cell_type": "markdown",
109 | "metadata": {
110 | "id": "ce4woS3LkgZ9"
111 | },
112 | "source": [
113 | "## Function calling\n",
114 | "\n",
115 | "Mistral 7B Instruct v3 also supports function calling!"
116 | ]
117 | },
118 | {
119 | "cell_type": "markdown",
120 | "metadata": {
121 | "id": "TKfPiEwNk1kh"
122 | },
123 | "source": [
124 | "Let's start by creating a function calling example"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": null,
130 | "metadata": {
131 | "id": "0PJdwvDEk3dl"
132 | },
133 | "outputs": [],
134 | "source": [
135 | "from mistral_common.protocol.instruct.messages import UserMessage\n",
136 | "from mistral_common.protocol.instruct.request import ChatCompletionRequest\n",
137 | "from mistral_common.protocol.instruct.tool_calls import Function, Tool\n",
138 | "\n",
139 | "completion_request = ChatCompletionRequest(\n",
140 | " tools=[\n",
141 | " Tool(\n",
142 | " function=Function(\n",
143 | " name=\"get_current_weather\",\n",
144 | " description=\"Get the current weather\",\n",
145 | " parameters={\n",
146 | " \"type\": \"object\",\n",
147 | " \"properties\": {\n",
148 | " \"location\": {\n",
149 | " \"type\": \"string\",\n",
150 | " \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
151 | " },\n",
152 | " \"format\": {\n",
153 | " \"type\": \"string\",\n",
154 | " \"enum\": [\"celsius\", \"fahrenheit\"],\n",
155 | " \"description\": \"The temperature unit to use. Infer this from the users location.\",\n",
156 | " },\n",
157 | " },\n",
158 | " \"required\": [\"location\", \"format\"],\n",
159 | " },\n",
160 | " )\n",
161 | " )\n",
162 | " ],\n",
163 | " messages=[\n",
164 | " UserMessage(content=\"What's the weather like today in Paris?\"),\n",
165 | " ],\n",
166 | ")"
167 | ]
168 | },
169 | {
170 | "cell_type": "markdown",
171 | "metadata": {
172 | "id": "bG6ZeZUylpBW"
173 | },
174 | "source": [
175 | "Since we have already loaded the tokenizer and the model in the example above. We will skip these steps here. \n",
176 | "\n",
177 | "Now we can encode the message with our tokenizer using `MistralTokenizer`. "
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": null,
183 | "metadata": {
184 | "id": "Ii8q-JNClwiq"
185 | },
186 | "outputs": [],
187 | "source": [
188 | "from mistral_common.tokens.tokenizers.mistral import MistralTokenizer\n",
189 | "\n",
190 | "tokens = mistral_tokenizer.encode_chat_completion(completion_request).tokens"
191 | ]
192 | },
193 | {
194 | "cell_type": "markdown",
195 | "metadata": {
196 | "id": "NrueDujkmJT4"
197 | },
198 | "source": [
199 | "and run `generate` to get a response. Don't forget to pass the EOS id!"
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": null,
205 | "metadata": {
206 | "id": "GWJYO43rl0V8"
207 | },
208 | "outputs": [],
209 | "source": [
210 | "from mistral_inference.generate import generate\n",
211 | "\n",
212 | "out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=mistral_tokenizer.instruct_tokenizer.tokenizer.eos_id)"
213 | ]
214 | },
215 | {
216 | "cell_type": "markdown",
217 | "metadata": {
218 | "id": "v7baJ1msmPMv"
219 | },
220 | "source": [
221 | "Finally, we can decode the generated tokens."
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": null,
227 | "metadata": {
228 | "id": "RKhryfBWmHon"
229 | },
230 | "outputs": [],
231 | "source": [
232 | "result = mistral_tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens)[0]\n",
233 | "result"
234 | ]
235 | }
236 | ],
237 | "metadata": {
238 | "accelerator": "GPU",
239 | "colab": {
240 | "gpuType": "L4",
241 | "machine_shape": "hm",
242 | "provenance": []
243 | },
244 | "kernelspec": {
245 | "display_name": "Python 3 (ipykernel)",
246 | "language": "python",
247 | "name": "python3"
248 | },
249 | "language_info": {
250 | "codemirror_mode": {
251 | "name": "ipython",
252 | "version": 3
253 | },
254 | "file_extension": ".py",
255 | "mimetype": "text/x-python",
256 | "name": "python",
257 | "nbconvert_exporter": "python",
258 | "pygments_lexer": "ipython3",
259 | "version": "3.11.8"
260 | }
261 | },
262 | "nbformat": 4,
263 | "nbformat_minor": 4
264 | }
265 |
--------------------------------------------------------------------------------
/src/single_query.py:
--------------------------------------------------------------------------------
1 | # This software was developed by Ivi Chatzi, Nina Corvelo Benz, Eleni Straitouri, Stratis Tsirtsis, and Manuel Gomez Rodriguez.
2 | # If you use this code, please cite the paper "Counterfactual Token Generation in Large Language Models" by the same authors.
3 |
4 | from typing import List, Optional
5 |
6 | import fire
7 |
8 | import sys
9 | import os
10 | import json
11 | import torch
12 | from sampler import Sampler
13 |
14 | # mistral stuff
15 | from mistral_common.tokens.tokenizers.base import Tokenizer
16 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
17 | from mistral_common.protocol.instruct.request import ChatCompletionRequest
18 | from mistral_common.protocol.instruct.messages import (
19 | AssistantMessage,
20 | UserMessage,
21 | SystemMessage
22 | )
23 | from mistral_inference.generate import generate
24 | # Add the src directory to the Python path
25 | sys.path.append(os.path.abspath("src/mistral-inference/src"))
26 | from mistral_inference.main import load_tokenizer, get_model_cls
27 | from pathlib import Path
28 |
29 | # Add the src directory to the Python path
30 | sys.path.append(os.path.abspath("src/llama3"))
31 | from llama import Dialog, Llama
32 |
33 | def main(
34 | ckpt_dir: str,
35 | tokenizer_path: str,
36 | temperature: float = 0.6,
37 | top_p: float = 0.9,
38 | top_k: int = 10,
39 | sampler_type: str = 'vocabulary',
40 | max_seq_len: int = 512,
41 | max_batch_size: int = 4,
42 | max_gen_len: Optional[int] = None,
43 | seed: int = 42,
44 | query: str = "give me a recipe for moussaka",
45 | exp_name: str = "test",
46 | system: str = "Keep your replies short and to the point but don't give single word answers.",
47 | genstates: bool = False,
48 | model_family: str = "llama3",
49 | ):
50 | """
51 | Examples to run with the models finetuned for chat. Prompts correspond of chat
52 | turns between the user and assistant with the final one always being the user.
53 |
54 | An optional system prompt at the beginning to control how the model should respond
55 | is also supported.
56 |
57 | The context window of llama3 models is 8192 tokens, so `max_seq_len` needs to be <= 8192.
58 |
59 | `max_gen_len` is optional because finetuned models are able to stop generations naturally.
60 | """
61 |
62 | sampler = Sampler(sampler_type=sampler_type, top_p=top_p, top_k=top_k)
63 |
64 | if model_family == "llama3":
65 | generator = Llama.build(
66 | ckpt_dir=ckpt_dir,
67 | tokenizer_path=tokenizer_path,
68 | max_seq_len=max_seq_len,
69 | max_batch_size=max_batch_size,
70 | seed=seed
71 | )
72 |
73 | dialogs: List[Dialog] = [
74 | [ {"role": "system", "content": system},
75 | {"role": "user", "content": query}]
76 | ]
77 |
78 | results, rngstates = generator.chat_completion(
79 | dialogs,
80 | max_gen_len=max_gen_len,
81 | temperature=temperature,
82 | sampler=sampler,
83 | genstates=genstates
84 | )
85 | elif model_family == "mistral":
86 |
87 | # load the model
88 | mistral_tokenizer: MistralTokenizer = load_tokenizer(Path(ckpt_dir))
89 | tokenizer: Tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer
90 |
91 | model_cls = get_model_cls(ckpt_dir)
92 | model = model_cls.from_folder(Path(ckpt_dir), max_batch_size=max_batch_size, num_pipeline_ranks=1)
93 |
94 | messages: List[SystemMessage | UserMessage | AssistantMessage] = []
95 | messages += [SystemMessage(content=system)]
96 | messages += [UserMessage(content=query)]
97 |
98 | chat_completion_request = ChatCompletionRequest(messages=messages)
99 | tokenized = mistral_tokenizer.encode_chat_completion(chat_completion_request)
100 | tokens = tokenized.tokens
101 |
102 | # initialize a random number generator with the given seed
103 | rng = torch.Generator(device="cuda")
104 | rng.manual_seed(seed)
105 |
106 | generated_tokens, _, rngstates = generate( # type: ignore[operator]
107 | encoded_prompts = [tokens],
108 | model = model,
109 | max_tokens = max_seq_len,
110 | temperature = temperature,
111 | eos_id = tokenizer.eos_id,
112 | sampler = sampler,
113 | genstates = genstates,
114 | rng = rng
115 | )
116 |
117 | if genstates:
118 | results = [
119 | {
120 | "generation": {
121 | "role": "assistant",
122 | "content": tokenizer.decode(t),
123 | "token_list": [tokenizer.decode([x]) for x in t],
124 | },
125 | }
126 | for t in generated_tokens
127 | ]
128 |
129 | output_dir = os.path.join("outputs", exp_name)
130 | # if a folder with the experiment name exists, delete its contents, otherwise create it
131 | if os.path.exists(output_dir):
132 | for file in os.listdir(output_dir):
133 | os.remove(os.path.join(output_dir, file))
134 | else:
135 | os.makedirs(output_dir)
136 |
137 | if genstates:
138 | rngstates = torch.tensor(rngstates, dtype=torch.uint8)
139 | # save the rngstates tensor
140 | torch.save(rngstates, os.path.join(output_dir, 'rngstates_1.pt'))
141 |
142 | output = {}
143 | output["system"] = system
144 | output["query"] = query
145 | output["response"] = results[0]["generation"]["content"]
146 | output["seed"] = seed
147 | output["temperature"] = temperature
148 | output["sampler_type"] = sampler_type
149 | output["top_p"] = top_p
150 | output["top_k"] = top_k
151 | output["model_family"] = model_family
152 |
153 | # save the output of the factual generation
154 | with open(os.path.join(output_dir, "factual.json"), "w") as f:
155 | json.dump(output, f, indent=4)
156 |
157 | # prepare the intervention file
158 | output["ckpt_dir"] = ckpt_dir
159 | output["tokenizer_path"] = tokenizer_path
160 | output["start_from"] = 0
161 | output["token_list"] = {ind: tok for ind, tok in enumerate(results[0]["generation"]["token_list"])}
162 | with open(os.path.join(output_dir, "intervention_1.json"), "w") as f:
163 | json.dump(output, f, indent=4)
164 |
165 | print("----------------")
166 | print("SYSTEM:", system)
167 | print("----------------")
168 | print("QUERY:", query)
169 | print("----------------")
170 | print("RESPONSE:", results[0]["generation"]["content"])
171 | print("----------------")
172 |
173 |
174 | if __name__ == "__main__":
175 | fire.Fire(main)
176 |
177 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | # This software was developed by Ivi Chatzi, Nina Corvelo Benz, Eleni Straitouri, Stratis Tsirtsis, and Manuel Gomez Rodriguez.
2 | # If you use this code, please cite the paper "Counterfactual Token Generation in Large Language Models" by the same authors.
3 |
4 | import matplotlib.pyplot as plt
5 | import matplotlib
6 | import numpy as np
7 |
8 | def get_fig_dim(width, fraction=1, aspect_ratio=None):
9 | """Set figure dimensions to avoid scaling in LaTeX.
10 |
11 | Parameters
12 | ----------
13 | width: float
14 | Document textwidth or columnwidth in pts
15 | fraction: float, optional
16 | Fraction of the width which you wish the figure to occupy
17 | aspect_ratio: float, optional
18 | Aspect ratio of the figure
19 |
20 | Returns
21 | -------
22 | fig_dim: tuple
23 | Dimensions of figure in inches
24 | """
25 | # Width of figure (in pts)
26 | fig_width_pt = width * fraction
27 |
28 | # Convert from pt to inches
29 | inches_per_pt = 1 / 72.27
30 |
31 | if aspect_ratio is None:
32 | # If not specified, set the aspect ratio equal to the Golden ratio (https://en.wikipedia.org/wiki/Golden_ratio)
33 | aspect_ratio = (1 + 5**.5) / 2
34 |
35 | # Figure width in inches
36 | fig_width_in = fig_width_pt * inches_per_pt
37 | # Figure height in inches
38 | fig_height_in = fig_width_in / aspect_ratio
39 |
40 | fig_dim = (fig_width_in, fig_height_in)
41 |
42 | return fig_dim
43 |
44 |
45 | def latexify(font_serif='Computer Modern', mathtext_font='cm', font_size=10, small_font_size=None, usetex=True):
46 | """Set up matplotlib's RC params for LaTeX plotting.
47 | Call this before plotting a figure.
48 |
49 | Parameters
50 | ----------
51 | font_serif: string, optional
52 | Set the desired font family
53 | mathtext_font: float, optional
54 | Set the desired math font family
55 | font_size: int, optional
56 | Set the large font size
57 | small_font_size: int, optional
58 | Set the small font size
59 | usetex: boolean, optional
60 | Use tex for strings
61 | """
62 |
63 | if small_font_size is None:
64 | small_font_size = font_size
65 |
66 | params = {
67 | 'backend': 'ps',
68 | 'text.latex.preamble': '\\usepackage{gensymb} \\usepackage{bm}',
69 |
70 | 'axes.labelsize': font_size,
71 | 'axes.titlesize': font_size,
72 | 'font.size': font_size,
73 |
74 | # Optionally set a smaller font size for legends and tick labels
75 | 'legend.fontsize': small_font_size,
76 | 'legend.title_fontsize': small_font_size,
77 | 'xtick.labelsize': small_font_size,
78 | 'ytick.labelsize': small_font_size,
79 |
80 | 'text.usetex': usetex,
81 | 'font.family' : 'serif',
82 | 'font.serif' : font_serif,
83 | 'mathtext.fontset' : mathtext_font
84 | }
85 |
86 | matplotlib.rcParams.update(params)
87 | plt.rcParams.update(params)
88 |
--------------------------------------------------------------------------------