├── .gitignore
├── API
├── api_service_main.py
├── server_start.bat
└── server_start.sh
├── LICENSE
├── README.md
├── examples
├── Annotation_Visualization.ipynb
├── Embd_Interp.ipynb
├── FET_Projection.ipynb
├── GUI_TEST.ipynb
├── KNN_RAG.ipynb
├── KNN_RAG_interp.ipynb
├── Metric_calculation_test.ipynb
├── RL_Simalation_With_Interpretation_Ebeddings.ipynb
├── SBertDistilTest.ipynb
├── SBertSVDDistil.ipynb
└── example_data
│ ├── clusters.json
│ └── metrix.csv
├── explainitall
├── QA
│ ├── extractive_qa_sbert
│ │ ├── QABotsBase.py
│ │ └── SVDBert.py
│ └── interp_qa
│ │ └── KNNWithGenerative.py
├── __init__.py
├── clusters.py
├── embedder_interp
│ └── embd_interpret.py
├── fast_tuning
│ ├── Embedder.py
│ ├── ExpertBase.py
│ ├── SimpleModelCreator.py
│ ├── generators
│ │ ├── GeneratorWithExpert.py
│ │ ├── MCExpert.py
│ │ ├── MCModel.py
│ │ └── SimpleGenerator.py
│ └── trainers
│ │ ├── DenceKerasTrainer.py
│ │ ├── HMMTrainer.py
│ │ └── ProjectionTrainer.py
├── gpt_like_interp
│ ├── __init__.py
│ ├── downloader.py
│ ├── inseq_helpers.py
│ ├── interp.py
│ └── viz.py
├── gui
│ ├── df_to_heatmap_plot.py
│ └── interface.py
├── metrics
│ ├── CheckingForHallucinations.py
│ └── RougeAndPPL
│ │ ├── Metrics.py
│ │ ├── Metrics_calculator.py
│ │ ├── create_database.py
│ │ ├── helpers.py
│ │ ├── metric_calculation_interface.py
│ │ ├── rouge_L.py
│ │ └── rouge_N.py
├── nlp.py
└── stat_helpers.py
├── main.py
├── pytest.ini
├── requirements.txt
├── setup.py
└── tests
├── __init__.py
├── test_inseq_helpers.py
└── test_stat_helpers.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
162 | /exclude_*
163 | /examples/dataset
164 | /examples/database.sqlite
165 | *cache*
166 | /examples/trained_model
167 | /examples/s
168 | /examples/new_gpt
169 |
--------------------------------------------------------------------------------
/API/api_service_main.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | from typing import List
4 |
5 | import gensim
6 | import numpy as np
7 | from fastapi import FastAPI
8 | from inseq import load_model
9 | from pydantic import BaseModel, constr
10 | from sentence_transformers import SentenceTransformer
11 | from sklearn.neighbors import KNeighborsClassifier
12 | from starlette.responses import RedirectResponse
13 |
14 | from explainitall.QA.extractive_qa_sbert.QABotsBase import cos_dist
15 | from explainitall.QA.extractive_qa_sbert.SVDBert import SVDBertModel
16 | from explainitall.QA.interp_qa.KNNWithGenerative import FredStruct, PromptBot
17 | from explainitall.gpt_like_interp import interp
18 | from explainitall.gpt_like_interp.downloader import DownloadManager
19 |
20 |
21 | class GetAnswerItem(BaseModel):
22 | question: constr(min_length=1)
23 | top_k: int
24 |
25 |
26 | class LoadDatasetItem(BaseModel):
27 | questions: List[str]
28 | answers: List[str]
29 |
30 |
31 | class ClusterItem(BaseModel):
32 | name: constr(min_length=1)
33 | centroid: List[str]
34 | top_k: int
35 |
36 |
37 | class EvaluationItem(BaseModel):
38 | nlp_model_path: constr(min_length=1)
39 | nn_model_path: constr(min_length=1)
40 | clusters: List[ClusterItem]
41 | prompt: constr(min_length=1)
42 | generated_text: constr(min_length=1)
43 |
44 |
45 | class ApiServerInit:
46 | def __init__(self, sbert_path='FractalGPT/SbertSVDDistil', device='cpu'):
47 | self.clusters_r = None
48 | self.bot = None
49 | self.prompt = None
50 | self.generated_text = None
51 | self.explainer = None
52 | self.clusters = None
53 | self.device = device
54 | self.app = FastAPI()
55 | self.sbert = SentenceTransformer(sbert_path)
56 | self.sbert[0].auto_model = SVDBertModel.from_pretrained(sbert_path)
57 | # self.fred = FredStruct('SiberiaSoft/SiberianFredT5-instructor')
58 |
59 | if os.getenv('TEST_MODE_ON_LOW_SPEC_PC') == 'True':
60 | self.fred = FredStruct('ai-forever/FRED-T5-large')
61 | else:
62 | self.fred = FredStruct('FractalGPT/FRED-T5-Interp')
63 |
64 | def load_dataset(self, questions, answers):
65 | self.__init_knn__(questions, answers)
66 | self.bot = PromptBot(self.knn, self.sbert, self.fred, answers, device=self.device)
67 | return True
68 |
69 | @staticmethod
70 | def df_to_dict(data_frame):
71 | df_copy = data_frame.copy(deep=True)
72 |
73 | def make_columns_unique(df):
74 | new_columns = {}
75 | for column in df.columns:
76 | if column in new_columns:
77 | new_columns[column] += 1
78 | new_name = f"{column}_{new_columns[column]}"
79 | else:
80 | new_columns[column] = 0
81 | new_name = column
82 | yield new_name
83 |
84 | df_copy.columns = list(make_columns_unique(df_copy))
85 | return df_copy.replace([np.nan, np.inf, -np.inf], ["nan", "inf", "-inf"]).to_dict(orient="split")
86 |
87 | def evaluate(self, nlp_model_path, nn_model_path, clusters, prompt, generated_text):
88 | self.clusters_r = [{
89 | "name": cluster.name,
90 | "centroid": cluster.centroid,
91 | "top_k": cluster.top_k
92 | } for cluster in clusters]
93 | self.prompt = prompt
94 | self.generated_text = generated_text
95 | self.__load_nlp_model__(nlp_model_path)
96 | self.__load_nn_model__(nn_model_path)
97 | self.explainer = interp.ExplainerGPT2(gpt_model=self.nn_model, nlp_model=self.nlp_model)
98 | expl_data = self.explainer.interpret(
99 | input_texts=self.prompt,
100 | generated_texts=self.generated_text,
101 | clusters_description=self.clusters_r,
102 | batch_size=50,
103 | steps=34,
104 | # max_new_tokens=19
105 | )
106 | return {"word_importance_map": self.df_to_dict(expl_data.word_imp_df),
107 | "word_importance_map_normalized": self.df_to_dict(expl_data.word_imp_norm_df),
108 | "cluster_importance_map": self.df_to_dict(expl_data.cluster_imp_df),
109 | "cluster_importance_map_normalized": self.df_to_dict(expl_data.cluster_imp_aggr_df)}
110 |
111 | def get_answer(self, q, top_k):
112 | return self.bot.get_answers(q, top_k=top_k)
113 |
114 | def __init_knn__(self, questions, answers):
115 | vects_questions = self.sbert.encode(questions)
116 | m = vects_questions.mean(axis=0)
117 | s = vects_questions.std(axis=0)
118 | knn_vects_questions = (vects_questions - m) / s
119 |
120 | self.knn = KNeighborsClassifier(metric=cos_dist)
121 | self.knn.fit(knn_vects_questions, answers)
122 |
123 | def __load_nlp_model__(self, url):
124 | self.nlp_model_url = url
125 | nlp_model_path = DownloadManager.load_zip(url)
126 | self.nlp_model = gensim.models.KeyedVectors.load_word2vec_format(nlp_model_path, binary=True)
127 | return True
128 |
129 | def __load_nn_model__(self, model_name_or_path):
130 | path = os.path.normpath(model_name_or_path)
131 | path_list = path.split(os.sep)
132 | self.nn_model_name = path_list[-1]
133 | self.nn_model = load_model(model=model_name_or_path, attribution_method="integrated_gradients")
134 | return True
135 |
136 |
137 | api_server_init = ApiServerInit()
138 | app = api_server_init.app
139 |
140 |
141 | @app.get("/")
142 | async def redirect_to_docs():
143 | return RedirectResponse(url="/docs")
144 |
145 |
146 | @app.post(
147 | "/load_dataset",
148 | summary="Load a dataset for the Q&A model",
149 | response_description="Indicates success of dataset loading",
150 | openapi_extra={
151 | "requestBody": {
152 | "content": {
153 | "application/json": {
154 | "examples": {
155 | "example1": {
156 | "summary": "Загрузка базового набора данных вопросов и ответов",
157 | "value": {
158 | "questions": ["Что такое коала?",
159 | "Опишите африканского слона"],
160 |
161 | "answers": ["Это вид медведей, обитающих в Австралии.",
162 | "Это крупное млекопитающее с длинным хоботом."]
163 | }
164 | }
165 | }
166 | }
167 | }
168 | }
169 | }
170 | )
171 | async def load_dataset(item: LoadDatasetItem):
172 | """Загружает набор данных, состоящий из вопросов и ответов о животных, в модель вопросов и ответов"""
173 | result = await asyncio.to_thread(api_server_init.load_dataset, item.questions, item.answers)
174 | return {"result": result}
175 |
176 |
177 | @app.post(
178 | "/get_answer",
179 | summary="Получить ответы на вопросы",
180 | response_description="Полученный(е) ответ(ы) на указанные вопросы",
181 | openapi_extra={
182 | "requestBody": {
183 | "content": {
184 | "application/json": {
185 | "examples": {
186 | "example1": {
187 | "summary": "Получить ответ на вопрос о животном",
188 | "value": {
189 | "question": "Что за животное коала?",
190 | "top_k": 1
191 | }
192 | }
193 | }
194 | }
195 | }
196 | }
197 | }
198 | )
199 | async def get_answer(item: GetAnswerItem):
200 | """
201 | Получает ответы на указанный вопрос с использованием загруженной модели вопросов и ответов
202 | """
203 | result = await asyncio.to_thread(api_server_init.get_answer, item.question, item.top_k)
204 | return {"result": result}
205 |
206 |
207 | @app.post(
208 | "/evaluate",
209 | summary="Оценить сгенерированный текст с использованием модели",
210 | response_description="Результаты оценки",
211 | openapi_extra={
212 | "requestBody": {
213 | "content": {
214 | "application/json": {
215 | "examples": {
216 | "example1": {
217 | "summary": "Оценить интерпретируемость текста",
218 | "value": {
219 | "nlp_model_path": 'http://vectors.nlpl.eu/repository/20/180.zip',
220 | "nn_model_path": "ai-forever/rugpt3small_based_on_gpt2",
221 | "clusters": [
222 | {'name': 'Животные', 'centroid': ['собака', 'кошка', 'заяц'], 'top_k': 140},
223 | {'name': 'Лекарства', 'centroid': ['уколы', 'таблетки', 'противовирусное'],
224 | 'top_k': 160},
225 | {'name': 'Болезни', 'centroid': ['простуда', 'орви', 'орз', 'грипп'], 'top_k': 20},
226 | {'name': 'Симптомы', 'centroid': ['температура', 'насморк'], 'top_k': 20}
227 | ],
228 | "prompt": "я думаю что у моей кошки простуда, у нее температура, постоянный кашель: чем мне лечить мою кошку? ответ:",
229 | "generated_text": "На сегодняшний день существует специальное противовирусное лечение для кошек, так же можно применять антибиотики"
230 | }
231 | }
232 | }
233 | }
234 | }
235 | }
236 | }
237 | )
238 | async def evaluate(item: EvaluationItem):
239 | """Оценивает интерпретируемость сгенерированного текста относительно входного подсказки и предоставленных кластеров"""
240 | data = await asyncio.to_thread(api_server_init.evaluate, item.nlp_model_path, item.nn_model_path, item.clusters,
241 | item.prompt, item.generated_text)
242 | return data
243 |
244 |
245 | if __name__ == "__main__":
246 | import uvicorn
247 |
248 | uvicorn.run(app, host="localhost", port=8000)
249 |
--------------------------------------------------------------------------------
/API/server_start.bat:
--------------------------------------------------------------------------------
1 | title API_EXPLAINITALL_PYTHON
2 | uvicorn api_service_main:app --port 8000 --reload
--------------------------------------------------------------------------------
/API/server_start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | echo "API_EXPLAINITALL_PYTHON"
3 | uvicorn api_service_main:app --port 8000 --reload
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://pypi.org/project/explainitall/)
2 |
3 | # **ExplainitAll**
4 |
5 | **ExplainitAll** — это библиотека для интерпретируемого ИИ, предназначенная для интерпретации генеративных моделей (
6 | GPT-like), и векторизаторов, например, Sbert. Библиотека предоставляет пользователям инструменты для анализа и понимания
7 | работы этих сложных моделей. Кроме того, содержит модули RAG QA, fast_tuning и пользовательский интерфейс.
8 |
9 | ---
10 |
11 | * [Примеры использования](https://github.com/Bots-Avatar/ExplainitAll/tree/main/examples)
12 | * [Исходный код библиотеки](https://github.com/Bots-Avatar/ExplainitAll/tree/main/explainitall)
13 | * [Документация](https://github.com/Bots-Avatar/ExplainitAll/wiki)
14 |
15 | ## Модели:
16 |
17 | * Дистиллированный [Sbert](https://huggingface.co/FractalGPT/SbertDistil)
18 | * Дистиллированный [Sbert](https://huggingface.co/FractalGPT/SbertSVDDistil) с применением SVD разложения, для ускорения
19 | инференса и обучения
20 | * [FRED T5](https://huggingface.co/FractalGPT/FRED-T5-Interp), обученный под задачу RAG, для ответов на вопросы по
21 | интепретации генеративной gpt-подобной сети.
22 | * [FRED T5](https://huggingface.co/FractalGPT/FredT5-Large-Instruct-Context), небольшой T5 обученный для instruct задач с учетом контекста
23 | ---
24 |
25 | ## Перечень направлений прикладного использования:
26 |
27 | Результаты могут применяться в следующих областях: любые вопрос-ответные системы или классификаторы критических
28 | отраслей (медицина, строительство, космос, право и т.п.). Типовой сценарий применения, например для медицины следующий:
29 | разработчик конечного продукта, такого как например система поиска противопоказаний у лекарств в тесном взаимодействии в
30 | заказчиком(врачом, поликлиникой и т.п.) создает набор кластеров тематической области, дообучает трансформерную модель (
31 | GPT-like: например, семейств ruGPT3 и GPT2) на текстах вопрос-ответ, и на затем в режиме эксплуатации данной, уже
32 | готовой Вопросно-ответной системы подключает библиотеку ExplainitAll для того, чтобы она давал аналитическую оценку –
33 | насколько «надежными» и доверенными являются ответы вопросно-ответной системы на основе результата интерпретации –
34 | действительно ли при ответе на вопросы пользователя система обращала внимание на важные для отрасли кластеры.
35 |
36 | Разработанная библиотека может быть адаптирована как модуль конечного продукта - ассистента врача,
37 | инженера-конструктора, юриста, бухгалтера. Для государственного сектора библиотека может быть полезна т.к. помогает
38 | доверять RAG системам при ответах по налогам, регламентам проведения закупочных процедур, руководствам пользователей
39 | информационных систем, нормативно-правовым актам регулирования. Для промышленных предприятий библиотека применима в
40 | работе с регламентами, руководствами по эксплуатации и обслуживанию сложного технического оборудования, т.к. позволяет
41 | оценивать учет в ответах QA систем понимание специальных, важных для отрасли сокращений, наименования, аббревиатур,
42 | номенклатурных обозначений.
43 |
44 | ## Характеристики:
45 |
46 | * Операционная система Ubuntu 22.04.3 LTS
47 | * Драйвер: NVIDIA версия 535.104.05
48 | * CUDA версия 12.2
49 | * Python 3.10.12
50 |
51 |
52 | * Процессор AMD Ryzen 3 3200G OEM (частота: 3600 МГц, количество ядер: 4)
53 | * Оперативная память 16 GB
54 |
55 | * Графический процессор
56 | * Модель: nVidia TU104GL [Tesla T4]
57 | * Видеопамять 16 GB
58 |
59 | ## Лицензия
60 |
61 | * Лицензия [Apache-2.0 license](https://github.com/Bots-Avatar/ExplainitAll/tree/main#Apache-2.0-1-ov-file)
62 |
--------------------------------------------------------------------------------
/examples/GUI_TEST.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "26b6dccfbfdf3023",
7 | "metadata": {
8 | "ExecuteTime": {
9 | "end_time": "2024-03-31T23:35:27.783149Z",
10 | "start_time": "2024-03-31T23:35:27.780481Z"
11 | },
12 | "collapsed": false,
13 | "jupyter": {
14 | "outputs_hidden": false
15 | }
16 | },
17 | "outputs": [
18 | {
19 | "name": "stdout",
20 | "output_type": "stream",
21 | "text": [
22 | "\u001b[33mWARNING: typer 0.12.0 does not provide the extra 'all'\u001b[0m\u001b[33m\n",
23 | "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
24 | "\u001b[0m"
25 | ]
26 | }
27 | ],
28 | "source": [
29 | "!pip install git+https://github.com/Bots-Avatar/ExplainitAll -q"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 2,
35 | "id": "eEJyA-XRmLUx",
36 | "metadata": {
37 | "ExecuteTime": {
38 | "end_time": "2024-04-05T18:22:43.620616Z",
39 | "start_time": "2024-04-05T18:22:43.613855Z"
40 | },
41 | "id": "eEJyA-XRmLUx"
42 | },
43 | "outputs": [],
44 | "source": [
45 | "import os\n",
46 | "# os.environ['TEST_MODE_ON_LOW_SPEC_PC'] = 'True'"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 3,
52 | "id": "eb3267fe",
53 | "metadata": {
54 | "ExecuteTime": {
55 | "end_time": "2024-04-05T18:22:48.980250Z",
56 | "start_time": "2024-04-05T18:22:45.467270Z"
57 | },
58 | "id": "eb3267fe"
59 | },
60 | "outputs": [
61 | {
62 | "name": "stderr",
63 | "output_type": "stream",
64 | "text": [
65 | "[nltk_data] Downloading package punkt to /root/nltk_data...\n",
66 | "[nltk_data] Package punkt is already up-to-date!\n"
67 | ]
68 | }
69 | ],
70 | "source": [
71 | "from explainitall.gui.interface import DemoInterface\n",
72 | "from explainitall.gui.interface import set_verbosity_error\n",
73 | "set_verbosity_error()"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 4,
79 | "id": "a63b1664",
80 | "metadata": {
81 | "ExecuteTime": {
82 | "end_time": "2024-04-05T18:22:59.430666Z",
83 | "start_time": "2024-04-05T18:22:50.245045Z"
84 | },
85 | "id": "a63b1664"
86 | },
87 | "outputs": [
88 | {
89 | "data": {
90 | "application/vnd.jupyter.widget-view+json": {
91 | "model_id": "bb93946321c54dcb881a78a165fb0457",
92 | "version_major": 2,
93 | "version_minor": 0
94 | },
95 | "text/plain": [
96 | "modules.json: 0%| | 0.00/341 [00:00, ?B/s]"
97 | ]
98 | },
99 | "metadata": {},
100 | "output_type": "display_data"
101 | },
102 | {
103 | "data": {
104 | "application/vnd.jupyter.widget-view+json": {
105 | "model_id": "47dfe2c2b5b34bd9b167e1101f9a445d",
106 | "version_major": 2,
107 | "version_minor": 0
108 | },
109 | "text/plain": [
110 | "config_sentence_transformers.json: 0%| | 0.00/116 [00:00, ?B/s]"
111 | ]
112 | },
113 | "metadata": {},
114 | "output_type": "display_data"
115 | },
116 | {
117 | "data": {
118 | "application/vnd.jupyter.widget-view+json": {
119 | "model_id": "d897dcc7adc54d65bfbfd7fd2e7afa8a",
120 | "version_major": 2,
121 | "version_minor": 0
122 | },
123 | "text/plain": [
124 | "README.md: 0%| | 0.00/6.24k [00:00, ?B/s]"
125 | ]
126 | },
127 | "metadata": {},
128 | "output_type": "display_data"
129 | },
130 | {
131 | "data": {
132 | "application/vnd.jupyter.widget-view+json": {
133 | "model_id": "2c81cccde0c94e70b2e3d61f66189558",
134 | "version_major": 2,
135 | "version_minor": 0
136 | },
137 | "text/plain": [
138 | "sentence_bert_config.json: 0%| | 0.00/53.0 [00:00, ?B/s]"
139 | ]
140 | },
141 | "metadata": {},
142 | "output_type": "display_data"
143 | },
144 | {
145 | "data": {
146 | "application/vnd.jupyter.widget-view+json": {
147 | "model_id": "0529fe94ef1546129408a4e2949cf0ab",
148 | "version_major": 2,
149 | "version_minor": 0
150 | },
151 | "text/plain": [
152 | "config.json: 0%| | 0.00/685 [00:00, ?B/s]"
153 | ]
154 | },
155 | "metadata": {},
156 | "output_type": "display_data"
157 | },
158 | {
159 | "data": {
160 | "application/vnd.jupyter.widget-view+json": {
161 | "model_id": "474a7812f4e047ec99e047d566c1ecd3",
162 | "version_major": 2,
163 | "version_minor": 0
164 | },
165 | "text/plain": [
166 | "model.safetensors: 0%| | 0.00/44.2M [00:00, ?B/s]"
167 | ]
168 | },
169 | "metadata": {},
170 | "output_type": "display_data"
171 | },
172 | {
173 | "data": {
174 | "application/vnd.jupyter.widget-view+json": {
175 | "model_id": "b2a14fe582f84edea4e4c1a115499f62",
176 | "version_major": 2,
177 | "version_minor": 0
178 | },
179 | "text/plain": [
180 | "tokenizer_config.json: 0%| | 0.00/1.43k [00:00, ?B/s]"
181 | ]
182 | },
183 | "metadata": {},
184 | "output_type": "display_data"
185 | },
186 | {
187 | "data": {
188 | "application/vnd.jupyter.widget-view+json": {
189 | "model_id": "796a2707d1c04b389f94794d7520e6a1",
190 | "version_major": 2,
191 | "version_minor": 0
192 | },
193 | "text/plain": [
194 | "vocab.txt: 0%| | 0.00/241k [00:00, ?B/s]"
195 | ]
196 | },
197 | "metadata": {},
198 | "output_type": "display_data"
199 | },
200 | {
201 | "data": {
202 | "application/vnd.jupyter.widget-view+json": {
203 | "model_id": "6d988421c1754a66aa8ac3166e16f24f",
204 | "version_major": 2,
205 | "version_minor": 0
206 | },
207 | "text/plain": [
208 | "tokenizer.json: 0%| | 0.00/706k [00:00, ?B/s]"
209 | ]
210 | },
211 | "metadata": {},
212 | "output_type": "display_data"
213 | },
214 | {
215 | "data": {
216 | "application/vnd.jupyter.widget-view+json": {
217 | "model_id": "6dfa1938ee51431ab620d6da918995a8",
218 | "version_major": 2,
219 | "version_minor": 0
220 | },
221 | "text/plain": [
222 | "special_tokens_map.json: 0%| | 0.00/695 [00:00, ?B/s]"
223 | ]
224 | },
225 | "metadata": {},
226 | "output_type": "display_data"
227 | },
228 | {
229 | "data": {
230 | "application/vnd.jupyter.widget-view+json": {
231 | "model_id": "e61fffdeeaf145eaa3990ddd9880af4f",
232 | "version_major": 2,
233 | "version_minor": 0
234 | },
235 | "text/plain": [
236 | "1_Pooling/config.json: 0%| | 0.00/190 [00:00, ?B/s]"
237 | ]
238 | },
239 | "metadata": {},
240 | "output_type": "display_data"
241 | },
242 | {
243 | "data": {
244 | "application/vnd.jupyter.widget-view+json": {
245 | "model_id": "8142abc640e44f71baba33124e5628a5",
246 | "version_major": 2,
247 | "version_minor": 0
248 | },
249 | "text/plain": [
250 | "2_Dense/config.json: 0%| | 0.00/114 [00:00, ?B/s]"
251 | ]
252 | },
253 | "metadata": {},
254 | "output_type": "display_data"
255 | },
256 | {
257 | "data": {
258 | "application/vnd.jupyter.widget-view+json": {
259 | "model_id": "38174a4d726f4159a03c537444db18dc",
260 | "version_major": 2,
261 | "version_minor": 0
262 | },
263 | "text/plain": [
264 | "2_Dense/pytorch_model.bin: 0%| | 0.00/482k [00:00, ?B/s]"
265 | ]
266 | },
267 | "metadata": {},
268 | "output_type": "display_data"
269 | },
270 | {
271 | "data": {
272 | "application/vnd.jupyter.widget-view+json": {
273 | "model_id": "55ced2284704490fbe6516f7b1d4fcb3",
274 | "version_major": 2,
275 | "version_minor": 0
276 | },
277 | "text/plain": [
278 | "tokenizer_config.json: 0%| | 0.00/20.1k [00:00, ?B/s]"
279 | ]
280 | },
281 | "metadata": {},
282 | "output_type": "display_data"
283 | },
284 | {
285 | "data": {
286 | "application/vnd.jupyter.widget-view+json": {
287 | "model_id": "1a52fb01b8434634949a34bd1f1858fb",
288 | "version_major": 2,
289 | "version_minor": 0
290 | },
291 | "text/plain": [
292 | "vocab.json: 0%| | 0.00/1.61M [00:00, ?B/s]"
293 | ]
294 | },
295 | "metadata": {},
296 | "output_type": "display_data"
297 | },
298 | {
299 | "data": {
300 | "application/vnd.jupyter.widget-view+json": {
301 | "model_id": "4ccd4f98317445c19d84d3be3f8c066e",
302 | "version_major": 2,
303 | "version_minor": 0
304 | },
305 | "text/plain": [
306 | "merges.txt: 0%| | 0.00/1.27M [00:00, ?B/s]"
307 | ]
308 | },
309 | "metadata": {},
310 | "output_type": "display_data"
311 | },
312 | {
313 | "data": {
314 | "application/vnd.jupyter.widget-view+json": {
315 | "model_id": "ca4bb386acb94b73bd82f9c2505eb3da",
316 | "version_major": 2,
317 | "version_minor": 0
318 | },
319 | "text/plain": [
320 | "tokenizer.json: 0%| | 0.00/3.76M [00:00, ?B/s]"
321 | ]
322 | },
323 | "metadata": {},
324 | "output_type": "display_data"
325 | },
326 | {
327 | "data": {
328 | "application/vnd.jupyter.widget-view+json": {
329 | "model_id": "8a000c7eb1d64a959475498373a84e74",
330 | "version_major": 2,
331 | "version_minor": 0
332 | },
333 | "text/plain": [
334 | "added_tokens.json: 0%| | 0.00/2.74k [00:00, ?B/s]"
335 | ]
336 | },
337 | "metadata": {},
338 | "output_type": "display_data"
339 | },
340 | {
341 | "data": {
342 | "application/vnd.jupyter.widget-view+json": {
343 | "model_id": "32e44969314249ea8238f386c33ef492",
344 | "version_major": 2,
345 | "version_minor": 0
346 | },
347 | "text/plain": [
348 | "special_tokens_map.json: 0%| | 0.00/217 [00:00, ?B/s]"
349 | ]
350 | },
351 | "metadata": {},
352 | "output_type": "display_data"
353 | },
354 | {
355 | "data": {
356 | "application/vnd.jupyter.widget-view+json": {
357 | "model_id": "76c9ce0a77d94cb0ac6ebccaf3b8f353",
358 | "version_major": 2,
359 | "version_minor": 0
360 | },
361 | "text/plain": [
362 | "config.json: 0%| | 0.00/846 [00:00, ?B/s]"
363 | ]
364 | },
365 | "metadata": {},
366 | "output_type": "display_data"
367 | },
368 | {
369 | "data": {
370 | "application/vnd.jupyter.widget-view+json": {
371 | "model_id": "d2afca77bb7f4e2aa1325cd16a685610",
372 | "version_major": 2,
373 | "version_minor": 0
374 | },
375 | "text/plain": [
376 | "model.safetensors.index.json: 0%| | 0.00/50.6k [00:00, ?B/s]"
377 | ]
378 | },
379 | "metadata": {},
380 | "output_type": "display_data"
381 | },
382 | {
383 | "data": {
384 | "application/vnd.jupyter.widget-view+json": {
385 | "model_id": "f769ce39c5ae46699c621ad1fc68934d",
386 | "version_major": 2,
387 | "version_minor": 0
388 | },
389 | "text/plain": [
390 | "Downloading shards: 0%| | 0/2 [00:00, ?it/s]"
391 | ]
392 | },
393 | "metadata": {},
394 | "output_type": "display_data"
395 | },
396 | {
397 | "data": {
398 | "application/vnd.jupyter.widget-view+json": {
399 | "model_id": "5b9d8183a922446c8e18a4fc9cce178f",
400 | "version_major": 2,
401 | "version_minor": 0
402 | },
403 | "text/plain": [
404 | "model-00001-of-00002.safetensors: 0%| | 0.00/4.99G [00:00, ?B/s]"
405 | ]
406 | },
407 | "metadata": {},
408 | "output_type": "display_data"
409 | },
410 | {
411 | "data": {
412 | "application/vnd.jupyter.widget-view+json": {
413 | "model_id": "68d5da5084fe42239305c56794fc3942",
414 | "version_major": 2,
415 | "version_minor": 0
416 | },
417 | "text/plain": [
418 | "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
419 | ]
420 | },
421 | "metadata": {},
422 | "output_type": "display_data"
423 | }
424 | ],
425 | "source": [
426 | "interface = DemoInterface()"
427 | ]
428 | },
429 | {
430 | "cell_type": "code",
431 | "execution_count": 5,
432 | "id": "b8cfafe0",
433 | "metadata": {
434 | "ExecuteTime": {
435 | "end_time": "2024-03-31T23:35:38.643995Z",
436 | "start_time": "2024-03-31T23:35:38.554009Z"
437 | },
438 | "colab": {
439 | "base_uri": "https://localhost:8080/",
440 | "height": 626
441 | },
442 | "id": "b8cfafe0",
443 | "outputId": "58dc6e55-4f15-4966-97f9-d54a2eb98237"
444 | },
445 | "outputs": [
446 | {
447 | "name": "stdout",
448 | "output_type": "stream",
449 | "text": [
450 | "Running on local URL: http://127.0.0.1:7860\n",
451 | "Running on public URL: https://fcf2d7aa43b89548cd.gradio.live\n",
452 | "\n",
453 | "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
454 | ]
455 | },
456 | {
457 | "data": {
458 | "text/html": [
459 | "
"
460 | ],
461 | "text/plain": [
462 | ""
463 | ]
464 | },
465 | "metadata": {},
466 | "output_type": "display_data"
467 | }
468 | ],
469 | "source": [
470 | "interface.launch()"
471 | ]
472 | },
473 | {
474 | "cell_type": "markdown",
475 | "id": "b6sQtIG7rK7t",
476 | "metadata": {
477 | "id": "b6sQtIG7rK7t"
478 | },
479 | "source": [
480 | "interface.launch()\n",
481 | "### То же самое можно сделать это кодом (ниже пример)"
482 | ]
483 | },
484 | {
485 | "cell_type": "code",
486 | "execution_count": 9,
487 | "id": "f1794464",
488 | "metadata": {
489 | "ExecuteTime": {
490 | "end_time": "2024-03-31T23:37:08.072609Z",
491 | "start_time": "2024-03-31T23:37:08.068446Z"
492 | },
493 | "id": "f1794464"
494 | },
495 | "outputs": [],
496 | "source": [
497 | "interface.context_ = 'у кошки грипп и аллергия на антибиотбиотики вопрос: чем лечить кошку? ответ:'\n",
498 | "interface.generated_text_ = 'лечите ее уколами'\n",
499 | "\n",
500 | "interface.clusters_ = [\n",
501 | " {'name': 'Животные', 'centroid': ['собака', 'кошка', 'заяц'], 'top_k': 140},\n",
502 | " {'name': 'Лекарства', 'centroid': ['уколы', 'таблетки', 'микстуры'], 'top_k': 160},\n",
503 | " {'name': 'Болезни', 'centroid': ['простуда', 'орви', 'орз', 'грипп'], 'top_k': 20},\n",
504 | " {'name': 'Аллергия', 'centroid': ['аллергия'], 'top_k': 20}\n",
505 | "]"
506 | ]
507 | },
508 | {
509 | "cell_type": "code",
510 | "execution_count": 10,
511 | "id": "1d12205d",
512 | "metadata": {
513 | "ExecuteTime": {
514 | "end_time": "2024-03-31T23:37:09.287562Z",
515 | "start_time": "2024-03-31T23:37:08.409795Z"
516 | },
517 | "id": "1d12205d"
518 | },
519 | "outputs": [
520 | {
521 | "data": {
522 | "text/plain": [
523 | "(True, True)"
524 | ]
525 | },
526 | "execution_count": 10,
527 | "metadata": {},
528 | "output_type": "execute_result"
529 | }
530 | ],
531 | "source": [
532 | "interface.load_nlp_model_('http://vectors.nlpl.eu/repository/20/180.zip')"
533 | ]
534 | },
535 | {
536 | "cell_type": "code",
537 | "execution_count": 11,
538 | "id": "3fe69def",
539 | "metadata": {
540 | "ExecuteTime": {
541 | "end_time": "2024-03-31T23:37:12.280943Z",
542 | "start_time": "2024-03-31T23:37:09.288664Z"
543 | },
544 | "id": "3fe69def"
545 | },
546 | "outputs": [
547 | {
548 | "data": {
549 | "text/plain": [
550 | "(True, True)"
551 | ]
552 | },
553 | "execution_count": 11,
554 | "metadata": {},
555 | "output_type": "execute_result"
556 | }
557 | ],
558 | "source": [
559 | "interface.load_nn_model_(\"sberbank-ai/rugpt3small_based_on_gpt2\")"
560 | ]
561 | }
562 | ],
563 | "metadata": {
564 | "colab": {
565 | "provenance": []
566 | },
567 | "kernelspec": {
568 | "display_name": "Python 3 (ipykernel)",
569 | "language": "python",
570 | "name": "python3"
571 | },
572 | "language_info": {
573 | "codemirror_mode": {
574 | "name": "ipython",
575 | "version": 3
576 | },
577 | "file_extension": ".py",
578 | "mimetype": "text/x-python",
579 | "name": "python",
580 | "nbconvert_exporter": "python",
581 | "pygments_lexer": "ipython3",
582 | "version": "3.10.13"
583 | }
584 | },
585 | "nbformat": 4,
586 | "nbformat_minor": 5
587 | }
588 |
--------------------------------------------------------------------------------
/examples/KNN_RAG_interp.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "metadata": {
7 | "ExecuteTime": {
8 | "end_time": "2024-03-31T23:37:41.442738Z",
9 | "start_time": "2024-03-31T23:37:41.440452Z"
10 | },
11 | "collapsed": false,
12 | "jupyter": {
13 | "outputs_hidden": false
14 | }
15 | },
16 | "outputs": [
17 | {
18 | "name": "stdout",
19 | "output_type": "stream",
20 | "text": [
21 | "\u001b[33mWARNING: typer 0.12.0 does not provide the extra 'all'\u001b[0m\u001b[33m\n",
22 | "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
23 | "\u001b[0m"
24 | ]
25 | }
26 | ],
27 | "source": [
28 | "!pip install git+https://github.com/Bots-Avatar/ExplainitAll -q"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 10,
34 | "metadata": {
35 | "ExecuteTime": {
36 | "end_time": "2024-03-31T23:37:38.390677Z",
37 | "start_time": "2024-03-31T23:37:38.384997Z"
38 | },
39 | "id": "aXqbhzKyh3I7"
40 | },
41 | "outputs": [],
42 | "source": [
43 | "import os\n",
44 | "# os.environ['TEST_MODE_ON_LOW_SPEC_PC'] = 'True'"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 1,
50 | "metadata": {
51 | "ExecuteTime": {
52 | "end_time": "2024-04-05T18:27:05.551805Z",
53 | "start_time": "2024-04-05T18:27:00.941860Z"
54 | },
55 | "id": "aba3N6uRpmy0"
56 | },
57 | "outputs": [
58 | {
59 | "name": "stderr",
60 | "output_type": "stream",
61 | "text": [
62 | "[nltk_data] Downloading package punkt to /root/nltk_data...\n",
63 | "[nltk_data] Package punkt is already up-to-date!\n"
64 | ]
65 | }
66 | ],
67 | "source": [
68 | "from explainitall.QA.interp_qa.KNNWithGenerative import FredStruct, PromptBot\n",
69 | "from explainitall.QA.extractive_qa_sbert.SVDBert import SVDBertModel\n",
70 | "from explainitall.QA.extractive_qa_sbert.QABotsBase import cos_dist\n",
71 | "from explainitall.gpt_like_interp.downloader import DownloadManager\n",
72 | "\n",
73 | "from sklearn.neighbors import KNeighborsClassifier\n",
74 | "from sentence_transformers import SentenceTransformer\n",
75 | "import gensim\n",
76 | "from inseq import load_model\n",
77 | "from explainitall.gpt_like_interp import interp\n",
78 | "from explainitall.gui.interface import set_verbosity_error\n",
79 | "set_verbosity_error()"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 2,
85 | "metadata": {
86 | "ExecuteTime": {
87 | "end_time": "2024-04-05T18:28:06.598120Z",
88 | "start_time": "2024-04-05T18:28:05.680831Z"
89 | },
90 | "id": "IPrg310VtVNy"
91 | },
92 | "outputs": [
93 | {
94 | "name": "stderr",
95 | "output_type": "stream",
96 | "text": [
97 | "Downloading: /root/.cache/180_zip: 100%|██████████| 462M/462M [00:26<00:00, 18.0MiB/s] \n",
98 | "Extracting: /root/.cache/180_zip_data: 100%|██████████| 4/4 [00:05<00:00, 1.26s/it]\n"
99 | ]
100 | }
101 | ],
102 | "source": [
103 | "def load_nlp_model(nlp_model_url):\n",
104 | " nlp_model_path = DownloadManager.load_zip(nlp_model_url)\n",
105 | " return gensim.models.KeyedVectors.load_word2vec_format(nlp_model_path, binary=True)\n",
106 | "\n",
107 | "# 'ID': 180\n",
108 | "# 'Размер вектора': 300\n",
109 | "# 'Корпус': 'Russian National Corpus'\n",
110 | "# 'Размер словаря': 189193\n",
111 | "# 'Алгоритм': 'Gensim Continuous Bag-of-Words'\n",
112 | "# 'Лемматизация': True\n",
113 | "\n",
114 | "nlp_model = load_nlp_model ('http://vectors.nlpl.eu/repository/20/180.zip')"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 4,
120 | "metadata": {
121 | "ExecuteTime": {
122 | "end_time": "2024-04-05T18:28:07.239780Z",
123 | "start_time": "2024-04-05T18:28:07.237071Z"
124 | },
125 | "id": "Z_WDkdpKtcd8"
126 | },
127 | "outputs": [],
128 | "source": [
129 | "model_path = \"sberbank-ai/rugpt3small_based_on_gpt2\""
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": 5,
135 | "metadata": {
136 | "ExecuteTime": {
137 | "end_time": "2024-04-05T18:28:12.419667Z",
138 | "start_time": "2024-04-05T18:28:08.710629Z"
139 | },
140 | "id": "E7jniFg2thv4"
141 | },
142 | "outputs": [
143 | {
144 | "data": {
145 | "application/vnd.jupyter.widget-view+json": {
146 | "model_id": "652f3a84abc04530b5e87140d7366f31",
147 | "version_major": 2,
148 | "version_minor": 0
149 | },
150 | "text/plain": [
151 | "config.json: 0%| | 0.00/720 [00:00, ?B/s]"
152 | ]
153 | },
154 | "metadata": {},
155 | "output_type": "display_data"
156 | },
157 | {
158 | "data": {
159 | "application/vnd.jupyter.widget-view+json": {
160 | "model_id": "2cd0fe2b22e9405b829fe0311d4997a8",
161 | "version_major": 2,
162 | "version_minor": 0
163 | },
164 | "text/plain": [
165 | "pytorch_model.bin: 0%| | 0.00/551M [00:00, ?B/s]"
166 | ]
167 | },
168 | "metadata": {},
169 | "output_type": "display_data"
170 | },
171 | {
172 | "data": {
173 | "application/vnd.jupyter.widget-view+json": {
174 | "model_id": "3cab3fff6c4441b284824c1af00a1f17",
175 | "version_major": 2,
176 | "version_minor": 0
177 | },
178 | "text/plain": [
179 | "tokenizer_config.json: 0%| | 0.00/1.25k [00:00, ?B/s]"
180 | ]
181 | },
182 | "metadata": {},
183 | "output_type": "display_data"
184 | },
185 | {
186 | "data": {
187 | "application/vnd.jupyter.widget-view+json": {
188 | "model_id": "d5de045bf14e4d45a9403bab7c711a2b",
189 | "version_major": 2,
190 | "version_minor": 0
191 | },
192 | "text/plain": [
193 | "vocab.json: 0%| | 0.00/1.71M [00:00, ?B/s]"
194 | ]
195 | },
196 | "metadata": {},
197 | "output_type": "display_data"
198 | },
199 | {
200 | "data": {
201 | "application/vnd.jupyter.widget-view+json": {
202 | "model_id": "c82bff84bea243c6b04cd4be6653aa1f",
203 | "version_major": 2,
204 | "version_minor": 0
205 | },
206 | "text/plain": [
207 | "merges.txt: 0%| | 0.00/1.27M [00:00, ?B/s]"
208 | ]
209 | },
210 | "metadata": {},
211 | "output_type": "display_data"
212 | },
213 | {
214 | "data": {
215 | "application/vnd.jupyter.widget-view+json": {
216 | "model_id": "4ce80a25f9574e3bb6067190d56013ae",
217 | "version_major": 2,
218 | "version_minor": 0
219 | },
220 | "text/plain": [
221 | "special_tokens_map.json: 0%| | 0.00/574 [00:00, ?B/s]"
222 | ]
223 | },
224 | "metadata": {},
225 | "output_type": "display_data"
226 | }
227 | ],
228 | "source": [
229 | "def load_gpt_model(gpt_model_name):\n",
230 | " return load_model(model=gpt_model_name,\n",
231 | " attribution_method=\"integrated_gradients\")\n",
232 | "\n",
233 | "# 'Фреймворк': 'transformers'\n",
234 | "# 'Тренировочные токены': '80 млрд'\n",
235 | "# 'Размер контекста': 2048\n",
236 | "\n",
237 | "gpt_model = load_gpt_model(model_path)"
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": 6,
243 | "metadata": {
244 | "ExecuteTime": {
245 | "end_time": "2024-04-05T18:28:17.901542Z",
246 | "start_time": "2024-04-05T18:28:17.894641Z"
247 | },
248 | "id": "RPz6kPJTtzND"
249 | },
250 | "outputs": [],
251 | "source": [
252 | "import re\n",
253 | "\n",
254 | "def clean_string(text):\n",
255 | " \"\"\"\n",
256 | " Очистка строки\n",
257 | " \"\"\"\n",
258 | " seq = text.replace('\\n',' ')\n",
259 | " r_char = re.compile('[^A-zА-яЁё0-9\": ]')\n",
260 | " r_spaces = re.compile(r\"\\s+\")\n",
261 | " seq = r_char.sub(' ', seq)\n",
262 | " seq = r_spaces.sub(' ', seq).strip()\n",
263 | " return seq.lower()\n",
264 | "\n",
265 | "def value_interp(v):\n",
266 | " if str(v) == 'nan':\n",
267 | " return 'нулевой'\n",
268 | " if v < 0.1:\n",
269 | " return 'незначительной'\n",
270 | " if v < 0.3:\n",
271 | " return 'очень малой'\n",
272 | " if v < 0.45:\n",
273 | " return 'малой'\n",
274 | " if v < 0.65:\n",
275 | " return 'средней'\n",
276 | " if v < 0.85:\n",
277 | " return 'выше средней'\n",
278 | " else:\n",
279 | " return 'очень большой'\n",
280 | "\n",
281 | "def interp_cl(df):\n",
282 | " ret = []\n",
283 | " for index, row in df.iterrows():\n",
284 | " for num_col, col in enumerate(df.columns):\n",
285 | " if num_col != 0:\n",
286 | " value = row[col]\n",
287 | "\n",
288 | " description = f'Кластер \"{row[df.columns[0]]}\" влияет на генерацию кластера \"{col}\" с {value_interp(value)} силой.'\n",
289 | " ret += [description]\n",
290 | "\n",
291 | " return ret\n"
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "execution_count": 7,
297 | "metadata": {
298 | "ExecuteTime": {
299 | "end_time": "2024-04-05T18:29:15.085325Z",
300 | "start_time": "2024-04-05T18:28:27.400858Z"
301 | },
302 | "colab": {
303 | "base_uri": "https://localhost:8080/"
304 | },
305 | "id": "ul803--vuCsS",
306 | "outputId": "249a5409-bab2-4107-c874-2df4ab4359d3"
307 | },
308 | "outputs": [
309 | {
310 | "name": "stderr",
311 | "output_type": "stream",
312 | "text": [
313 | "Attributing with integrated_gradients...: 100%|██████████| 26/26 [00:01<00:00, 3.64it/s]\n"
314 | ]
315 | }
316 | ],
317 | "source": [
318 | "clusters_discr = [\n",
319 | " {'name': 'Животные', 'centroid': ['собака', 'кошка', 'заяц'], 'top_k': 140},\n",
320 | " {'name': 'Лекарства', 'centroid': ['уколы', 'таблетки', 'микстуры'], 'top_k': 160},\n",
321 | " {'name': 'Болезни', 'centroid': ['простуда', 'орви', 'орз', 'грипп'], 'top_k': 20},\n",
322 | " {'name': 'Аллергия', 'centroid': ['аллергия'], 'top_k': 20}\n",
323 | "]\n",
324 | "\n",
325 | "explainer = interp.ExplainerGPT2(gpt_model=gpt_model, nlp_model=nlp_model)\n",
326 | "\n",
327 | "\n",
328 | "input_text = 'у кошки грипп и аллергия на антибиотбиотики вопрос: чем лечить кошку? ответ:'\n",
329 | "generated_text = 'лечите ее уколами'\n",
330 | "\n",
331 | "expl_data = explainer.interpret(\n",
332 | " input_texts=input_text,\n",
333 | " generated_texts=generated_text,\n",
334 | " clusters_description=clusters_discr,\n",
335 | " batch_size=50,\n",
336 | " steps=34,\n",
337 | ")"
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "execution_count": 8,
343 | "metadata": {
344 | "ExecuteTime": {
345 | "end_time": "2024-04-05T18:29:26.635880Z",
346 | "start_time": "2024-04-05T18:29:23.917315Z"
347 | },
348 | "id": "XxJyiEP9vVDv"
349 | },
350 | "outputs": [],
351 | "source": [
352 | "# Результат интерпретации\n",
353 | "imp_df_cl = expl_data.cluster_imp_aggr_df\n",
354 | "cl_desc = interp_cl(imp_df_cl)"
355 | ]
356 | },
357 | {
358 | "cell_type": "code",
359 | "execution_count": 11,
360 | "metadata": {
361 | "ExecuteTime": {
362 | "end_time": "2024-03-31T23:38:07.420913Z",
363 | "start_time": "2024-03-31T23:38:02.557049Z"
364 | },
365 | "id": "KoBhDgazwMBQ"
366 | },
367 | "outputs": [
368 | {
369 | "data": {
370 | "application/vnd.jupyter.widget-view+json": {
371 | "model_id": "090dad482d6540f7b1e009dc85020832",
372 | "version_major": 2,
373 | "version_minor": 0
374 | },
375 | "text/plain": [
376 | "tokenizer_config.json: 0%| | 0.00/20.1k [00:00, ?B/s]"
377 | ]
378 | },
379 | "metadata": {},
380 | "output_type": "display_data"
381 | },
382 | {
383 | "data": {
384 | "application/vnd.jupyter.widget-view+json": {
385 | "model_id": "064141bcb3584cb39ff590e5ce484570",
386 | "version_major": 2,
387 | "version_minor": 0
388 | },
389 | "text/plain": [
390 | "vocab.json: 0%| | 0.00/1.61M [00:00, ?B/s]"
391 | ]
392 | },
393 | "metadata": {},
394 | "output_type": "display_data"
395 | },
396 | {
397 | "data": {
398 | "application/vnd.jupyter.widget-view+json": {
399 | "model_id": "e873897b1ccd485b9d0876ca71195b93",
400 | "version_major": 2,
401 | "version_minor": 0
402 | },
403 | "text/plain": [
404 | "merges.txt: 0%| | 0.00/1.27M [00:00, ?B/s]"
405 | ]
406 | },
407 | "metadata": {},
408 | "output_type": "display_data"
409 | },
410 | {
411 | "data": {
412 | "application/vnd.jupyter.widget-view+json": {
413 | "model_id": "482cc40c4cfa4b89a23ba96da504b685",
414 | "version_major": 2,
415 | "version_minor": 0
416 | },
417 | "text/plain": [
418 | "tokenizer.json: 0%| | 0.00/3.76M [00:00, ?B/s]"
419 | ]
420 | },
421 | "metadata": {},
422 | "output_type": "display_data"
423 | },
424 | {
425 | "data": {
426 | "application/vnd.jupyter.widget-view+json": {
427 | "model_id": "cd049c6490bd409685597d5bc656d5d6",
428 | "version_major": 2,
429 | "version_minor": 0
430 | },
431 | "text/plain": [
432 | "added_tokens.json: 0%| | 0.00/2.74k [00:00, ?B/s]"
433 | ]
434 | },
435 | "metadata": {},
436 | "output_type": "display_data"
437 | },
438 | {
439 | "data": {
440 | "application/vnd.jupyter.widget-view+json": {
441 | "model_id": "441592f0ae92449c99cbcf7889e9a1f7",
442 | "version_major": 2,
443 | "version_minor": 0
444 | },
445 | "text/plain": [
446 | "special_tokens_map.json: 0%| | 0.00/217 [00:00, ?B/s]"
447 | ]
448 | },
449 | "metadata": {},
450 | "output_type": "display_data"
451 | },
452 | {
453 | "data": {
454 | "application/vnd.jupyter.widget-view+json": {
455 | "model_id": "4d50a3e9821842a89461e171ec11f631",
456 | "version_major": 2,
457 | "version_minor": 0
458 | },
459 | "text/plain": [
460 | "config.json: 0%| | 0.00/846 [00:00, ?B/s]"
461 | ]
462 | },
463 | "metadata": {},
464 | "output_type": "display_data"
465 | },
466 | {
467 | "data": {
468 | "application/vnd.jupyter.widget-view+json": {
469 | "model_id": "bd953b28340c4af4b1089c933da603ac",
470 | "version_major": 2,
471 | "version_minor": 0
472 | },
473 | "text/plain": [
474 | "model.safetensors.index.json: 0%| | 0.00/50.6k [00:00, ?B/s]"
475 | ]
476 | },
477 | "metadata": {},
478 | "output_type": "display_data"
479 | },
480 | {
481 | "data": {
482 | "application/vnd.jupyter.widget-view+json": {
483 | "model_id": "e6bcc4399d0b40b3ab48b7bd5a697ef9",
484 | "version_major": 2,
485 | "version_minor": 0
486 | },
487 | "text/plain": [
488 | "Downloading shards: 0%| | 0/2 [00:00, ?it/s]"
489 | ]
490 | },
491 | "metadata": {},
492 | "output_type": "display_data"
493 | },
494 | {
495 | "data": {
496 | "application/vnd.jupyter.widget-view+json": {
497 | "model_id": "87fff118497f489cab05454efccf172c",
498 | "version_major": 2,
499 | "version_minor": 0
500 | },
501 | "text/plain": [
502 | "model-00001-of-00002.safetensors: 0%| | 0.00/4.99G [00:00, ?B/s]"
503 | ]
504 | },
505 | "metadata": {},
506 | "output_type": "display_data"
507 | },
508 | {
509 | "data": {
510 | "application/vnd.jupyter.widget-view+json": {
511 | "model_id": "0ea68b0773e84198bbde13968794c43b",
512 | "version_major": 2,
513 | "version_minor": 0
514 | },
515 | "text/plain": [
516 | "model-00002-of-00002.safetensors: 0%| | 0.00/1.97G [00:00, ?B/s]"
517 | ]
518 | },
519 | "metadata": {},
520 | "output_type": "display_data"
521 | },
522 | {
523 | "data": {
524 | "application/vnd.jupyter.widget-view+json": {
525 | "model_id": "dfefff8bbf694201a35ba66fcaca5314",
526 | "version_major": 2,
527 | "version_minor": 0
528 | },
529 | "text/plain": [
530 | "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
531 | ]
532 | },
533 | "metadata": {},
534 | "output_type": "display_data"
535 | },
536 | {
537 | "data": {
538 | "application/vnd.jupyter.widget-view+json": {
539 | "model_id": "175e575e657f4a4fa9f68c94d3812b96",
540 | "version_major": 2,
541 | "version_minor": 0
542 | },
543 | "text/plain": [
544 | "generation_config.json: 0%| | 0.00/142 [00:00, ?B/s]"
545 | ]
546 | },
547 | "metadata": {},
548 | "output_type": "display_data"
549 | }
550 | ],
551 | "source": [
552 | "path_sbert = 'FractalGPT/SbertSVDDistil'\n",
553 | "sbert = SentenceTransformer(path_sbert)\n",
554 | "sbert[0].auto_model = SVDBertModel.from_pretrained(path_sbert)\n",
555 | "\n",
556 | "if os.getenv('TEST_MODE_ON_LOW_SPEC_PC') == 'True':\n",
557 | " fred = FredStruct('t5-small' )\n",
558 | "else:\n",
559 | " fred = FredStruct('FractalGPT/FRED-T5-Interp')"
560 | ]
561 | },
562 | {
563 | "cell_type": "code",
564 | "execution_count": 12,
565 | "metadata": {
566 | "ExecuteTime": {
567 | "end_time": "2024-03-31T23:38:09.020464Z",
568 | "start_time": "2024-03-31T23:38:09.013877Z"
569 | },
570 | "colab": {
571 | "base_uri": "https://localhost:8080/"
572 | },
573 | "id": "iaZDK_iWwxLL",
574 | "outputId": "032a1eff-43e8-4ce6-a0fe-f49d749e43a2"
575 | },
576 | "outputs": [
577 | {
578 | "data": {
579 | "text/plain": [
580 | "['Кластер \"Аллергия\" влияет на генерацию кластера \"Лекарства\" с выше средней силой.',\n",
581 | " 'Кластер \"Болезни\" влияет на генерацию кластера \"Лекарства\" с выше средней силой.',\n",
582 | " 'Кластер \"Животные\" влияет на генерацию кластера \"Лекарства\" с средней силой.',\n",
583 | " 'Кластер \"Лекарства\" влияет на генерацию кластера \"Лекарства\" с нулевой силой.']"
584 | ]
585 | },
586 | "execution_count": 12,
587 | "metadata": {},
588 | "output_type": "execute_result"
589 | }
590 | ],
591 | "source": [
592 | "cl_desc"
593 | ]
594 | },
595 | {
596 | "cell_type": "code",
597 | "execution_count": 13,
598 | "metadata": {
599 | "ExecuteTime": {
600 | "end_time": "2024-03-31T23:38:12.308493Z",
601 | "start_time": "2024-03-31T23:38:12.268185Z"
602 | },
603 | "id": "OT7tJxeVwpRX"
604 | },
605 | "outputs": [],
606 | "source": [
607 | "clean = [clean_string(cl_data) for cl_data in cl_desc]\n",
608 | "vects_x = sbert.encode(clean)\n",
609 | "m = vects_x.mean(axis=0)\n",
610 | "s = vects_x.std(axis=0)\n",
611 | "knn_vects_x = (vects_x - m)/s\n",
612 | "knn = KNeighborsClassifier(metric=cos_dist)\n",
613 | "knn.fit(knn_vects_x, cl_desc)\n",
614 | "\n",
615 | "interp_bot = PromptBot(knn, sbert, fred, cl_desc, device='cpu')"
616 | ]
617 | },
618 | {
619 | "cell_type": "code",
620 | "execution_count": 14,
621 | "metadata": {
622 | "ExecuteTime": {
623 | "end_time": "2024-03-31T23:38:17.352373Z",
624 | "start_time": "2024-03-31T23:38:13.512955Z"
625 | },
626 | "colab": {
627 | "base_uri": "https://localhost:8080/",
628 | "height": 36
629 | },
630 | "id": "kl0pEqje0qPy",
631 | "outputId": "ce8a27a2-6d37-4287-f9b9-284db74e3e4b"
632 | },
633 | "outputs": [
634 | {
635 | "data": {
636 | "text/plain": [
637 | "'Кластер \"Аллергия\" влияет на генерацию кластера \"Лекарства\" с выше средней силой'"
638 | ]
639 | },
640 | "execution_count": 14,
641 | "metadata": {},
642 | "output_type": "execute_result"
643 | }
644 | ],
645 | "source": [
646 | "ans = interp_bot.get_answers('Как влияет аллергия на назначение лекарства', top_k=3)\n",
647 | "ans.split('.')[0]"
648 | ]
649 | },
650 | {
651 | "cell_type": "code",
652 | "execution_count": 15,
653 | "metadata": {
654 | "ExecuteTime": {
655 | "end_time": "2024-03-31T23:38:18.961225Z",
656 | "start_time": "2024-03-31T23:38:17.354105Z"
657 | },
658 | "colab": {
659 | "base_uri": "https://localhost:8080/",
660 | "height": 36
661 | },
662 | "id": "0CHTOM0V3FT6",
663 | "outputId": "d58af991-a2af-4a27-b9b6-8949197ed02e"
664 | },
665 | "outputs": [
666 | {
667 | "data": {
668 | "text/plain": [
669 | "'Кластер \"Болезни\" влияет на кластер \"Лекарства\" с выше средней силой Кластер \"Аллергия\" влияет кластером \"Лекаря\" с нулевой силой'"
670 | ]
671 | },
672 | "execution_count": 15,
673 | "metadata": {},
674 | "output_type": "execute_result"
675 | }
676 | ],
677 | "source": [
678 | "ans = interp_bot.get_answers('Как влияет кластер болезни на кластер лекарства', top_k=3)\n",
679 | "ans.split('.')[0]"
680 | ]
681 | }
682 | ],
683 | "metadata": {
684 | "accelerator": "GPU",
685 | "colab": {
686 | "gpuType": "T4",
687 | "provenance": []
688 | },
689 | "kernelspec": {
690 | "display_name": "Python 3 (ipykernel)",
691 | "language": "python",
692 | "name": "python3"
693 | },
694 | "language_info": {
695 | "codemirror_mode": {
696 | "name": "ipython",
697 | "version": 3
698 | },
699 | "file_extension": ".py",
700 | "mimetype": "text/x-python",
701 | "name": "python",
702 | "nbconvert_exporter": "python",
703 | "pygments_lexer": "ipython3",
704 | "version": "3.10.13"
705 | }
706 | },
707 | "nbformat": 4,
708 | "nbformat_minor": 4
709 | }
710 |
--------------------------------------------------------------------------------
/examples/Metric_calculation_test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "2aa5a4c2eea7f6a5",
7 | "metadata": {
8 | "ExecuteTime": {
9 | "end_time": "2024-03-31T23:38:42.574605Z",
10 | "start_time": "2024-03-31T23:38:42.569784Z"
11 | },
12 | "collapsed": false,
13 | "jupyter": {
14 | "outputs_hidden": false
15 | }
16 | },
17 | "outputs": [],
18 | "source": [
19 | "import os\n",
20 | "# os.environ['TEST_MODE_ON_LOW_SPEC_PC'] = 'True'"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 2,
26 | "id": "6e6be105",
27 | "metadata": {
28 | "ExecuteTime": {
29 | "end_time": "2024-03-31T23:38:42.578042Z",
30 | "start_time": "2024-03-31T23:38:42.575973Z"
31 | }
32 | },
33 | "outputs": [
34 | {
35 | "name": "stdout",
36 | "output_type": "stream",
37 | "text": [
38 | "\u001b[33mWARNING: typer 0.12.1 does not provide the extra 'all'\u001b[0m\u001b[33m\n",
39 | "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
40 | "\u001b[0m"
41 | ]
42 | }
43 | ],
44 | "source": [
45 | "!pip install git+https://github.com/Bots-Avatar/ExplainitAll -q"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 5,
51 | "id": "60df7ce2",
52 | "metadata": {
53 | "ExecuteTime": {
54 | "end_time": "2024-04-05T18:49:59.907480Z",
55 | "start_time": "2024-04-05T18:49:57.102190Z"
56 | }
57 | },
58 | "outputs": [],
59 | "source": [
60 | "from explainitall.metrics.RougeAndPPL.metric_calculation_interface import MetricCalculationInterface"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 6,
66 | "id": "ff329a8a",
67 | "metadata": {
68 | "ExecuteTime": {
69 | "end_time": "2024-04-05T18:50:00.973275Z",
70 | "start_time": "2024-04-05T18:50:00.794970Z"
71 | }
72 | },
73 | "outputs": [],
74 | "source": [
75 | "metric_calculation_interface = MetricCalculationInterface()"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 7,
81 | "id": "377e2671",
82 | "metadata": {
83 | "ExecuteTime": {
84 | "end_time": "2024-04-05T18:50:23.788943Z",
85 | "start_time": "2024-04-05T18:50:03.363796Z"
86 | },
87 | "scrolled": true
88 | },
89 | "outputs": [
90 | {
91 | "name": "stdout",
92 | "output_type": "stream",
93 | "text": [
94 | "Running on local URL: http://127.0.0.1:7861\n",
95 | "Running on public URL: https://5c4ca8ac6d53b93c20.gradio.live\n",
96 | "\n",
97 | "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
98 | ]
99 | },
100 | {
101 | "data": {
102 | "text/html": [
103 | ""
104 | ],
105 | "text/plain": [
106 | ""
107 | ]
108 | },
109 | "metadata": {},
110 | "output_type": "display_data"
111 | }
112 | ],
113 | "source": [
114 | "metric_calculation_interface.launch()"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 8,
120 | "id": "d6c66fd6",
121 | "metadata": {
122 | "ExecuteTime": {
123 | "end_time": "2024-03-31T23:39:37.429481Z",
124 | "start_time": "2024-03-31T23:39:35.850815Z"
125 | }
126 | },
127 | "outputs": [
128 | {
129 | "data": {
130 | "application/vnd.jupyter.widget-view+json": {
131 | "model_id": "2ce462cc87e240c5b51d09f4cb40108d",
132 | "version_major": 2,
133 | "version_minor": 0
134 | },
135 | "text/plain": [
136 | "tokenizer_config.json: 0%| | 0.00/26.0 [00:00, ?B/s]"
137 | ]
138 | },
139 | "metadata": {},
140 | "output_type": "display_data"
141 | },
142 | {
143 | "data": {
144 | "application/vnd.jupyter.widget-view+json": {
145 | "model_id": "fcd02212255f4df78a0cd143f8eb683c",
146 | "version_major": 2,
147 | "version_minor": 0
148 | },
149 | "text/plain": [
150 | "vocab.json: 0%| | 0.00/1.04M [00:00, ?B/s]"
151 | ]
152 | },
153 | "metadata": {},
154 | "output_type": "display_data"
155 | },
156 | {
157 | "data": {
158 | "application/vnd.jupyter.widget-view+json": {
159 | "model_id": "d8c0a2205c484ea1b6d0350f8ca2c35f",
160 | "version_major": 2,
161 | "version_minor": 0
162 | },
163 | "text/plain": [
164 | "merges.txt: 0%| | 0.00/456k [00:00, ?B/s]"
165 | ]
166 | },
167 | "metadata": {},
168 | "output_type": "display_data"
169 | },
170 | {
171 | "data": {
172 | "application/vnd.jupyter.widget-view+json": {
173 | "model_id": "d5059c9261854d3c9821926d6892b2a5",
174 | "version_major": 2,
175 | "version_minor": 0
176 | },
177 | "text/plain": [
178 | "tokenizer.json: 0%| | 0.00/1.36M [00:00, ?B/s]"
179 | ]
180 | },
181 | "metadata": {},
182 | "output_type": "display_data"
183 | },
184 | {
185 | "data": {
186 | "application/vnd.jupyter.widget-view+json": {
187 | "model_id": "8a8a68da87174bf3b1333c3191ccb352",
188 | "version_major": 2,
189 | "version_minor": 0
190 | },
191 | "text/plain": [
192 | "config.json: 0%| | 0.00/762 [00:00, ?B/s]"
193 | ]
194 | },
195 | "metadata": {},
196 | "output_type": "display_data"
197 | },
198 | {
199 | "data": {
200 | "application/vnd.jupyter.widget-view+json": {
201 | "model_id": "1c39eb13bf4a40ec827a1d594cbfaf48",
202 | "version_major": 2,
203 | "version_minor": 0
204 | },
205 | "text/plain": [
206 | "model.safetensors: 0%| | 0.00/353M [00:00, ?B/s]"
207 | ]
208 | },
209 | "metadata": {},
210 | "output_type": "display_data"
211 | },
212 | {
213 | "data": {
214 | "application/vnd.jupyter.widget-view+json": {
215 | "model_id": "e3c2c88a3dbd4c6b96f0111651eeea43",
216 | "version_major": 2,
217 | "version_minor": 0
218 | },
219 | "text/plain": [
220 | "generation_config.json: 0%| | 0.00/124 [00:00, ?B/s]"
221 | ]
222 | },
223 | "metadata": {},
224 | "output_type": "display_data"
225 | },
226 | {
227 | "data": {
228 | "text/plain": [
229 | "True"
230 | ]
231 | },
232 | "execution_count": 8,
233 | "metadata": {},
234 | "output_type": "execute_result"
235 | },
236 | {
237 | "name": "stdout",
238 | "output_type": "stream",
239 | "text": [
240 | "csvfile /tmp/gradio/a22c5fe5805096466e4b24533c25eb8077375399/metrix.csv\n",
241 | "title Dataset_metrix\n",
242 | "2024-04-05 20:45:18,022 INFO sqlalchemy.engine.Engine BEGIN (implicit)\n",
243 | "2024-04-05 20:45:18,025 INFO sqlalchemy.engine.Engine SELECT max(data.dataset_version) AS max_1 \n",
244 | "FROM data \n",
245 | "WHERE data.dataset_name = ?\n",
246 | "2024-04-05 20:45:18,026 INFO sqlalchemy.engine.Engine [generated in 0.00389s] ('Dataset_metrix',)\n"
247 | ]
248 | },
249 | {
250 | "name": "stderr",
251 | "output_type": "stream",
252 | "text": [
253 | "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n",
254 | "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n",
255 | "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n",
256 | "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
257 | ]
258 | },
259 | {
260 | "name": "stdout",
261 | "output_type": "stream",
262 | "text": [
263 | "2024-04-05 20:46:33,539 INFO sqlalchemy.engine.Engine INSERT INTO data (model_name, timestamp, dataset_name, dataset_version, \"PPL\", \"R3\", \"R5\", \"R-L\") VALUES (?, ?, ?, ?, ?, ?, ?, ?)\n",
264 | "2024-04-05 20:46:33,541 INFO sqlalchemy.engine.Engine [generated in 0.00232s] ('distilgpt2', 1712349993, 'Dataset_metrix', 0, 185.80456829071045, 0.0, 0.0, 0.0)\n",
265 | "2024-04-05 20:46:33,620 INFO sqlalchemy.engine.Engine COMMIT\n",
266 | "2024-04-05 20:46:44,718 INFO sqlalchemy.engine.Engine BEGIN (implicit)\n",
267 | "2024-04-05 20:46:44,722 INFO sqlalchemy.engine.Engine SELECT data.id, data.model_name, data.timestamp, data.dataset_name, data.dataset_version, data.\"PPL\", data.\"R3\", data.\"R5\", data.\"R-L\" \n",
268 | "FROM data\n",
269 | "2024-04-05 20:46:44,724 INFO sqlalchemy.engine.Engine [generated in 0.00576s] ()\n"
270 | ]
271 | }
272 | ],
273 | "source": [
274 | "model_name = \"distilgpt2\"\n",
275 | "# model_name = \"gpt2\"\n",
276 | "# model_name = \"sberbank-ai/rugpt3small_based_on_gpt2\"\n",
277 | "\n",
278 | "\n",
279 | "metric_calculation_interface.load_model_(model_name)"
280 | ]
281 | }
282 | ],
283 | "metadata": {
284 | "kernelspec": {
285 | "display_name": "Python 3 (ipykernel)",
286 | "language": "python",
287 | "name": "python3"
288 | },
289 | "language_info": {
290 | "codemirror_mode": {
291 | "name": "ipython",
292 | "version": 3
293 | },
294 | "file_extension": ".py",
295 | "mimetype": "text/x-python",
296 | "name": "python",
297 | "nbconvert_exporter": "python",
298 | "pygments_lexer": "ipython3",
299 | "version": "3.10.13"
300 | }
301 | },
302 | "nbformat": 4,
303 | "nbformat_minor": 5
304 | }
305 |
--------------------------------------------------------------------------------
/examples/example_data/clusters.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "name": "Животные",
4 | "centroid": [
5 | "собака",
6 | "кошка",
7 | "заяц"
8 | ],
9 | "top_k": 140
10 | },
11 | {
12 | "name": "Лекарства",
13 | "centroid": [
14 | "уколы",
15 | "таблетки",
16 | "микстуры"
17 | ],
18 | "top_k": 160
19 | },
20 | {
21 | "name": "Болезни",
22 | "centroid": [
23 | "простуда",
24 | "орви",
25 | "орз",
26 | "грипп"
27 | ],
28 | "top_k": 20
29 | },
30 | {
31 | "name": "Аллергия",
32 | "centroid": [
33 | "аллергия"
34 | ],
35 | "top_k": 20
36 | }
37 | ]
--------------------------------------------------------------------------------
/examples/example_data/metrix.csv:
--------------------------------------------------------------------------------
1 | context,reference
2 | Joe waited for the ,Joe waited for the train
3 | We understand that sometimes the best way to truly understand a new concept is ,We understand that sometimes the best way to truly understand a new concept is to see it used in an example
4 | Simply type the word into the sentence generator and ,Simply type the word into the sentence generator and we’ll do the rest
5 | Sometimes to understand a word's meaning you need ,Sometimes to understand a word's meaning you need more than a definition
6 |
--------------------------------------------------------------------------------
/explainitall/QA/extractive_qa_sbert/QABotsBase.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import re
3 | import pandas as pd
4 | import numpy as np
5 | from sklearn.neighbors import NearestNeighbors
6 |
7 |
8 |
9 | class RetriBotStruct:
10 | """Ретривел бот"""
11 |
12 | def __init__(self, path=None, knn=None, embedder=None, answers=None):
13 |
14 | if path is None:
15 | if not (knn is None or embedder is None or answers is None):
16 | self.knn = knn
17 | self.embedder = embedder
18 | self.answers = answers
19 |
20 | else:
21 | print('Ошибка! Укажите путь или метод ближ. соседа, ответы и эмбеддер')
22 |
23 | else:
24 | self.load(path)
25 |
26 | def to(self, device='cpu'):
27 | self.embedder.to(device)
28 |
29 | def get_knn(self):
30 | """Возвращает knn"""
31 | return self.knn
32 |
33 | def get_embedder(self):
34 | """Возвращает эбеддер"""
35 | return self.embedder
36 |
37 | def get_ans_list(self):
38 | """Возвращает список ответов"""
39 | return self.answers
40 |
41 | def save(self, path):
42 | """Сохранить"""
43 | self.to()
44 | with open(path, 'wb') as f:
45 | pickle.dump(self, f)
46 |
47 | def load(self, path):
48 | """Загрузка"""
49 | with open(path, 'rb') as f:
50 | retri_bot = pickle.load(f)
51 | self.knn = retri_bot.knn
52 | self.embedder = retri_bot.embedder
53 | self.answers = retri_bot.answers
54 |
55 |
56 | class QABotStruct:
57 | """Вопрос-ответный бот"""
58 |
59 | def __init__(self, path=None, retri_bot=None, qa=None):
60 |
61 | if path == None:
62 | if not (retri_bot == None or qa == None):
63 | self.retri_bot = retri_bot
64 | self.qa = qa
65 |
66 | else:
67 | print('Ошибка! Укажите путь или ретривел бот и QA систему')
68 |
69 | else:
70 | self.load(path)
71 |
72 | def get_retri_bot(self):
73 | """Возвращает ретривел бот"""
74 | return self.retri_bot
75 |
76 | def get_qa(self):
77 | """Возвращает QA систему"""
78 | return self.qa
79 |
80 | def save(self, path):
81 | """Сохранить"""
82 | self.qa.model.to('cpu')
83 | self.retri_bot.to('cpu')
84 | with open(path, 'wb') as f:
85 | pickle.dump(self, f)
86 |
87 | def load(self, path):
88 | """Загрузка"""
89 | with open(path, 'rb') as f:
90 | qa_bot = pickle.load(f)
91 |
92 | self.qa = qa_bot.qa
93 | self.retri_bot = qa_bot.retri_bot
94 |
95 |
96 | class RetriBot():
97 | """Ретривел бот"""
98 |
99 | def __init__(self, bot, max_words=50, device='cpu'):
100 |
101 | if 'str' in str(type(bot)):
102 | rBot = RetriBotStruct(bot)
103 | else:
104 | rBot = bot
105 |
106 | self.main_knn = rBot.get_knn()
107 | self.sModel = rBot.get_embedder().to(device)
108 | self.max_words = max_words
109 | self.texts = rBot.get_ans_list()
110 |
111 | def _get_vect(self, q):
112 | return self.sModel.encode(q, convert_to_tensor=False)
113 |
114 | @staticmethod
115 | def cut(text, max_len=15):
116 | words = text.split(' ')[:max_len]
117 | ret_text = ''
118 | for word in words:
119 | ret_text += word + ' '
120 |
121 | return ret_text
122 |
123 | def get_answers(self, q, top_k=7):
124 | vect_q = self._get_vect(q)
125 | ans = self.main_knn.kneighbors([vect_q], top_k)
126 | support = []
127 |
128 | for i in range(ans[0].shape[1]):
129 | support.append(self.texts[ans[1][0][i]])
130 |
131 | support = list(set(support))
132 |
133 | ret_line = ''
134 |
135 | for doc in support:
136 | ret_line += RetriBot.cut(doc, self.max_words) + '. '
137 |
138 | return ret_line
139 |
140 |
141 | class QABot:
142 | """QA"""
143 |
144 | def __init__(self, bot, max_words=50, device='cpu'):
145 |
146 | if 'str' in str(type(bot)):
147 | qBot = QABotStruct(bot)
148 | else:
149 | qBot = bot
150 |
151 | self.retr = RetriBot(qBot.get_retri_bot(), max_words=max_words, device=device)
152 | self.qa = qBot.get_qa()
153 | self.qa.model.to(device)
154 |
155 | def qa_get_answer(self, text, q, top_k=3):
156 | ans = self.qa(context=text, question=q, top_k=top_k)
157 | answers = []
158 |
159 | if top_k == 1:
160 | return [{'answer': ans['answer'], 'score': ans['score']}]
161 |
162 | else:
163 | for a in ans:
164 | answers.append({'answer': a['answer'], 'score': a['score']})
165 |
166 | return answers
167 |
168 | def search(self, text, q, confidence=0.2):
169 | answer = self.qa_get_answer(text, q, 1)[0]
170 | if answer['score'] >= confidence:
171 | return answer['answer']
172 | else:
173 | return ''
174 |
175 | def get_prompt(self, q, confidence=0.3, top_k_search=7):
176 | text = self.retr.get_answers(q, top_k_search)
177 | return self.search(text, q, confidence)
178 |
179 |
180 | class SimpleTransformer():
181 | """ Класс-заглушка """
182 |
183 | def __init__(self):
184 | pass
185 |
186 | def transform(self, vects):
187 | return vects
188 |
189 |
190 | def cos(x, y):
191 | return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))
192 |
193 |
194 | def cos_dist(x, y):
195 | return -cos(x, y)
196 |
197 |
198 | class KnnBot:
199 | """
200 | Поисковый бот на базе модели векторизации и метода ближайших соседей
201 | с установкой максимального радиуса для детекции аномалий.
202 | """
203 |
204 | def __init__(self, knn=None, sbert=None, mean=None, std=None, vect_transformer=None, dim=None, n_neighbors=3, eps=1e-200):
205 | self.knn = knn
206 | self.model = sbert
207 | self.mean = np.zeros((dim,)) if mean is None else mean
208 | self.std = np.ones((dim,)) if std is None else std + eps
209 | self.vect_transformer = SimpleTransformer() if vect_transformer is None else vect_transformer
210 |
211 | self.r_char = re.compile('[^A-zА-яЁё0-9": ]')
212 | self.r_spaces = re.compile(r"\s+")
213 |
214 | def clean_string(self, text):
215 | """
216 | Очистка и нормализация строки.
217 | """
218 | seq = self.r_char.sub(' ', text.replace('\n', ' '))
219 | seq = self.r_spaces.sub(' ', seq).strip()
220 | return seq.lower()
221 |
222 | def get_vect(self, q):
223 | """
224 | Преобразование текста в вектор и его нормализация.
225 | """
226 | vect_q = self.model.encode(q, convert_to_tensor=False)
227 | vect_q = self.vect_transformer.transform([vect_q])[0]
228 | return (vect_q - self.mean) / self.std
229 |
230 | def __get_answer_text(self, text_q, n_neighbors=1):
231 | """
232 | Получение ответов на основе векторизованного текста.
233 |
234 | :param text_q: Векторизованный запрос.
235 | :param n_neighbors: Количество ближайших соседей для поиска.
236 | :return: Список ответов.
237 | """
238 | vect = self.get_vect(text_q)
239 |
240 | # Получаем индексы ближайших соседей
241 | distances, indices = self.knn.kneighbors([vect], n_neighbors=n_neighbors)
242 |
243 | # Возвращаем список ответов на основе индексов ближайших соседей
244 | return list(set([self.answers[idx] for idx in indices[0]]))
245 |
246 | def get_answer(self, q, n_neighbors=1):
247 | """
248 | Получение ответа на входящий запрос.
249 |
250 | :param q: Вопрос.
251 | :param n_neighbors: Количество ближайших соседей, которые нужно вернуть.
252 | :return: Список ответов.
253 | """
254 | text_q = self.clean_string(q)
255 | return self.__get_answer_text(text_q, n_neighbors)
256 |
257 | def train(self, csv_path, embedder, knn_neighbors=5):
258 | """
259 | Метод для обучения бота на основе CSV с вопросами и ответами с расчетом среднего и стандартного отклонения.
260 |
261 | :param csv_path: Путь к CSV файлу с вопросами и ответами.
262 | :param embedder: Модель эмбеддинга для преобразования текста в векторы.
263 | :param knn_neighbors: Количество соседей для метода KNN.
264 | """
265 | # Шаг 1: Загрузка данных из CSV
266 | data = pd.read_csv(csv_path)
267 | if 'question' not in data.columns or 'answer' not in data.columns:
268 | raise ValueError("CSV файл должен содержать колонки 'question' и 'answer'")
269 |
270 | questions = data['question'].tolist()
271 | answers = data['answer'].tolist()
272 |
273 | # Шаг 2: Преобразование вопросов в векторы
274 | question_vectors = np.array([embedder.encode(q, convert_to_tensor=False) for q in questions])
275 |
276 | # Шаг 3: Вычисление среднего и стандартного отклонения по векторам вопросов
277 | self.mean = np.mean(question_vectors, axis=0)
278 | self.std = np.std(question_vectors, axis=0)
279 |
280 | # Чтобы избежать деления на ноль, добавляем небольшое значение eps
281 | eps = 1e-10
282 | self.std += eps
283 |
284 | # Шаг 4: Нормализация векторов вопросов
285 | normalized_vectors = (question_vectors - self.mean) / self.std
286 |
287 | # Шаг 5: Инициализация и обучение KNN
288 | self.knn = NearestNeighbors(n_neighbors=knn_neighbors, metric='cosine')
289 | self.knn.fit(normalized_vectors)
290 |
291 | # Сохранение ответов
292 | self.answers = answers
293 |
294 | def get_normalized_vector(self, question):
295 | """
296 | Получение нормализованного вектора вопроса на основе сохраненных среднего и std.
297 |
298 | :param question: Вопрос для преобразования.
299 | :return: Нормализованный вектор.
300 | """
301 | vect_q = self.model.encode(question, convert_to_tensor=False)
302 | return (vect_q - self.mean) / self.std
303 |
--------------------------------------------------------------------------------
/explainitall/QA/extractive_qa_sbert/SVDBert.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from transformers import BertModel
3 |
4 |
5 | class SVDLinearLayer(nn.Module):
6 | """
7 | Линейный слой, использующий пространство уменьшенной размерности, использующий SVD.
8 | """
9 |
10 | def __init__(self, in_features, out_features, h_dim, bias=True):
11 | super(SVDLinearLayer, self).__init__()
12 | self.encoder = nn.Linear(in_features, h_dim, bias=False)
13 | self.decoder = nn.Linear(h_dim, out_features, bias=bias)
14 |
15 | def forward(self, x):
16 | x = self.encoder(x)
17 | x = self.decoder(x)
18 | return x
19 |
20 |
21 | class SVDBertModel(BertModel):
22 | """
23 | Модель BERT с уменьшенной размерностью в определенных слоях с использованием подхода SVD.
24 | """
25 |
26 | def __init__(self, config, svd_dim=5):
27 | super(SVDBertModel, self).__init__(config)
28 |
29 | for i, layer in enumerate(self.encoder.layer):
30 | intermediate_size = layer.intermediate.dense.out_features
31 | output_size = layer.output.dense.out_features
32 |
33 | if i > 0:
34 | layer.intermediate.dense = SVDLinearLayer(
35 | layer.intermediate.dense.in_features,
36 | intermediate_size,
37 | svd_dim
38 | )
39 | layer.output.dense = SVDLinearLayer(
40 | layer.output.dense.in_features,
41 | output_size,
42 | svd_dim
43 | )
44 | else:
45 | layer.intermediate.dense = nn.Linear(
46 | layer.intermediate.dense.in_features,
47 | intermediate_size,
48 | bias=True
49 | )
50 | layer.output.dense = nn.Linear(
51 | layer.output.dense.in_features,
52 | output_size,
53 | bias=True
54 | )
55 |
--------------------------------------------------------------------------------
/explainitall/QA/interp_qa/KNNWithGenerative.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import nltk
4 | import numpy as np
5 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6 |
7 | from explainitall.QA.extractive_qa_sbert.QABotsBase import SimpleTransformer
8 |
9 | nltk.download('punkt')
10 |
11 |
12 | class FredStruct:
13 |
14 | def __init__(self, path='Ponimash/FredInterpreter', device='cpu'):
15 | self.tokenizer = AutoTokenizer.from_pretrained(path)
16 | self.fred_model = AutoModelForSeq2SeqLM.from_pretrained(path)
17 | self.fred_model = self.fred_model.to(device)
18 | self.fred_model.eval()
19 |
20 | def get_model(self):
21 | return self.fred_model
22 |
23 | def get_tokenizer(self):
24 | return self.tokenizer
25 |
26 |
27 | class PromptBot:
28 | """
29 | Поисковый бот на базе модели векторизации и метода ближ. соседа
30 | с установкой максимального радиуса для детекции аномалий
31 | """
32 |
33 | def __init__(self, knn, sbert, fred, texts, max_words=50, mean=None, std=None, vect_transformer=None, dim=None,
34 | n_neighbors=3, eps=1e-200, device='cuda'):
35 | self.max_words = max_words
36 | self.knn = knn
37 | self.knn.n_neighbors = n_neighbors
38 | self.texts = texts
39 | self.sbert = sbert
40 | self.mean = np.zeros((dim)) if str(type(mean)) != "" else mean
41 | self.std = np.ones((dim)) if str(type(std)) != "" else std + eps
42 | self.vect_transformer = SimpleTransformer() if vect_transformer == None else vect_transformer
43 | self.fred = fred
44 | self.device = device
45 |
46 | def __clean_string(self, text):
47 | """
48 | Очистка строки
49 | """
50 | seq = text.replace('\n', ' ')
51 | r_char = re.compile('[^A-zА-яЁё0-9": ]')
52 | r_spaces = re.compile(r"\s+")
53 | seq = r_char.sub(' ', seq)
54 | seq = r_spaces.sub(' ', seq).strip()
55 | return seq.lower()
56 |
57 | return data_inp, data_outp
58 |
59 | def __qa__(self, doc, q):
60 | doc = doc.replace('\n', ' ')
61 | q = q.replace('\n', ' ')
62 | q_pr = f"Опираясь только на информацию: {doc}.\n Подробно и полно ответь на вопрос указав все детали и числовые значения: \"{q}\" "
63 | data_inp = self.fred.get_tokenizer()(q_pr, return_tensors="pt").to(self.device)
64 | return data_inp
65 |
66 | def __generate__(self, doc, q):
67 | t = self.__qa__(doc, q)
68 | output_ids = self.fred.get_model().generate(
69 | **t, do_sample=True, temperature=0.2, max_new_tokens=256, top_p=0.95, top_k=15, repetition_penalty=1.03,
70 | no_repeat_ngram_size=3
71 | )[0]
72 | out = self.fred.get_tokenizer().decode(output_ids.tolist(), skip_special_tokens=True)
73 | return out.replace("", "")
74 |
75 | def get_vect(self, q):
76 | """вектор из текста"""
77 | vect_q = self.sbert.encode(q, convert_to_tensor=False)
78 | vect_q = self.vect_transformer.transform([vect_q])[0]
79 | return (vect_q - self.mean) / self.std
80 |
81 | @staticmethod
82 | def cut(text, max_len=15):
83 | words = text.split(' ')[:max_len]
84 | ret_text = ''
85 | for word in words:
86 | ret_text += word + ' '
87 |
88 | return ret_text
89 |
90 | def get_answers(self, q, top_k=7):
91 | """ответ по преобразованному тексту"""
92 | vect_q = self.get_vect(q)
93 | answer_ids = self.knn.kneighbors([vect_q], top_k)
94 | doc = ''
95 | for i in range(answer_ids[0].shape[1]):
96 | doc += self.texts[answer_ids[1][0][i]] + '\n'
97 | return self.__generate__(doc, q)
98 |
--------------------------------------------------------------------------------
/explainitall/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Bots-Avatar/ExplainitAll/0339ea5c09c3cd309d53c23b403465c821a778d0/explainitall/__init__.py
--------------------------------------------------------------------------------
/explainitall/clusters.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import List, Callable, Dict, Union
3 |
4 | import gensim
5 | from gensim.models import KeyedVectors
6 | from pandas import DataFrame
7 |
8 | from explainitall.gpt_like_interp.inseq_helpers import AttrObj
9 | from . import nlp
10 | from .gpt_like_interp import viz, inseq_helpers
11 |
12 |
13 | class Cluster:
14 | def __init__(self, name: str,
15 | sim_thresh: float,
16 | words: list, matching_func: Callable):
17 | self.name = name # Имя кластера
18 | self.sim_thresh = sim_thresh # Порог схожести
19 | self.words = words # Слова
20 | self.matching_func = matching_func # Функция сравнения
21 |
22 |
23 | class ClusterBuilder:
24 | def __init__(self, name: str, seed_words: list,
25 | embeddings: KeyedVectors,
26 | word_processor,
27 | num_similar_words: int = 200):
28 | """
29 | name: Имя кластера
30 | seed_words: Начальный список слов
31 | embeddings: векторные представления слов
32 | num_similar_words: Количество наиболее похожих слов для поиска
33 | """
34 | self.name = name
35 | self.embeddings = embeddings
36 | self.num_similar_words = num_similar_words
37 | self.word_processor = word_processor
38 |
39 | self.seed_words = seed_words
40 |
41 | self.embeddable_seed_words = self.word_processor.get_embeddable_words_batch(self.seed_words)
42 |
43 | def build(self, matching_func: Callable) -> Cluster:
44 | """
45 | Создает и возвращает объект Cluster
46 | matching_func: Функция для сравнения именованной сущности со словом
47 | """
48 | return Cluster(
49 | name=self.name,
50 | sim_thresh=0,
51 | words=self.find_similar_words(),
52 | matching_func=matching_func
53 | )
54 |
55 | def get_embeddable_word_from_most_similar(self, value) -> str:
56 | word_and_postfix, likelihood = value
57 | word, postfix = word_and_postfix.split("_")
58 | return self.word_processor.get_embeddable_word_or_none(word)
59 |
60 | def filter_and_clean_words_postfix(self, word_list: list, postfix: str = "_NOUN") -> list:
61 | filtered_w = [w for w in word_list if w and w.endswith(postfix)]
62 | return [w[:-len(postfix)] for w in filtered_w]
63 |
64 | def find_similar_words(self) -> list:
65 | if not self.embeddable_seed_words:
66 | return []
67 |
68 | try:
69 | similar_words = self.embeddings.most_similar(
70 | positive=self.embeddable_seed_words,
71 | topn=self.num_similar_words
72 | )
73 | except KeyError as e:
74 | raise Exception("embeddable_seed_words должны быть вида ['cat_NOUN','run_VERB']") from e
75 |
76 | extracted_words = [
77 | self.get_embeddable_word_from_most_similar(result) for result in similar_words
78 | ]
79 | extracted_words = self.embeddable_seed_words + extracted_words
80 | extracted_words = self.filter_and_clean_words_postfix(extracted_words)
81 |
82 | return extracted_words
83 |
84 |
85 | class ClusterManager:
86 | def __init__(self, embeddings: gensim.models.keyedvectors.KeyedVectors):
87 | self.embeddings = embeddings
88 | self.word_processor = nlp.WordProcessor(embeddings)
89 |
90 | def _is_same_normalized_word(self, word1: str, word2: str) -> bool:
91 | """
92 | Проверяет, одинаковы ли нормализованные формы двух слов.
93 | """
94 | return self.word_processor.get_normal_form_or_none(word1) == self.word_processor.get_normal_form_or_none(word2)
95 |
96 | def find_cluster_name(self, word: str, clusters: List[Cluster]) -> str:
97 | """
98 | Преобразует слово в имя кластера.
99 | """
100 | normalized_word = self.word_processor.get_normal_form_or_none(word)
101 | if normalized_word:
102 | for cluster in clusters:
103 | if normalized_word in cluster.words:
104 | return cluster.name
105 | return "unnamed"
106 |
107 | def create_clusters(self, clusters_descr: List[Dict]):
108 | """
109 | Создает кластеры на основе их описаний.
110 | """
111 | clusters = []
112 |
113 | for descr in clusters_descr:
114 | clusters.append(ClusterBuilder(
115 | name=descr['name'],
116 | seed_words=descr['centroid'],
117 | embeddings=self.embeddings,
118 | num_similar_words=descr['top_k'],
119 | word_processor=self.word_processor
120 | ).build(matching_func=self._is_same_normalized_word))
121 | return clusters
122 |
123 |
124 | class ClusterInterpreter:
125 | """Интерпретация Кластеров"""
126 |
127 | def __init__(self,
128 | clusters_discr: List[Dict[str, object]],
129 | cluster_manager: ClusterManager
130 | ):
131 | self.cluster_manager = cluster_manager
132 | self.clusters = cluster_manager.create_clusters(clusters_discr)
133 |
134 | def set_link_with_clusters(self, attribute):
135 | """Устанавливает связь между семантическими кластерами."""
136 | grouped_attribute = inseq_helpers.group_by(attribute, gmm_norm=True)
137 | return self._create_parsed_attribution(grouped_attribute)
138 |
139 | def get_cluster_importance_df(self, attribute):
140 | """Преобразует атрибуты в dataframe."""
141 | attribute_with_clusters = self.set_link_with_clusters(attribute)
142 | return inseq_helpers.attr_to_df(attribute_with_clusters)
143 |
144 | def display_attr(self, attribute):
145 | """Отображает атрибуты в виде тепловой карты."""
146 | attribute_with_clusters = self.set_link_with_clusters(attribute)
147 | return viz.attr_to_heatmap(attribute_with_clusters)
148 |
149 | def _create_parsed_attribution(self, grouped_attribute: AttrObj):
150 | """Создает разобранные атрибуции с сгенерированными метками."""
151 | tokens_generated_cl = [self.cluster_manager.find_cluster_name(word=x, clusters=self.clusters)
152 | for x in grouped_attribute.tokens_generated]
153 |
154 | tokens_input_cl = [self.cluster_manager.find_cluster_name(word=x, clusters=self.clusters)
155 | for x in grouped_attribute.tokens_input]
156 |
157 | grouped_attribute = copy.deepcopy(grouped_attribute)
158 | grouped_attribute.tokens_input = tokens_input_cl
159 | grouped_attribute.tokens_generated = tokens_generated_cl
160 | return grouped_attribute
161 |
162 |
163 | def aggregate_cluster_df(cluster_df: DataFrame,
164 | aggr_f: Union[('max', 'min', 'sum', 'mean', 'median', 'std', 'var', 'sem', 'skew')],
165 | drop_condition: str = 'unnamed',
166 | cl_names_col: str = 'Tokens',
167 | ) -> DataFrame:
168 | cluster_df = cluster_df[~cluster_df[cl_names_col].str.contains(drop_condition)]
169 | cluster_df = cluster_df.loc[:, ~cluster_df.columns.str.contains(drop_condition)]
170 | aggregation_functions = {col: aggr_f for col in cluster_df.columns if col != cl_names_col}
171 | return cluster_df.groupby(cl_names_col).agg(aggregation_functions).reset_index()
172 |
--------------------------------------------------------------------------------
/explainitall/embedder_interp/embd_interpret.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy.linalg import norm
3 | from sentence_transformers import util
4 |
5 |
6 | def sen_a(model, texts):
7 | embeddings = model.encode(texts)
8 | s = util.pytorch_cos_sim(embeddings, embeddings)
9 | s[s < 0] = 0
10 | matrix = (1 - s).to('cpu').detach().numpy()
11 | return matrix[matrix > 1e-7].mean()
12 |
13 |
14 | class CosRelu:
15 | @staticmethod
16 | def dot_(v1, v2):
17 | """Скалярное произведение (у numpy не точно считает)"""
18 | sum_ = 0
19 | for i, v in enumerate(v2):
20 | sum_ += v1[i] * v
21 | return sum_
22 |
23 | @staticmethod
24 | def cos(v1, v2):
25 | """Косинус"""
26 | return CosRelu.dot_(v1, v2) / (norm(v1) * norm(v2))
27 |
28 | @staticmethod
29 | def cos_relu(y_orig, y_without_ner) -> float:
30 | """Рассчет ReLU от косинуса между векторам эмбеддингов"""
31 | r = CosRelu.cos(y_orig, y_without_ner)
32 | r = r if r >= 0 else 0
33 | return r
34 |
35 |
36 | class ModelInterp:
37 |
38 | def __init__(self, model):
39 | self.model = model
40 | self.mask_token = model.tokenizer.mask_token
41 |
42 | def seq_interp(self, sent):
43 | refer = self.model.encode(sent)
44 |
45 | words = sent.split(' ')
46 | imp = []
47 | y_mod = []
48 |
49 | for k in range(len(words)):
50 | repl_word = ' '.join([word if i != k else self.mask_token for i, word in enumerate(words)])
51 | y_mod.append(repl_word)
52 |
53 | targ = self.model.encode(y_mod)
54 |
55 | for vect in targ:
56 | imp.append(1 - CosRelu.cos_relu(refer, vect))
57 |
58 | imp = np.array(imp)
59 | imp_sum = imp.sum()
60 | imp /= imp_sum
61 |
62 | return {'imp': imp, 'words': words}
63 |
64 | def dataset_interp(self, texts):
65 | ret_data = []
66 | for text in texts:
67 | interp_seq = self.seq_interp(text)
68 | ret_data.append(interp_seq)
69 |
70 | return ret_data
71 |
72 | def __claster_energy(self, cluster_data):
73 | elements = cluster_data['elements']
74 | vectors = self.model.encode(elements)
75 | cl_energe = np.array([sum(vector ** 2) for vector in vectors])
76 | return {'name': cluster_data['name'], 'sensitivity': sen_a(self.model, elements), 'mean': cl_energe.mean()}
77 |
78 | def clusters_interp(self, clusters_data):
79 | return [self.__claster_energy(cluster_data) for cluster_data in clusters_data]
80 |
--------------------------------------------------------------------------------
/explainitall/fast_tuning/Embedder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from transformers import GPT2Model, GPT2Config
4 |
5 |
6 | class GPTEmbedder:
7 | """Эмбеддер для gpt"""
8 |
9 | def __init__(self, tokenizer_gpt, gpt, device=None):
10 |
11 | if device is None:
12 | device = 'cpu'
13 | if torch.cuda.is_available():
14 | device = 'cuda:0'
15 |
16 | self.model = gpt
17 | self.model.to(device)
18 | self.tokenizer = tokenizer_gpt
19 | self.device = device
20 |
21 | #
22 | def get_emb_from_gpt(self, inp_str, n_layer_index=-1, token=-1, is_attention=False):
23 | """
24 | Данные(эмбеддинги) со скрытых слоев
25 | Вернуть скрытое состояние или данные внимания
26 | """
27 | att_hidden = 0
28 | if is_attention:
29 | att_hidden = 1
30 |
31 | inp_tokens = self.tokenizer.encode(inp_str)
32 | context = torch.tensor(inp_tokens, dtype=torch.long, device=self.device)
33 | generated = context.unsqueeze(0)
34 | inputs = {'input_ids': generated}
35 |
36 | with torch.no_grad():
37 | outputs = self.model(**inputs)
38 |
39 | if n_layer_index == 'all':
40 | return outputs.last_hidden_state[0, -1, :].reshape(self.model.config.n_embd).to('cpu').detach().numpy()
41 | else:
42 | return_obj = outputs[1][n_layer_index][att_hidden][0, :, token, :].to('cpu') # 0 - т.к. батч ожидается 1
43 |
44 | shape = return_obj.shape
45 | return return_obj.reshape((shape[0] * shape[1])).detach().numpy()
46 |
47 | def get_embs_from_gpt(self, inp_str, n_layer_index=-1, head_index=0, is_attention=False):
48 | """
49 | Данные(эмбеддинги) со скрытых слоев(По всем токенам)
50 | Вернуть скрытое состояние или данные внимания
51 | """
52 |
53 | att_hidden = 0
54 | if is_attention:
55 | att_hidden = 1
56 |
57 | inp_tokens = self.tokenizer.encode(inp_str)
58 | context = torch.tensor(inp_tokens, dtype=torch.long, device=self.device)
59 | generated = context.unsqueeze(0)
60 | inputs = {'input_ids': generated}
61 |
62 | with torch.no_grad():
63 | outputs = self.model(**inputs)
64 |
65 | # Пройти всю сеть включая слой нормализации
66 | if n_layer_index == 'all':
67 | return outputs.last_hidden_state[0, :, :].reshape((len(inp_tokens), self.model.config.n_embd)).to(
68 | 'cpu').detach().numpy()
69 |
70 | # Пройти заданное число gpt блоков
71 | else:
72 | # Вернуть все головы внимания или только 1
73 | if head_index == 'all':
74 | return_obj = outputs[1][n_layer_index][att_hidden][0, :, :, :].to('cpu').detach().numpy()
75 | out_len = return_obj.shape[0] * return_obj.shape[2]
76 | return_obj = np.transpose(return_obj, (1, 0, 2)).reshape(return_obj.shape[1], out_len)
77 | else:
78 | return_obj = outputs[1][n_layer_index][att_hidden][0, head_index, :, :].to(
79 | 'cpu').detach().numpy() # 0 - т.к. батч ожидается 1
80 |
81 | return return_obj
82 |
83 | def _get_k_layer(self, num_layers=3, name='gpt-embeder'):
84 | """Создание модели на базе первых слоев gpt2 донора"""
85 | base_model = self.model.base_model # модель-донор
86 | config_base = base_model.config # Конфиг модели-донора
87 | config = GPT2Config.from_dict(config_base.to_dict()) # Копирование конфига
88 |
89 | if num_layers < 0:
90 | num_layers = config.n_layer + num_layers + 1
91 |
92 | config.name_or_path = name # Имя сети
93 | config.n_layer = num_layers # Установка нужного числа слоев
94 |
95 | gpt_emb = GPT2Model(config) # Создание модели
96 | gpt_emb.wte.weight = self.model.transformer.wte.weight # Эмбединги слов
97 | gpt_emb.wpe.weight = self.model.transformer.wpe.weight # Эмбединги позиций
98 | gpt_emb.ln_f.weight = self.model.transformer.ln_f.weight # Слой нормализации
99 |
100 | for n_layer in range(num_layers):
101 | gpt_emb.base_model.h[n_layer] = base_model.h[n_layer] # Копирование слоев
102 |
103 | return gpt_emb
104 |
105 | def get_new_model(self, num_layers=3, name='gpt-embeder', save_path='gpt/model_emb'):
106 | """Создание модели на базе первых слоев gpt2 донора с перезаписью"""
107 | m = self._get_k_layer(num_layers, name)
108 | m.save_pretrained(save_path)
109 | return GPT2Model.from_pretrained(save_path)
110 |
--------------------------------------------------------------------------------
/explainitall/fast_tuning/ExpertBase.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | class ExpertModel(abc.ABC):
4 | @abc.abstractmethod
5 | def get_bias(self, tokens):
6 | """Вычисление bias из вероятностной модели"""
7 |
--------------------------------------------------------------------------------
/explainitall/fast_tuning/SimpleModelCreator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Model
5 |
6 | from .Embedder import GPTEmbedder
7 | from .trainers.DenceKerasTrainer import GPTFastTrainer
8 |
9 |
10 | def get_dataset_dense(txts, embedder: GPTEmbedder, tokenizer: GPT2Tokenizer, n_layer_index='all'):
11 | """
12 | Создает датасет из текстов с помощью заданного embedder и tokenizer.
13 | """
14 | list_x, list_y = [], []
15 | for txt in txts:
16 | words = txt.split(' ')
17 | for i in range(0, len(words), 25):
18 | text = ' '.join(words[i:])
19 | emb = embedder.get_embs_from_gpt(text, n_layer_index=n_layer_index)[:-1][:1024]
20 | ids = np.array(tokenizer(text)['input_ids'])[1:]
21 | list_x.append(emb)
22 | list_y.append(ids)
23 |
24 | return np.concatenate(list_x), np.concatenate(list_y)
25 |
26 |
27 | def gpt_build(trainer: GPTFastTrainer, gpt_emb: GPT2Model, tokenizer: GPT2Tokenizer, y_set,
28 | path_to_save='gpt_model_new'):
29 | """
30 | Создает и сохраняет новую модель GPT.
31 | """
32 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
33 | gpt_emb.to(device)
34 |
35 | w_out_matr = trainer.adapter_layer.get_weights()[0] @ trainer.keras_out_weight.T
36 | config_gpt = gpt_emb.config
37 | new_model = GPT2LMHeadModel(config_gpt)
38 | new_model.transformer = gpt_emb
39 | new_model.lm_head = nn.Linear(config_gpt.n_embd, config_gpt.vocab_size, bias=False)
40 | new_model.lm_head.weight = nn.Parameter(torch.tensor(w_out_matr, device=device))
41 |
42 | new_model.save_pretrained(path_to_save)
43 | tokenizer.save_pretrained(path_to_save)
44 | np.save(f'{path_to_save}/set.data', np.array(y_set))
45 |
46 |
47 | class SimpleCreator:
48 | def __init__(self, model: GPT2Model, tokenizer: GPT2Tokenizer):
49 | self.tokenizer = tokenizer
50 | main_embedder = GPTEmbedder(tokenizer, model)
51 | self.gpt_emb = main_embedder.get_new_model(num_layers=-1)
52 | self.cut_embedder = GPTEmbedder(self.tokenizer, self.gpt_emb)
53 | self.trainer = GPTFastTrainer(model)
54 |
55 | def train(self, data, lr=0.0003, bs=64, epochs=6, val_split=0.0, save_path='new_model'):
56 | x, y = get_dataset_dense(data, self.cut_embedder, self.tokenizer)
57 | net = self.trainer.create_net()
58 | self.trainer.train(net, x, y, lr=lr, bs=bs, epochs=epochs, val_split=val_split)
59 | y_set = list(set(y))
60 | gpt_build(self.trainer, self.gpt_emb, self.tokenizer, y_set, save_path)
61 | return net
62 |
--------------------------------------------------------------------------------
/explainitall/fast_tuning/generators/GeneratorWithExpert.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 |
6 | class GPTGenerator:
7 |
8 | def __init__(self, model, tokenizer, expert, device = None):
9 | self.infer_device = device
10 | self.model_gpt = model
11 | self.model_gpt.to(device)
12 | self.tokenizer_gpt = tokenizer
13 | self.expert = expert
14 | self.model_gpt.eval()
15 |
16 | # Штраф за повтор
17 | def __repeat_penalty(self, generated, logits, rp):
18 | hist_tokens = generated[0].cpu().numpy()
19 | unique_tokens, counts = np.unique(hist_tokens, return_counts=True)
20 | for token, count in zip(unique_tokens, counts):
21 | logits[token] -= rp * count
22 |
23 | #top p, top k фильтрация
24 | def _get_token(self, logist, top_k = 30, top_p = 0.9, del_simbols = None):
25 |
26 | filter_value = float('-inf')
27 |
28 | # Фильтруем нежелательные токены
29 | if del_simbols != None:
30 | for del_s in del_simbols:
31 | logist[del_s] = filter_value
32 |
33 | #top_k
34 | topk = torch.topk(logist, top_k)
35 | topk_v = topk[0] # значения top_k
36 | topk_i = topk[1] # индексы top_k
37 |
38 | #top_p
39 | probs = F.softmax(topk_v, dim = -1)
40 | cumulative_probs = torch.cumsum(probs, dim = -1)
41 | probs[top_p < cumulative_probs] = 0
42 |
43 | if sum(probs) == 0:
44 | probs[0] = 1
45 |
46 | token_ind = torch.multinomial(probs, 1)
47 | token = topk_i[token_ind]
48 | return token
49 |
50 | # Генерация последовательности
51 | def _sample_sequence(self, length, context_tokens, temperature=1, top_k=30, expert_w = 0.2, rp = 0.1, del_simbols = None):
52 | inp_len = len(context_tokens)
53 | context = torch.tensor(context_tokens, dtype=torch.long, device=self.infer_device)
54 | generated = context.unsqueeze(0)
55 |
56 | with torch.no_grad():
57 | decoded = ''
58 | for _ in range(length):
59 | inputs = {'input_ids': generated[:, -1023:]} # Входы
60 | outputs = self.model_gpt(**inputs) # Прямой проход gpt
61 |
62 | g_with_start = list(generated[0].cpu().numpy())# Затравка для эксперта
63 | bias_expert = self.expert.get_bias(g_with_start) # bias на базе эксперта или их смеси
64 | bias_expert = torch.tensor(bias_expert).to(self.infer_device) # bias в виде тензора
65 | next_token_logits = ((1-expert_w)*outputs[0][0, -1, :]+expert_w*bias_expert) / temperature
66 | self.__repeat_penalty(generated, next_token_logits, rp)
67 | next_tokens = self._get_token(next_token_logits, top_k = top_k, del_simbols = del_simbols).to(self.infer_device) # Генерация из распределения
68 |
69 | if next_tokens == 0:
70 | break
71 |
72 | generated = torch.cat((generated, next_tokens.unsqueeze(-1)), dim=1)
73 |
74 | out = generated[0, len(context_tokens):].tolist()
75 | new_decoded = self.tokenizer_gpt.decode(out)
76 | if len(new_decoded) > len(decoded):
77 | decoded = new_decoded
78 | return decoded
79 |
80 |
81 | # Генерация текста из текста
82 | def _generate(self, raw_text, length=250, temperature=1., top_k=30, rp = 0.03, expert_w =0.2, del_simbols = None):
83 | context_tokens = self.tokenizer_gpt.encode(raw_text)
84 | out = self._sample_sequence(
85 | length, context_tokens,
86 | rp = rp,
87 | expert_w=expert_w,
88 | temperature=temperature,
89 | top_k=top_k,
90 | del_simbols = del_simbols,
91 | )
92 | return out
93 |
94 | # Генерация нескольких последовательностей из начального текста
95 | def Generate(self, text, max_len = 200, num_seq = 2, temperature = 1.0, topk = 30, rp = 0.03, expert_w =0.2, del_simbols = None):
96 | ret = []
97 | for i in range(num_seq):
98 | ret.append(
99 | self._generate(text, max_len, temperature=temperature, top_k=topk, rp=rp, expert_w = expert_w, del_simbols=del_simbols)
100 | )
101 | return ret
102 |
103 | # Генерация с использованием Bias
104 | class GenerationWithProbs:
105 |
106 | def __init__(self, model, tokenizer, bias_mask, device = 'cpu'):
107 | self.model = model
108 | self.b_mask = bias_mask
109 | self.tokenizer = tokenizer
110 | self.model.to(device)
111 |
112 | def generate(self, text, top_p = 0.9, max_len = 30, rp = 1.15, temperature=0.7, variety = 0.3):
113 | self.__set_variety(variety)
114 | do_sample = temperature > 0
115 |
116 | input_ids = self.tokenizer.encode(text, return_tensors='pt').to(self.model.device)
117 | output_sequences = self.model.generate(input_ids=input_ids,
118 | max_length=max_len,
119 | temperature= None if temperature == 0 else temperature,
120 | top_p= None if temperature == 0 else top_p,
121 | repetition_penalty=rp,
122 | num_return_sequences=1,
123 | do_sample=do_sample)
124 |
125 | generated_text = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
126 | return generated_text
127 |
128 | def __set_variety(self, variety=0.0):
129 | bias = np.zeros((self.model.lm_head.out_features,))
130 | coef_mask = np.log2(variety + 3e-3) / np.log2(np.e)
131 | bias += coef_mask
132 |
133 | for token in self.b_mask:
134 | bias[token] = 0
135 |
136 | b_tensor = torch.tensor(bias, dtype=torch.float32)
137 | out_gpt_layer = torch.nn.Linear(in_features=self.model.lm_head.in_features, out_features=self.model.lm_head.out_features, bias=True)
138 | out_gpt_layer.weight = self.model.lm_head.weight
139 | out_gpt_layer.bias.data.copy_(b_tensor)
140 | self.model.lm_head = out_gpt_layer
141 |
--------------------------------------------------------------------------------
/explainitall/fast_tuning/generators/MCExpert.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta
2 |
3 | from explainitall.fast_tuning.ExpertBase import ExpertModel
4 |
5 |
6 | class MCExpert(ExpertModel, metaclass=ABCMeta):
7 |
8 | def __init__(self, mc):
9 | self.mc = mc
10 |
11 | def get_bias(self, tokens):
12 | start = [1, 1] + tokens
13 | return self.mc.get_bias(*start[-2:])
--------------------------------------------------------------------------------
/explainitall/fast_tuning/generators/MCModel.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import numpy as np
3 |
4 |
5 | # преобразование 2х токенов в строку
6 | def token_pair_2_str(tokens):
7 | list_x = list(tokens)
8 | return str(list_x)
9 |
10 | # Кодирование конца
11 | def encode_end(tokens_gpt, x_encode):
12 | str_x = token_pair_2_str(tokens_gpt[-2:])
13 | return x_encode[str_x] if str_x in x_encode else -1
14 |
15 | # Markov
16 | class MarkovModel():
17 | def __init__(self, len_vect, x_e=None, y_d = None, model = None, path = None, dep = 4):
18 | if path == None:
19 | self.x_encoder = x_e
20 | self.y_decoder = y_d
21 | else:
22 | with open(path + 'x_enc.dat', 'rb') as f:
23 | self.x_encoder = pickle.load(f)
24 | with open(path + 'y_dec.dat', 'rb') as f:
25 | self.y_decoder = pickle.load(f)
26 | with open(path + 'model.dat', 'rb') as f:
27 | model = pickle.load(f)
28 |
29 | self.dep = dep
30 | self.model_log = self.__get_log_model(model)
31 | self.len_vect = len_vect
32 |
33 | #----------------------- Генрерация bias --------------#
34 | def get_bias(self, token_1 = 1, token_2 = 1):
35 | '''Генерация bias'''
36 | b = [token_1, token_2]
37 | key = encode_end(b, self.x_encoder)
38 | bias = np.zeros((self.len_vect))
39 |
40 | if key != -1:
41 | tokens, logs = self.model_log[key]['tokens'], self.model_log[key]['logists']
42 | tokens = self.__get_tokens(tokens)
43 | bias -= self.dep
44 |
45 | for i, token in enumerate(tokens):
46 | bias[token] = logs[i]
47 |
48 | return bias
49 |
50 | # Получение нормированных логарифмов
51 | def __get_logist(self, key, model):
52 | logist = np.log(model[key]['probs'])
53 | logist-=max(logist)
54 | return logist
55 |
56 | # Получение модели с логарифмами вероятностей
57 | def __get_log_model(self, model):
58 | m_l ={}
59 | for key in range(len(model)):
60 | m_l_semple = {}
61 | m_l_semple.update({'tokens': model[key]['tokens']})
62 | m_l_semple.update({'logists':self.__get_logist(key, model)})
63 | m_l.update({key:m_l_semple})
64 | return m_l
65 |
66 |
67 | def __get_tokens(self, tokens):
68 | true_tokens = []
69 | for t in tokens:
70 | true_tokens.append(self.y_decoder[t])
71 | return true_tokens
72 |
73 |
74 |
--------------------------------------------------------------------------------
/explainitall/fast_tuning/generators/SimpleGenerator.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline
2 | import numpy as np
3 | import torch
4 |
5 | class TextGenerator:
6 | def __init__(self, path):
7 | self.tokenizer = AutoTokenizer.from_pretrained(path)
8 | self.model = AutoModelForCausalLM.from_pretrained(path)
9 |
10 | self.y_set = np.load(f'{path}/set.data.npy')
11 | self.set_variety_of_answers(0.0)
12 |
13 | self.pipeline = TextGenerationPipeline(model=self.model, tokenizer=self.tokenizer)
14 |
15 | def set_variety_of_answers(self, variety=0, min_prob=3e-3):
16 | bias = np.zeros((self.model.lm_head.out_features,))
17 | coef_mask = np.log2(variety + min_prob) / np.log2(np.e)
18 | bias += coef_mask
19 |
20 | for token in self.y_set:
21 | bias[token] = 0
22 |
23 | b_tensor = torch.tensor(bias, dtype=torch.float32)
24 | self.model.lm_head.bias.data.copy_(b_tensor)
25 |
26 | def generate(self, start_text, args=None):
27 | if args is None:
28 | args = {
29 | "temperature": 0.7,
30 | "no_repeat_ngram_size": 2,
31 | "num_beams": 12,
32 | "top_k": 30,
33 | }
34 | return self.pipeline(start_text, **args)["generated_text"]
35 |
--------------------------------------------------------------------------------
/explainitall/fast_tuning/trainers/DenceKerasTrainer.py:
--------------------------------------------------------------------------------
1 | import keras.layers as L
2 | import numpy as np
3 | import torch
4 | from keras.layers import Dense
5 | from keras.models import Sequential
6 | from tensorflow import keras as K
7 |
8 |
9 | class GPTFastTrainer:
10 | """ Создание обучаемого последнего слоя"""
11 |
12 | def __init__(self, gpt_model):
13 | self.inp_dim = gpt_model.config.n_embd
14 | self.outp_dim = gpt_model.config.vocab_size
15 |
16 | if torch.cuda.is_available():
17 | gpt_model.to('cpu')
18 |
19 | self.keras_out_weight = gpt_model.lm_head.weight.detach().numpy().transpose() # Получение весовых коэффициентов
20 | self.keras_adapter_weight = np.eye(self.inp_dim, dtype=float) # Единичная матрица (коэф. адаптирующего слоя)
21 |
22 | self.adapter_layer = Dense(self.inp_dim, use_bias=False, activation='linear') # Адаптирующий слой
23 | self.out_layer = Dense(use_bias=True, units=self.outp_dim, activation='linear',
24 | trainable=False) # Выходной слой (необучается)
25 |
26 | if torch.cuda.is_available():
27 | gpt_model.to('cuda:0')
28 |
29 | # Создание сети для тюнинга #
30 | def creat_net(self):
31 | net = Sequential()
32 | net.add(L.Input(self.inp_dim))
33 | net.add(self.adapter_layer)
34 | net.add(self.out_layer)
35 | net.add(L.Activation(activation='softmax'))
36 | net.compile()
37 | # Загрузка весов в выходной слой
38 | self.out_layer.set_weights([self.keras_out_weight, np.zeros(self.outp_dim)])
39 | # Загрузка весов в слой адаптера
40 | self.adapter_layer.set_weights([self.keras_adapter_weight])
41 | return net
42 |
43 | def set_variety_of_answers(self, y, variety=0, min_prob=1e-300):
44 | """Пересоздание слоя с установкой вариативности генерации"""
45 | set_tokens = set(y)
46 | bias = np.zeros(self.outp_dim)
47 | coef_mask = np.log2(variety + min_prob) / np.log2(np.e)
48 | bias += coef_mask
49 |
50 | for token in set_tokens:
51 | bias[token] = 0
52 |
53 | self.out_layer.set_weights([self.keras_out_weight, bias])
54 |
55 | def train(self, net, x, y, lr=0.0003, bs=64, epochs=3, val_split=0.0):
56 | """Обучение сети"""
57 | self.set_variety_of_answers(y)
58 | opt = K.optimizers.Adamax(learning_rate=lr)
59 | net.compile(loss='sparse_categorical_crossentropy', optimizer=opt)
60 | net.fit(x, y, batch_size=bs, epochs=epochs, validation_split=val_split)
61 |
--------------------------------------------------------------------------------
/explainitall/fast_tuning/trainers/HMMTrainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class GPT2HMMDataProcessor:
5 | def __init__(self, tokenizer):
6 | self.tokenizer = tokenizer
7 |
8 | def get_data_1(self, texts):
9 | list_y = []
10 | for text in texts:
11 | list_tok = [1, 1] + self.tokenizer(text)['input_ids'] + [2]
12 | list_y.append(np.array(list_tok))
13 | return np.concatenate(list_y)
14 |
15 | @staticmethod
16 | def token_pair_2_str(tokens):
17 | list_x = list(tokens)
18 | return str(list_x)
19 |
20 | def createXY(self, tokens_gpt):
21 | x = []
22 | y = []
23 |
24 | for i in range(len(tokens_gpt) - 2):
25 | str_x = self.token_pair_2_str(tokens_gpt[i:i + 2])
26 | x.append(str_x)
27 | y.append(tokens_gpt[i + 2])
28 |
29 | return x, y
30 |
31 | @staticmethod
32 | def encode_samples_x(x, y_encode):
33 | return [y_encode[i] for i in x]
34 |
35 | @staticmethod
36 | def encode_samples_y(y, y_encode):
37 | return [y_encode[i] for i in y]
38 |
39 | @staticmethod
40 | def encode_end(tokens_gpt, x_encode):
41 | str_x = GPT2HMMDataProcessor.token_pair_2_str(tokens_gpt[-2:])
42 | return x_encode[str_x] if str_x in x_encode else -1
43 |
44 | def create_data(self, tokens_gpt):
45 | x, y = self.createXY(tokens_gpt)
46 | x_set = list(set(x))
47 | mask_y = list(set(y))
48 |
49 | x_encode = {}
50 | x_decode = []
51 |
52 | y_encode = {}
53 | y_decode = []
54 |
55 | for i, xi in enumerate(x_set):
56 | x_encode.update({xi: i})
57 | x_decode.append(xi)
58 |
59 | for i, yi in enumerate(mask_y):
60 | y_encode.update({yi: i})
61 | y_decode.append(yi)
62 |
63 | x_enc = self.encode_samples_x(x, x_encode)
64 | y_enc = self.encode_samples_y(y, y_encode)
65 |
66 | return {'x': x_enc, 'y': y_enc, 'x_encoder': x_encode, 'y_encoder': y_encode, 'x_decoder': x_decode,
67 | 'y_decoder': y_decode}
68 |
69 | @staticmethod
70 | def train(data):
71 | states = {}
72 | n = len(data['y'])
73 | x = data['x']
74 | y = data['y']
75 |
76 | for i in range(n):
77 | if x[i] in states:
78 | if y[i] in states[x[i]]['tokens']:
79 | states[x[i]]['probs'][y[i]] += 1
80 | else:
81 | states[x[i]]['probs'].update({y[i]: 1})
82 | states[x[i]]['tokens'].update({y[i]: 0})
83 | else:
84 | states.update({x[i]: {'tokens': {}, 'probs': {}}})
85 | states[x[i]]['probs'].update({y[i]: 1})
86 | states[x[i]]['tokens'].update({y[i]: 0})
87 |
88 | n = max(states.keys()) + 1
89 | n_states = []
90 |
91 | for i in range(n):
92 | tokens = list(states[i]['tokens'].keys())
93 | total_count = 0
94 | probs = []
95 |
96 | for t in tokens:
97 | count = states[i]['probs'][t]
98 | total_count += count
99 | probs.append(count)
100 |
101 | for j, p in enumerate(probs):
102 | probs[j] = p / total_count
103 |
104 | n_states.append({'tokens': tokens, 'probs': probs})
105 |
106 | return n_states
107 |
--------------------------------------------------------------------------------
/explainitall/fast_tuning/trainers/ProjectionTrainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from transformers import Trainer, TrainingArguments, PreTrainedTokenizer
4 | from transformers import DataCollatorForLanguageModeling
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class StringDataset(Dataset):
9 | def __init__(self, tokenizer: PreTrainedTokenizer, texts: list, block_size=256):
10 | block_size = block_size - (tokenizer.model_max_length - tokenizer.max_len_single_sentence)
11 | self.examples = []
12 |
13 | for text in texts:
14 | if len(text)==0:
15 | continue
16 | tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
17 | if len(tokenized_text) >= block_size:
18 | for i in range(0, len(tokenized_text) - block_size + 1, block_size):
19 | self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i:i + block_size]))
20 | else:
21 | self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text))
22 |
23 | def __len__(self):
24 | return len(self.examples)
25 |
26 | def __getitem__(self, i):
27 | return torch.tensor(self.examples[i], dtype=torch.long)
28 |
29 | class GPTProjectionTrainer:
30 | def __init__(self, model, tokenizer):
31 | self.model = model
32 | self.tokenizer = tokenizer
33 |
34 | def load_dataset(self, texts, block_size=256):
35 | dataset = StringDataset(
36 | tokenizer=self.tokenizer,
37 | texts=texts,
38 | block_size=block_size,
39 | )
40 | return dataset
41 |
42 | def create_data_collator(self, mlm=False):
43 | data_collator = DataCollatorForLanguageModeling(
44 | tokenizer=self.tokenizer,
45 | mlm=mlm,
46 | )
47 | return data_collator
48 |
49 | def set_variety(self, bias_mask, variety=0., min_prob=3e-3):
50 | bias = np.zeros((self.model.lm_head.out_features,))
51 | coef_mask = np.log2(variety + min_prob) / np.log2(np.e)
52 | bias += coef_mask
53 |
54 | for token in bias_mask:
55 | bias[token] = 0
56 |
57 | b_tensor = torch.tensor(bias, dtype=torch.float32)
58 | out_gpt_layer = torch.nn.Linear(in_features=self.model.lm_head.in_features,
59 | out_features=self.model.lm_head.out_features, bias=True)
60 | out_gpt_layer.weight = self.model.lm_head.weight
61 | out_gpt_layer.bias.data.copy_(b_tensor)
62 | self.model.lm_head = out_gpt_layer
63 |
64 | def train(self, train_texts, output_dir="new_gpt", last_k=10,
65 | per_device_train_batch_size=2, num_train_epochs=3, save_steps=1000, device=None, lr = 2e-4):
66 |
67 | train_dataset = self.load_dataset(train_texts)
68 |
69 | self.train_with_dataset(train_dataset, output_dir, last_k,
70 | per_device_train_batch_size, num_train_epochs, save_steps, device, lr)
71 |
72 |
73 | def train_with_dataset(self, train_dataset, output_dir="new_gpt", last_k='all',
74 | per_device_train_batch_size=2, num_train_epochs=1, save_steps=10000, device=None, lr = 2e-4):
75 | if device is None:
76 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77 |
78 | self.model.to(device)
79 |
80 | params = []
81 |
82 | # Обучение всех слоев
83 | if last_k == 'all':
84 | for name, param in self.model.named_parameters():
85 | param.requires_grad = True
86 |
87 | # Только проекции
88 | else:
89 | for name, param in self.model.named_parameters():
90 | param.requires_grad = False
91 | if "c_proj.weight" in name: # and "mlp" in name
92 | params.append(param)
93 |
94 | for param in params[-last_k:]:
95 | param.requires_grad = True
96 |
97 |
98 | if len(train_dataset) == 0:
99 | raise ValueError("Dataset is empty. Ensure that the input texts are not empty and of sufficient length.")
100 |
101 | data_collator = self.create_data_collator()
102 |
103 | training_args = TrainingArguments(
104 | output_dir=output_dir,
105 | overwrite_output_dir=True,
106 | per_device_train_batch_size=per_device_train_batch_size,
107 | num_train_epochs=num_train_epochs,
108 | save_steps=save_steps,
109 | learning_rate=lr,
110 | )
111 |
112 | trainer = Trainer(
113 | model=self.model,
114 | args=training_args,
115 | data_collator=data_collator,
116 | train_dataset=train_dataset,
117 | )
118 |
119 | trainer.train()
120 |
--------------------------------------------------------------------------------
/explainitall/gpt_like_interp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Bots-Avatar/ExplainitAll/0339ea5c09c3cd309d53c23b403465c821a778d0/explainitall/gpt_like_interp/__init__.py
--------------------------------------------------------------------------------
/explainitall/gpt_like_interp/downloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import shutil
4 | import zipfile
5 |
6 | import requests
7 | from tqdm import tqdm
8 |
9 |
10 | class DownloadManager:
11 | base_directory = os.path.join(os.path.expanduser("~"), ".cache")
12 |
13 | def _create_directory(self, directory_name):
14 | directory_path = os.path.join(self.base_directory, directory_name)
15 | os.makedirs(directory_path, exist_ok=True)
16 | return directory_path
17 |
18 | @staticmethod
19 | def _clean_string(text):
20 | return re.sub(r'[^a-zA-Z0-9]', '_', text)
21 |
22 | @staticmethod
23 | def _delete_existing_file(file_path):
24 | if os.path.exists(file_path):
25 | os.unlink(file_path)
26 |
27 | @staticmethod
28 | def _download_file(url: str, filepath: str, verbose: bool = True):
29 | if os.path.exists(filepath):
30 | return
31 |
32 | response = requests.get(url, stream=True)
33 | total_size = int(response.headers.get('content-length', 0))
34 | with open(filepath, 'wb') as file:
35 | if verbose:
36 | with tqdm(
37 | desc=f"Downloading: {filepath}",
38 | total=total_size,
39 | unit='iB',
40 | unit_scale=True,
41 | unit_divisor=1024,
42 | ) as progress_bar:
43 | for data in response.iter_content(chunk_size=1024):
44 | written_size = file.write(data)
45 | progress_bar.update(written_size)
46 | else:
47 | for data in response.iter_content(chunk_size=1024):
48 | file.write(data)
49 |
50 | @staticmethod
51 | def _delete_existing_folder(path: str):
52 | if os.path.exists(path):
53 | shutil.rmtree(path)
54 |
55 | @staticmethod
56 | def _extract_zip_file(file_path, destination_path):
57 | if not os.path.exists(destination_path):
58 |
59 | with zipfile.ZipFile(file_path, "r") as zip_file:
60 | for file in tqdm(iterable=zip_file.namelist(),
61 | total=len(zip_file.namelist()),
62 | desc=f"Extracting: {destination_path}"):
63 | zip_file.extract(member=file, path=destination_path)
64 |
65 | @classmethod
66 | def load_zip(cls, url, remove_existing=False, model_file_name='model.bin', verbose=True):
67 | zip_filename = cls._clean_string(url.split("/")[-1])
68 | zip_file_path = os.path.join(cls.base_directory, zip_filename)
69 |
70 | if remove_existing:
71 | cls._delete_existing_file(zip_file_path)
72 |
73 | cls._download_file(url, zip_file_path, verbose)
74 |
75 | extracted_data_directory_path = os.path.join(cls.base_directory, cls._clean_string(zip_filename) + "_data")
76 | if remove_existing:
77 | cls._delete_existing_folder(extracted_data_directory_path)
78 |
79 | cls._extract_zip_file(zip_file_path, extracted_data_directory_path)
80 |
81 | return os.path.join(extracted_data_directory_path, model_file_name)
82 |
83 |
84 | if __name__ == "__main__":
85 | download_path = DownloadManager.load_zip(
86 | 'http://vectors.nlpl.eu/repository/20/180.zip',
87 | remove_existing=True)
88 | print("Download 1 path:", download_path)
89 |
90 | download_path = DownloadManager.load_zip(
91 | 'http://vectors.nlpl.eu/repository/20/180.zip',
92 | remove_existing=False)
93 | print("Download 2 path:", download_path)
94 |
95 | download_path = DownloadManager.load_zip(
96 | 'http://vectors.nlpl.eu/repository/20/180.zip',
97 | remove_existing=True, verbose=False)
98 | print("Download 2 path:", download_path)
99 |
--------------------------------------------------------------------------------
/explainitall/gpt_like_interp/inseq_helpers.py:
--------------------------------------------------------------------------------
1 | import re
2 | from copy import deepcopy
3 | from typing import Tuple, Union, Optional
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | from inseq import FeatureAttributionOutput
9 | from inseq.utils.typing import GranularSequenceAttributionTensor as Gast
10 | from inseq.utils.typing import TokenSequenceAttributionTensor as Tsat
11 | from torch.linalg import vector_norm
12 |
13 | from explainitall import stat_helpers
14 |
15 |
16 | def sum_normalize_attributions(
17 | attributions: Union[Gast, Tuple[Gast, Gast]],
18 | cat_dim: int = 0,
19 | norm_dim: Optional[int] = 0,
20 | ) -> Tsat:
21 | """
22 | Суммаризация и нормализация тензоров по dim_sum
23 | РезультатЖ матрица векторов строк
24 | """
25 | concat = False
26 | if isinstance(attributions, tuple):
27 | concat = True
28 | orig_sizes = [a.shape[cat_dim] for a in attributions]
29 | attributions = torch.cat(attributions, dim=cat_dim)
30 | else:
31 | orig_sizes = [attributions.shape[cat_dim]]
32 | attributions = vector_norm(attributions, ord=2, dim=-1)
33 | if norm_dim is not None:
34 | attributions = attributions / attributions.nansum(dim=norm_dim, keepdim=True)
35 | if len(attributions.shape) == 1:
36 | attributions = attributions.unsqueeze(0)
37 | if concat:
38 | attributions = attributions.split(orig_sizes, dim=cat_dim)
39 | return attributions[0], attributions[1]
40 | return attributions
41 |
42 |
43 | def fix_ig_tokens(feature_attr: FeatureAttributionOutput):
44 | from transformers import GPT2Tokenizer
45 |
46 | feature_attr_conv = deepcopy(feature_attr)
47 | dt = GPT2Tokenizer.from_pretrained(feature_attr_conv.info['model_name'])
48 |
49 | for attr in feature_attr_conv.sequence_attributions:
50 | for token_holder in (attr.source, attr.target):
51 | for s in token_holder:
52 | try:
53 | s.token = dt.convert_tokens_to_string(s.token)
54 | except KeyError:
55 | pass
56 | return feature_attr_conv
57 |
58 |
59 | def get_ig_tokens(feature_attr: FeatureAttributionOutput):
60 | rez = []
61 | for attr in feature_attr.sequence_attributions:
62 | rez.append((tuple(s.token for s in attr.source),
63 | tuple(t.token for t in attr.target)))
64 | return rez
65 |
66 |
67 | def get_ig_phrases(feature_attr: FeatureAttributionOutput):
68 | return tuple(zip(feature_attr.info['input_texts'],
69 | feature_attr.info['generated_texts']))
70 |
71 |
72 | def get_g_arrays(feature_attr: FeatureAttributionOutput):
73 | target_arrays = []
74 | for attr in feature_attr.sequence_attributions:
75 | ta = attr.target_attributions
76 | ta2 = sum_normalize_attributions(ta)
77 | ta2 = np.array(ta2, dtype=float)
78 |
79 | target_arrays.append(np.array(ta2, dtype=float))
80 | return target_arrays
81 |
82 |
83 | class AttrObj:
84 | def __init__(self,
85 | phrase_input: str,
86 | phrase_generated: str,
87 | tokens_input: Tuple[str, ...],
88 | tokens_generated: Tuple[str, ...],
89 | array: np.ndarray):
90 | self.phrase_input = phrase_input
91 | self.phrase_generated = phrase_generated
92 | self.tokens_input = tokens_input
93 | self.tokens_generated = tokens_generated
94 | self.array = array
95 |
96 | def __repr__(self):
97 | return (f"AttrObj({self.phrase_input=}, "
98 | f"{self.phrase_generated=}, "
99 | f"{self.tokens_input=}, "
100 | f"{self.tokens_generated=}, "
101 | f"{self.array.shape=})")
102 |
103 |
104 | def get_first_attribute(feature_attr: FeatureAttributionOutput):
105 | fixed_attr = fix_ig_tokens(feature_attr)
106 | phrase_input, phrase_generated_full = get_ig_phrases(fixed_attr)[0]
107 | phrase_generated = phrase_generated_full[len(phrase_input):]
108 | tokens_input, tokens_generated_full = get_ig_tokens(fixed_attr)[0]
109 | tokens_generated = tokens_generated_full[len(tokens_input):]
110 | array = get_g_arrays(fixed_attr)[0]
111 |
112 | return AttrObj(phrase_input=phrase_input,
113 | phrase_generated=phrase_generated,
114 | tokens_input=tokens_input,
115 | tokens_generated=tokens_generated,
116 | array=array)
117 |
118 |
119 | def attr_to_df(attr: AttrObj):
120 | """Преобразует атрибуты в DataFrame"""
121 | df = pd.DataFrame(attr.array)
122 | df.columns = attr.tokens_generated
123 | df = df.sort_index()
124 | df.insert(0, 'Tokens', attr.tokens_input + attr.tokens_generated)
125 | return df
126 |
127 |
128 | def squash_arr(arr, squash_row_mask, squash_col_mask, aggr_f=np.max):
129 | # Apply the mask to the rows
130 | row_result = np.array([aggr_f(arr[start:end], axis=0)
131 | for start, end in squash_row_mask])
132 | # Apply the mask to the columns
133 | col_result = np.array([aggr_f(row_result[:, start:end], axis=1)
134 | for start, end in squash_col_mask]).T
135 | return col_result
136 |
137 |
138 | class Detokenizer:
139 | """
140 | Класс для детокенизации (приведения токенов к словам).
141 | """
142 | # список символов-тире
143 | dash_chars = list(map(chr, (45, 8211, 8212, 8722, 9472, 9473, 9476)))
144 | # регулярное выражение для поиска тире между буквами
145 | dash_regex = re.compile(r'(? AttrObj:
209 | tokens_input_grouped = Detokenizer(attr.phrase_input, attr.tokens_input).group_text()
210 | tokens_generated_grouped = Detokenizer(attr.phrase_generated, attr.tokens_generated).group_text()
211 | tokens_input_generated_mask = calculate_mask(tokens_input_grouped + tokens_generated_grouped)
212 | tokens_generated_mask = calculate_mask(tokens_generated_grouped)
213 |
214 | squashed_array = squash_arr(attr.array,
215 | squash_col_mask=tokens_generated_mask,
216 | squash_row_mask=tokens_input_generated_mask)
217 |
218 | tokens_input_grouped_flatten = tuple(["".join(x) for x in tokens_input_grouped])
219 | tokens_generated_grouped_flatten = tuple("".join(x) for x in tokens_generated_grouped)
220 |
221 | if gmm_norm:
222 | squashed_array = stat_helpers.calc_gmm_stat_params(squashed_array)['new_arr']
223 |
224 | return AttrObj(phrase_input=attr.phrase_input,
225 | phrase_generated=attr.phrase_generated,
226 | tokens_input=tokens_input_grouped_flatten,
227 | tokens_generated=tokens_generated_grouped_flatten,
228 | array=squashed_array)
229 |
--------------------------------------------------------------------------------
/explainitall/gpt_like_interp/interp.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict
2 |
3 | import gensim
4 | import pandas as pd
5 | from inseq import AttributionModel
6 |
7 | from explainitall import clusters
8 | from explainitall.clusters import ClusterManager, aggregate_cluster_df
9 | from explainitall.gpt_like_interp import viz
10 | from . import inseq_helpers
11 | from .inseq_helpers import AttrObj
12 |
13 |
14 | class ExplainerGPT2Output:
15 | def __init__(self, attributions: AttrObj,
16 | attributions_grouped: AttrObj,
17 | attributions_grouped_norm: AttrObj,
18 |
19 | cluster_imp_df: pd.DataFrame,
20 | cluster_imp_aggr_df: pd.DataFrame,
21 | word_imp_df: pd.DataFrame,
22 | word_imp_norm_df: pd.DataFrame):
23 | self.attributions = attributions
24 | self.attributions_grouped = attributions_grouped
25 | self.attributions_grouped_norm = attributions_grouped_norm
26 |
27 | self.cluster_imp_df = cluster_imp_df
28 | self.cluster_imp_aggr_df = cluster_imp_aggr_df
29 | self.word_imp_df = word_imp_df
30 | self.word_imp_norm_df = word_imp_norm_df
31 |
32 | def show_word_imp_heatmap(self):
33 | viz.df_to_heatmap(self.word_imp_df, title="Карта важности слов")
34 |
35 | def show_word_imp_norm_heatmap(self):
36 | viz.df_to_heatmap(self.word_imp_norm_df, title="Карта важности слов, нормированная")
37 |
38 | def show_cluster_imp_heatmap(self):
39 | viz.df_to_heatmap(self.cluster_imp_df, title="Карта важности слов")
40 |
41 | def show_cluster_imp_aggr_heatmap(self):
42 | viz.df_to_heatmap(self.cluster_imp_aggr_df, title="Карта важности слов группированная")
43 |
44 |
45 | class ExplainerGPT2:
46 | def __init__(self, gpt_model: AttributionModel,
47 | nlp_model: gensim.models.keyedvectors.KeyedVectors):
48 | self.gpt_model = gpt_model
49 | self.nlp_model = nlp_model
50 | self._cluster_manager = ClusterManager(embeddings=self.nlp_model)
51 | self.attributions = None
52 |
53 | def interpret(self,
54 | input_texts: str,
55 | generated_texts: str,
56 | clusters_description: List[Dict[str, object]],
57 | batch_size: int = 50,
58 | steps: int = 34,
59 | max_new_tokens: int = None,
60 | aggr_f='mean') -> ExplainerGPT2Output:
61 | self._attribute(input_texts, generated_texts, max_new_tokens, steps, batch_size)
62 | return self._run_pipeline(clusters_description, aggr_f)
63 |
64 | @staticmethod
65 | def calc_max_tokes(input_texts, generated_texts):
66 | from gensim.utils import tokenize
67 | tokens = list(tokenize(input_texts + generated_texts, lowercase=True))
68 | num_tokens = len(tokens)
69 | return num_tokens + 10 # buffer
70 |
71 | def _attribute(self, input_texts: str,
72 | generated_texts: str,
73 | max_new_tokens: int,
74 | steps: int, batch_size: int):
75 |
76 | generation_args = None
77 | if not generated_texts:
78 | if max_new_tokens is None:
79 | max_new_tokens = self.calc_max_tokes(input_texts, generated_texts)
80 | generation_args = {"max_new_tokens": max_new_tokens}
81 |
82 | out = self.gpt_model.attribute(
83 | input_texts=input_texts, generated_texts=input_texts + generated_texts,
84 | n_steps=steps, generation_args=generation_args,
85 | show_progress=True, pretty_progress=False, internal_batch_size=batch_size
86 | )
87 | self.attributions = inseq_helpers.get_first_attribute(out)
88 |
89 | def _run_pipeline(self, cluster_desc, aggr_f):
90 | group_attr = inseq_helpers.group_by(self.attributions)
91 | norm_attr = inseq_helpers.group_by(self.attributions, gmm_norm=True)
92 |
93 | word_imp_df = inseq_helpers.attr_to_df(group_attr)
94 | word_imp_norm_df = inseq_helpers.attr_to_df(norm_attr)
95 |
96 | cluster_interpreter = clusters.ClusterInterpreter(clusters_discr=cluster_desc,
97 | cluster_manager=self._cluster_manager)
98 |
99 | cluster_imp_df = cluster_interpreter.get_cluster_importance_df(self.attributions)
100 | try:
101 | cluster_imp_aggr_df = aggregate_cluster_df(cluster_imp_df, aggr_f=aggr_f)
102 | except Exception as e:
103 | print(f"Неверно заданы выходные кластеры: {e}")
104 | cluster_imp_aggr_df = pd.DataFrame()
105 |
106 | return ExplainerGPT2Output(
107 | attributions=self.attributions, attributions_grouped=group_attr, attributions_grouped_norm=norm_attr,
108 | cluster_imp_df=cluster_imp_df, cluster_imp_aggr_df=cluster_imp_aggr_df,
109 | word_imp_df=word_imp_df, word_imp_norm_df=word_imp_norm_df)
110 |
--------------------------------------------------------------------------------
/explainitall/gpt_like_interp/viz.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import seaborn as sns
6 | from statsmodels.graphics.gofplots import qqplot
7 |
8 | from explainitall import stat_helpers
9 |
10 |
11 | def show_distribution_histogram(data):
12 | """
13 | Строит гистограмму распределения
14 | """
15 | stat_params = stat_helpers.calc_gmm_stat_params(data)
16 | data = data[~np.isnan(data)]
17 | data.sort()
18 |
19 | gaussian_prob = stat_helpers.compute_gaussian_integral(data, stat_params["mean"], stat_params["std"])
20 | rayleigh_prob = stat_helpers.rayleigh_integral(data, data.var())
21 |
22 | plt.hist(data, bins=15, cumulative=True, density=True, label='Распределение реальных данных')
23 | plt.plot(data, gaussian_prob, color='red', label='Функция Гаусса')
24 | plt.plot(data, rayleigh_prob, label='Функция Рэлея')
25 |
26 | plt.legend(title='Функции распределения')
27 |
28 | ax = plt.gca()
29 | ax.set_xlabel("Значения", fontsize=10, color='black', labelpad=10)
30 | ax.set_ylabel("Распределение P(X"
126 | "Контекст и сгенерированный текст - это входные данные для анализа. ")
127 | context_text = gr.Text(label='Context',
128 | info="**Контекст** это ваш исходный текст, который служит отправной точкой для генерации ответа "
129 | "пример: 'у кошки грипп и аллергия на антибиотбиотики вопрос: чем лечить кошку? ответ:'",
130 | lines=1, placeholder="Enter context here...")
131 |
132 | output_text = gr.Text(label='Generated text',
133 | info="**Сгенерированный текст** - это результат работы системы, ответ на ваш контекст "
134 | " 'лечите ее уколами'",
135 | lines=1, placeholder="Enter generated text here...")
136 |
137 | with gr.Row():
138 | texts_load_button = gr.Button("Set texts")
139 | texts_set_checkbox = gr.Checkbox(label='Texts are set', interactive=False,
140 | info="Если тексты установлены, можно переходить к шагу Clusters.")
141 |
142 | with gr.TabItem("Clusters"):
143 | with gr.Tabs():
144 | with gr.TabItem("Load clusters from file"):
145 | gr.Markdown("""
146 | Загрузите файл с описанием кластеров. Файл должен быть в формате JSON и содержать список кластеров, где каждый кластер описывается следующими ключами:
147 | - `name`: Название кластера (строка).
148 | - `centroid`: Центроид кластера, представленный списком строк (например, список терминов, характеризующих кластер).
149 | - `top_k`: Количество топовых элементов кластера (целое число).
150 |
151 |
152 | пример файла: example_data/clusters.json
153 |
154 | Пример структуры файла:
155 | ```json
156 | [
157 | {
158 | "name": "Кластер 1",
159 | "centroid": ["термин1", "термин2", "термин3"],
160 | "top_k": 5
161 | },
162 | {
163 | "name": "Кластер 2",
164 | "centroid": ["термин4", "термин5", "термин6"],
165 | "top_k": 3
166 | }
167 | ]
168 | ```
169 |
170 | Этот файл используется для анализа и визуализации влияния различных групп слов (кластеров) на сгенерированный текст. Загрузка подходящего файла позволит провести более глубокий анализ и понять, какие тематические группы наиболее важны в контексте генерации текста.
171 | """)
172 | with gr.Column():
173 | clusters_file = gr.File(label='Clusters\' file')
174 | with gr.Row():
175 | clusters_load_from_file_button = gr.Button("Load clusters")
176 | with gr.TabItem("Set clusters manually"):
177 | with gr.Column():
178 | set_clusters_table = gr.Dataframe(label='Clusters\' table',
179 | headers=['name', 'centroid', 'top_k'],
180 | # max_rows=None,
181 | height=None,
182 | # overflow_row_behaviour='paginate',
183 | wrap=False,
184 | interactive=True)
185 | with gr.Row():
186 | set_clusters_from_dataframe_button = gr.Button("Set clusters")
187 |
188 | clusters_set_checkbox = gr.Checkbox(
189 | label='Clusters are set Если кластеры и NLP модель установлены, можно переходить к шагу LLM model',
190 | interactive=False)
191 |
192 | cluster_table = gr.Dataframe(label='Clusters',
193 | headers=['name', 'centroid', 'top_k'],
194 | # max_rows=None,
195 | height=None,
196 | # overflow_row_behaviour='paginate',
197 | wrap=False,
198 | interactive=False)
199 |
200 | clusters_file_path_text = gr.Text(label='Save clusters to file')
201 | with gr.Row():
202 | clusters_save_button = gr.Button("Save clusters")
203 | clusters_save_checkbox = gr.Checkbox(label='Clusters are saved', interactive=False)
204 |
205 | nlp_model_url = gr.Text(label='Model url',
206 | info="Введите URL для загрузки предварительно обученной NLP модели. Модель должна быть в формате, совместимом с библиотекой gensim, например,"
207 | " Word2Vec, FastText или любой другой векторной модели слов. например http://vectors.nlpl.eu/repository/20/180.zip",
208 | placeholder="http://vectors.nlpl.eu/repository/20/180.zip",
209 | lines=1)
210 | with gr.Row():
211 | load_nlp_model_button = gr.Button("Load model")
212 | nlp_model_set_checkbox = gr.Checkbox(label='NLP model loaded', interactive=False)
213 | with gr.TabItem("LLM interpretation model"):
214 | nn_model_name_or_path = gr.Text(label='Model name or path',
215 | info="Введите название модели или путь к ней для использования в качестве модели интерпретации."
216 | " Это должна быть модель на основе GPT или другой современной трансформерной модели, например sberbank-ai/rugpt3small_based_on_gpt2",
217 | placeholder="sberbank-ai/rugpt3small_based_on_gpt2",
218 | lines=1)
219 | with gr.Row():
220 | load_nn_model_button = gr.Button("Load model")
221 | nn_model_set_checkbox = gr.Checkbox(
222 | label='NN model loaded, если модель загружена можно переходить к шагу Results',
223 | interactive=False)
224 | with gr.TabItem("Results"):
225 | with gr.Row():
226 | with gr.Column():
227 | result_texts_set_checkbox = gr.Checkbox(label='Texts are set', interactive=False)
228 | result_clusters_set_checkbox = gr.Checkbox(label='Clusters are set', interactive=False)
229 | result_nlp_model_set_checkbox = gr.Checkbox(label='NLP model loaded', interactive=False)
230 | result_nn_model_set_checkbox = gr.Checkbox(label='NN model loaded', interactive=False)
231 | launch_button = gr.Button("Launch")
232 | with gr.Column():
233 | with gr.Tabs():
234 | with gr.TabItem("Word importance"):
235 | # word_importance_image = gr.Image().style(height=600)
236 | word_importance_image = gr.Image(height=600)
237 | with gr.TabItem("Word importance normalized"):
238 | # word_importance_norm_image = gr.Image().style(height=600)
239 | word_importance_norm_image = gr.Image(height=600)
240 | with gr.TabItem("Cluster importance"):
241 | # cluster_importance_image = gr.Image().style(height=600)
242 | cluster_importance_image = gr.Image(height=600)
243 | with gr.TabItem("Cluster importance grouped"):
244 | # cluster_importance_norm_image = gr.Image().style(height=600)
245 | cluster_importance_norm_image = gr.Image(height=600)
246 | with gr.TabItem("Chatbot"):
247 | chat = gr.Chatbot(label="Chatbot", layout="bubble")
248 | msg = gr.Text(label="Message", interactive=True)
249 | send_message = gr.Button("Send", interactive=True)
250 | clear = gr.ClearButton([msg, chat])
251 | # chat.change(fn=chatbot_function, inputs=msg, outputs=chat)
252 |
253 | texts_load_button.click(self.load_context_and_generated_text_,
254 | inputs=[context_text, output_text],
255 | outputs=[texts_set_checkbox, result_texts_set_checkbox])
256 |
257 | clusters_load_from_file_button.click(self.load_clusters_from_file_,
258 | inputs=[clusters_file],
259 | outputs=[clusters_set_checkbox, result_clusters_set_checkbox,
260 | cluster_table, set_clusters_table])
261 | set_clusters_from_dataframe_button.click(self.set_clusters_from_dataframe_,
262 | inputs=[set_clusters_table],
263 | outputs=[clusters_set_checkbox, result_clusters_set_checkbox,
264 | cluster_table])
265 | clusters_save_button.click(self.save_new_clusters_to_file_,
266 | inputs=[cluster_table, clusters_file_path_text],
267 | outputs=[clusters_save_checkbox])
268 |
269 | load_nn_model_button.click(self.load_nn_model_,
270 | inputs=[nn_model_name_or_path],
271 | outputs=[nn_model_set_checkbox, result_nn_model_set_checkbox])
272 |
273 | load_nlp_model_button.click(self.load_nlp_model_,
274 | inputs=[nlp_model_url],
275 | outputs=[nlp_model_set_checkbox, result_nlp_model_set_checkbox])
276 |
277 | launch_button.click(self.show_results, inputs=[], outputs=[word_importance_image,
278 | word_importance_norm_image,
279 | cluster_importance_image,
280 | cluster_importance_norm_image])
281 |
282 | send_message.click(self.respond_, inputs=[msg, chat], outputs=[msg, chat])
283 |
284 | def __del__(self):
285 | pass
286 |
287 | # FUNCTIONALITY:
288 |
289 | def launch(self):
290 | self.demo_.launch(share=True, debug=False, server_name="127.0.0.1", inbrowser=True)
291 |
292 | def show_results(self):
293 | self.explainer_ = interp.ExplainerGPT2(gpt_model=self.nn_model_, nlp_model=self.nlp_model_)
294 | expl_data = self.explainer_.interpret(input_texts=self.context_,
295 | generated_texts=self.generated_text_,
296 | clusters_description=self.clusters_,
297 | batch_size=50,
298 | steps=34,
299 | # max_new_tokens=19
300 | )
301 |
302 | # Результат интерпретации
303 | imp_df_cl = expl_data.cluster_imp_aggr_df
304 | cl_desc = interp_cl(imp_df_cl)
305 |
306 | clean = [clean_string(cl_data) for cl_data in cl_desc]
307 | vects_x = self.sbert_.encode(clean)
308 | m = vects_x.mean(axis=0)
309 | s = vects_x.std(axis=0)
310 | try:
311 | knn_vects_x = (vects_x - m) / s
312 | knn = KNeighborsClassifier(metric=cos_dist)
313 | knn.fit(knn_vects_x, cl_desc)
314 |
315 | self.interp_bot_ = PromptBot(knn, self.sbert_, self.fred_, cl_desc, device='cpu')
316 | except:
317 | print("Err")
318 | self.interp_bot_ = None
319 | word_importance_plt = df_to_heatmap_plot(expl_data.word_imp_df, title="Карта важности слов")
320 | word_importance_norm_plt = df_to_heatmap_plot(expl_data.word_imp_norm_df,
321 | title="Карта важности слов, нормированная")
322 |
323 | cluster_importance_plt = df_to_heatmap_plot(expl_data.cluster_imp_df, title="Карта важности кластеров")
324 | cluster_importance_norm_plt = df_to_heatmap_plot(expl_data.cluster_imp_aggr_df,
325 | title="Карта важности кластеров, группированная")
326 |
327 | return word_importance_plt, word_importance_norm_plt, cluster_importance_plt, cluster_importance_norm_plt
328 |
329 | # PRIVATE FUNCTIONS:
330 |
331 | def respond_(self, message, chat_history):
332 | ans = self.interp_bot_.get_answers(message, top_k=3)
333 | bot_reply = 'Кластер' + ans.split('.')[0].split('Кластер')[1]
334 | chat_history.append((message, bot_reply))
335 |
336 | return "", chat_history
337 |
338 | def load_context_and_generated_text_(self, context, generated_text):
339 | self.context_ = context
340 | self.generated_text_ = generated_text
341 |
342 | return True, True
343 |
344 | def load_clusters_from_file_(self, jsonfile: tempfile._TemporaryFileWrapper):
345 | with open(jsonfile.name, 'r') as fp:
346 | self.clusters_ = json.load(fp)
347 | df = make_dataframe_from_clusters(self.clusters_)
348 |
349 | return True, True, df, df
350 |
351 | def set_clusters_from_dataframe_(self, df):
352 | self.clusters_ = make_clusters_from_dataframe(df)
353 |
354 | return True, True, df
355 |
356 | def save_new_clusters_to_file_(self, df, filename):
357 | clusters = make_clusters_from_dataframe(df)
358 | with open(filename, 'w') as fp:
359 | json.dump(clusters, fp)
360 |
361 | return True
362 |
363 | def load_nlp_model_(self, url):
364 | self.npl_model_url_ = url
365 | nlp_model_path = DownloadManager.load_zip(url)
366 | self.nlp_model_ = gensim.models.KeyedVectors.load_word2vec_format(nlp_model_path, binary=True)
367 |
368 | return True, True
369 |
370 | def load_nn_model_(self, model_name_or_path):
371 | path = os.path.normpath(model_name_or_path)
372 | path_list = path.split(os.sep)
373 | self.nn_model_name_ = path_list[-1]
374 |
375 | self.nn_model_ = load_model(model=model_name_or_path,
376 | attribution_method="integrated_gradients")
377 |
378 | return True, True
379 |
--------------------------------------------------------------------------------
/explainitall/metrics/CheckingForHallucinations.py:
--------------------------------------------------------------------------------
1 | import nltk
2 | import numpy as np
3 | from nltk.tokenize import sent_tokenize
4 | from sentence_transformers import SentenceTransformer, util
5 |
6 |
7 | def sim_cosine(embedding_1, embedding_2):
8 | s = util.pytorch_cos_sim(embedding_1, embedding_2)
9 | return s.detach().numpy()
10 |
11 | def sim_euclidean(embedding_1, embedding_2):
12 | s = np.linalg.norm(embedding_1 - embedding_2, axis=1)
13 | return s
14 |
15 | def sim_manhattan(embedding_1, embedding_2):
16 | s = np.sum(np.abs(embedding_1 - embedding_2), axis=1)
17 | return s
18 |
19 | def sim_cross_encoder(encoder, sentences1, sentences2):
20 | scores = encoder.predict([sentences1, sentences2])
21 | return scores
22 |
23 | class RAGHallucinationsChecker:
24 | def __init__(self, sbert_model: SentenceTransformer, cross_encoder=None, language = 'russian'):
25 | nltk.download('punkt')
26 | self.sbert_model = sbert_model
27 | self.cross_encoder = cross_encoder
28 | self.language = language
29 |
30 | def load_doc(self, text, block_size=3):
31 | seqs = sent_tokenize(text, language=self.language)
32 | count_block = max([1, len(seqs) // block_size])
33 | seqs = [list(x) for x in np.array_split(seqs, count_block)]
34 | snippets = [' '.join(x) for x in seqs]
35 |
36 | return snippets
37 |
38 | def get_support_seq(self, doc_snp, ans, prob=0.6, top_k=1, sim_metric='cosine'):
39 |
40 | docs = []
41 | sn_ans = self.load_doc(ans, 1)
42 |
43 | for d in doc_snp:
44 | docs += self.load_doc(d, 1)
45 |
46 | top_k = min(top_k, len(docs))
47 |
48 | ans_v = self.sbert_model.encode(sn_ans)
49 | doc_v = self.sbert_model.encode(docs)
50 |
51 | if sim_metric == 'cosine':
52 | matrix_ = sim_cosine(doc_v, ans_v)
53 | elif sim_metric == 'euclidean':
54 | matrix_ = sim_euclidean(doc_v, ans_v)
55 | elif sim_metric == 'manhattan':
56 | matrix_ = sim_manhattan(doc_v, ans_v)
57 | elif sim_metric == 'cross_encoder' and self.cross_encoder is not None:
58 | matrix_ = sim_cross_encoder(self.cross_encoder, docs, sn_ans)
59 | else:
60 | raise ValueError("Unsupported similarity metric or cross-encoder is not provided.")
61 |
62 | res = []
63 | for i in range(matrix_.shape[1]):
64 | slice_ = matrix_[:, i]
65 | top_indexes = np.argpartition(slice_, -top_k)[-top_k:]
66 | top_probs = matrix_[top_indexes, i]
67 | top_indexes[top_probs < prob] = -1
68 |
69 | reference_texts = []
70 | indexes = []
71 | for j, d in enumerate(top_indexes):
72 | if d >= 0:
73 | reference_texts += [docs[d]]
74 | indexes += [d]
75 |
76 | if len(indexes) > 0:
77 | res.append({'answer': sn_ans[i], 'reference_texts': reference_texts, 'indexes': indexes})
78 |
79 | return res
80 |
81 | def get_conf(self, doc_snp, ans, prob=0.6, sim_metric='cosine'):
82 | answer_a = self.get_support_seq(ans=ans, doc_snp=doc_snp, prob=prob, sim_metric=sim_metric)
83 | len_all = len(ans)
84 | len_with_out_h = len(answer_a) - 1 # Учитываем пробелы
85 |
86 | for s in answer_a:
87 | len_with_out_h += len(s['answer'])
88 |
89 | return len_with_out_h / len_all
90 |
91 | def get_hallucinations_prob(self, doc_snp, ans, prob=0.6, sim_metric='cosine'):
92 | return 1 - self.get_conf(doc_snp, ans, prob, sim_metric)
93 |
94 |
--------------------------------------------------------------------------------
/explainitall/metrics/RougeAndPPL/Metrics.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from transformers import PreTrainedModel, PreTrainedTokenizer
3 | from typing import List, Dict
4 |
5 | import torch
6 |
7 | from explainitall.metrics.RougeAndPPL.helpers import get_all_words_from_text
8 | from explainitall.metrics.RougeAndPPL.rouge_L import rouge_L
9 | from explainitall.metrics.RougeAndPPL.rouge_N import rouge_N
10 |
11 |
12 | class Metric(ABC):
13 | @staticmethod
14 | def preprocess(contexts, references, candidates, tokenizer):
15 | raise NotImplementedError
16 |
17 | @abstractmethod
18 | def calculate(self, references_encodings, candidates_encodings):
19 | raise NotImplementedError
20 |
21 |
22 | class Metric_rouge(Metric):
23 |
24 | def __init__(self, n):
25 | self.n = n
26 |
27 | @staticmethod
28 | def preprocess(contexts, references, candidates, tokenizer):
29 | res = {'references': [], 'candidates': []}
30 | for i, context in enumerate(contexts):
31 | res['references'].append(get_all_words_from_text(references[i][len(contexts[i]):]))
32 | res['candidates'].append(get_all_words_from_text(candidates[i][len(contexts[i]):]))
33 | return res
34 |
35 |
36 | class MetricRougeL(Metric_rouge):
37 |
38 | def calculate(self, references_encodings, candidates_encodings):
39 | res = [rouge_L(reference, candidate) for reference, candidate in
40 | zip(references_encodings, candidates_encodings)]
41 | return res
42 |
43 |
44 | class MetricRougeN(Metric_rouge):
45 |
46 | def calculate(self, references_encodings, candidates_encodings):
47 | res = [rouge_N(reference, candidate, self.n) for reference, candidate in
48 | zip(references_encodings, candidates_encodings)]
49 | return res
50 |
51 |
52 | class MetricStandard(Metric):
53 |
54 | def __init__(self, metric_function):
55 | self.metric_function = metric_function
56 |
57 | def calculate(self, references_encodings, candidates_encodings):
58 | res = [self.metric_function(reference, candidate) for reference, candidate in
59 | zip(references_encodings, candidates_encodings)]
60 | return res
61 |
62 |
63 | class Metric_ppl(Metric):
64 |
65 | def __init__(self, model: PreTrainedModel, stride: int):
66 |
67 | self.model = model
68 | self.stride = stride
69 |
70 | @staticmethod
71 | def preprocess(contexts: List[str], references: List[str], candidates: List[str], tokenizer: PreTrainedTokenizer) -> Dict[str, List[List[int]]]:
72 | """
73 | Предобработка данных
74 | """
75 | tokenized_data = {'references': [], 'candidates': []}
76 |
77 | for context, reference in zip(contexts, references):
78 | encoded_reference = tokenizer(reference)
79 | tokenized_data['references'].append(encoded_reference.input_ids)
80 |
81 | return tokenized_data
82 |
83 | def calculate(self, reference_encodings: List[List[int]], candidate_encodings: List[List[int]]) -> List[Dict[str, float]]:
84 | """
85 | Вычисление ppl для списка токенизированных текстов
86 | """
87 | perplexities = [self._calculate_perplexity(encodings) for encodings in reference_encodings]
88 | return perplexities
89 |
90 | def _calculate_perplexity(self, encodings: List[int]) -> Dict[str, float]:
91 | """
92 | Вычисление перплексии для одного токенизированного текста
93 | """
94 | max_length = self.model.config.n_positions
95 | sequence_length = len(encodings)
96 |
97 | neg_log_likelihoods = []
98 | prev_end_loc = 0
99 |
100 | for start_loc in range(0, sequence_length, self.stride):
101 | end_loc = min(start_loc + max_length, sequence_length)
102 | target_length = end_loc - prev_end_loc
103 |
104 | input_ids = torch.tensor(encodings[start_loc:end_loc], device=self.model.device).unsqueeze(0)
105 | target_ids = input_ids.clone()
106 | target_ids[:, :-target_length] = -1e-3
107 |
108 | with torch.no_grad():
109 | outputs = self.model(input_ids, labels=target_ids)
110 | neg_log_likelihood = outputs.loss * target_length
111 |
112 | neg_log_likelihoods.append(neg_log_likelihood)
113 | prev_end_loc = end_loc
114 |
115 | if end_loc == sequence_length:
116 | break
117 |
118 | total_neg_log_likelihood = torch.stack(neg_log_likelihoods).sum()
119 | perplexity = torch.exp(total_neg_log_likelihood / sequence_length)
120 | return {'value': perplexity.item()}
121 |
--------------------------------------------------------------------------------
/explainitall/metrics/RougeAndPPL/Metrics_calculator.py:
--------------------------------------------------------------------------------
1 | class Metrics_calculator:
2 | metrics_ = None
3 | preprocessing_functions_ = None
4 | tokenizer_ = None
5 |
6 | def __init__(self, tokenizer):
7 | self.metrics_ = {}
8 | self.preprocessing_functions_ = {}
9 | self.tokenizer_ = tokenizer
10 |
11 | def __del__(self):
12 | del self.metrics_
13 | del self.tokenizer_
14 |
15 | def calculate(self, contexts, references, candidates):
16 | res = {}
17 |
18 | assert len(contexts) == len(references) == len(candidates)
19 |
20 | preprocessed_sentences = {}
21 | for preproc in self.preprocessing_functions_:
22 | preprocessed = preproc(contexts, references, candidates, self.tokenizer_)
23 | preprocessed_sentences[frozenset(self.preprocessing_functions_[preproc])] = preprocessed
24 |
25 | for metric_name in self.metrics_:
26 | for metric_group in preprocessed_sentences:
27 | if metric_name in metric_group:
28 | prep = preprocessed_sentences[metric_group]
29 | prep_references = prep['references']
30 | prep_candidates = prep['candidates']
31 | res[metric_name] = self.metrics_[metric_name].calculate(prep_references, prep_candidates)
32 | break
33 | return res
34 |
35 | def add_metric(self, name, metric):
36 | if name in self.metrics_:
37 | raise Exception('Metric name ' + name + ' already exists!')
38 | self.metrics_[name] = metric
39 |
40 | preproc = metric.preprocess
41 | if preproc not in self.preprocessing_functions_:
42 | self.preprocessing_functions_[preproc] = set()
43 | self.preprocessing_functions_[preproc].add(name)
44 |
--------------------------------------------------------------------------------
/explainitall/metrics/RougeAndPPL/create_database.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import create_engine, Table, Column, Float, Integer, String, MetaData
2 |
3 |
4 | ENGINE = create_engine('sqlite:///database.sqlite', echo=True)
5 |
6 |
7 | metadata = MetaData()
8 |
9 |
10 | data_table = Table('data', metadata,
11 | Column('id', Integer, primary_key=True, autoincrement=True),
12 | Column('model_name', String),
13 | Column('timestamp', Integer),
14 | Column('dataset_name', String),
15 | Column('dataset_version', Integer),
16 | Column('PPL', Float),
17 | Column('R3', Float),
18 | Column('R5', Float),
19 | Column('R-L', Float))
20 |
21 |
22 | metadata.create_all(ENGINE)
23 |
--------------------------------------------------------------------------------
/explainitall/metrics/RougeAndPPL/helpers.py:
--------------------------------------------------------------------------------
1 | import re
2 | import time
3 | from datetime import datetime
4 |
5 | import pandas as pd
6 | import sqlalchemy
7 |
8 |
9 | def calculate_average_metric_values(calculated_metrics):
10 | res = {}
11 |
12 | for metric in calculated_metrics:
13 | sub_metric_average_values = {}
14 | for text_evaluation_result in calculated_metrics[metric]:
15 | for sub_metric in text_evaluation_result:
16 | if sub_metric not in sub_metric_average_values:
17 | sub_metric_average_values[sub_metric] = 0.0
18 | sub_metric_average_values[sub_metric] += text_evaluation_result[sub_metric]
19 | for sub in sub_metric_average_values:
20 | sub_metric_average_values[sub] = sub_metric_average_values[sub] / len(calculated_metrics[metric])
21 | if 'f1' in sub_metric_average_values:
22 | res[metric] = sub_metric_average_values['f1']
23 | elif 'value' in sub_metric_average_values:
24 | res[metric] = sub_metric_average_values['value']
25 | else:
26 | res[metric] = 0.0
27 |
28 | return res
29 |
30 |
31 | def fbeta_score(precision, recall, beta=1):
32 | if precision == 0.0 and recall == 0.0:
33 | return 0.0
34 |
35 | if beta == 1:
36 | return 2 * precision * recall / (precision + recall)
37 | return (1 + beta * beta) * precision * recall / (beta * beta * precision + recall)
38 |
39 |
40 | def generate_candidates(model, tokenizer, sentences, max_length=128, max_new_tokens=100):
41 | candidates = []
42 | for sentence in sentences:
43 | encoded_input = tokenizer(sentence, truncation=True, max_length=max_length, return_tensors='pt')
44 | encoded_input = encoded_input.to(model.device)
45 | res = model.generate(**encoded_input, max_new_tokens=max_new_tokens)
46 | candidate = tokenizer.decode(res[0], skip_special_tokens=True)
47 | candidates.append(candidate)
48 | return candidates
49 |
50 |
51 | def get_all_words_from_text(text):
52 | words = re.findall(r"[\w]+", text)
53 | return words
54 |
55 |
56 | def get_max_dataset_version(dataset_name, conn, data_table):
57 | statement = sqlalchemy.select(sqlalchemy.func.max(data_table.c.dataset_version)).where(
58 | data_table.c.dataset_name == dataset_name)
59 |
60 | records = []
61 |
62 | for row in conn.execute(statement):
63 | records.append(row)
64 |
65 | res = records[0][0]
66 |
67 | if res is None:
68 | return -1
69 |
70 | return records[0][0]
71 |
72 |
73 | def get_records_from_database(data_table, conn, specific_column_value=None):
74 | records = []
75 |
76 | statement = data_table.select()
77 |
78 | if specific_column_value is not None:
79 | for k in specific_column_value:
80 | col = sqlalchemy.sql.column(k)
81 | statement = statement.where(col == specific_column_value[k])
82 |
83 | for row in conn.execute(statement):
84 | records.append(row)
85 |
86 | return records
87 |
88 |
89 | def insert_new_record(data_table, conn, model_name, dataset_name, dataset_version, metric_values):
90 | record = {'model_name': model_name,
91 | 'timestamp': int(time.time()),
92 | 'dataset_name': dataset_name,
93 | 'dataset_version': dataset_version}
94 | for m in metric_values:
95 | record[m] = metric_values[m]
96 |
97 | statement = data_table.insert().values(**record)
98 | conn.execute(statement)
99 | conn.commit()
100 |
101 |
102 | def make_dataframe_from_history_records(records):
103 | columns = ['model_name', 'date', 'dataset_name', 'dataset_version', 'PPL', 'R3', 'R5', 'R-L']
104 | res_records = []
105 |
106 | for rec in records:
107 | r = list(rec[1:])
108 | d = datetime.fromtimestamp(r[1])
109 | r[1] = str(d.day) + '.' + str(d.month) + '.' + str(d.year)
110 | for i in range(len(columns[4:])):
111 | r[i + 4] = round(r[i + 4], 2)
112 | res_records.append(r)
113 |
114 | df = pd.DataFrame(res_records, columns=columns)
115 |
116 | return df
117 |
118 |
119 | def split_text_by_whitespaces(text):
120 | text = text.replace('\n', ' ')
121 | text = re.sub(r'\W+', ' ', text)
122 | text = re.sub(r'\s+', ' ', text)
123 | tokens = re.split(r"\s", text)
124 | return tokens
125 |
126 |
127 | def words_n_gramm(text, n_gramm=3):
128 | tx = text.replace('\n', ' ')
129 | tx = re.sub(r' +', ' ', tx)
130 | w = tx.split(' ')
131 |
132 | ng = []
133 |
134 | for i, word in enumerate(w):
135 | ng.append(' '.join(w[i:i + n_gramm]))
136 |
137 | return list(set(ng))
138 |
--------------------------------------------------------------------------------
/explainitall/metrics/RougeAndPPL/metric_calculation_interface.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import pandas as pd
3 | from transformers import GPT2LMHeadModel, GPT2TokenizerFast, AutoTokenizer, AutoModel, AutoModelForCausalLM
4 |
5 | from explainitall.metrics.RougeAndPPL.Metrics import Metric_ppl, MetricRougeL, MetricRougeN
6 | from explainitall.metrics.RougeAndPPL.Metrics_calculator import Metrics_calculator
7 | from explainitall.metrics.RougeAndPPL.create_database import data_table, ENGINE
8 | from explainitall.metrics.RougeAndPPL.helpers import (
9 | get_max_dataset_version, generate_candidates, calculate_average_metric_values,
10 | insert_new_record, get_records_from_database, make_dataframe_from_history_records)
11 |
12 |
13 | class MetricCalculationInterface:
14 | demo_ = None
15 |
16 | model_checkbox_ = None
17 | metrics_checkbox_ = None
18 |
19 | calculator_ = None
20 |
21 | model_name_ = None
22 | model_ = None
23 | tokenizer_ = None
24 | model_successfully_loaded_ = None
25 |
26 | dataset_name_ = None
27 | dataset_version_ = None
28 | contexts_ = None
29 | references_ = None
30 | dataset_successfully_loaded_ = None
31 |
32 | conn_ = None
33 |
34 | def __init__(self):
35 | self.conn_ = ENGINE.connect()
36 |
37 | self.demo_ = gr.Blocks()
38 |
39 | with self.demo_:
40 | with gr.Tabs():
41 | with gr.TabItem("Load"):
42 | gr.Markdown("""
43 | Загрузите CSV файл с набором данных, который будет использоваться для анализа и оценки качества модели генерации текста
44 | Набор данных должен содержать как минимум две колонки:
45 | - `context`: текстовый контекст или вводные данные, на основе которых модель будет генерировать текст
46 | - `reference`: эталонный текст или ожидаемый ответ модели на заданный контекст
47 | Данный набор данных позволит оценить, насколько хорошо модель способна генерировать текст, соответствующий заданному контексту и эталонным ответам
48 | Пример файла example_data/metrix.csv
49 | """)
50 | with gr.Column():
51 | with gr.Row():
52 | with gr.Column():
53 | with gr.Row():
54 | with gr.Column():
55 | dataset_file = gr.File(label='Dataset (CSV)')
56 | dataset_title = gr.Text(label='Dataset title',
57 | info="Это поле обязательно к заполнению. Без указания названия датасета процесс не будет запущен.",
58 | placeholder="Введите название датасета")
59 |
60 | dataset_visualization = gr.Dataframe(label='Dataset',
61 | headers=['context', 'reference'],
62 | wrap=False,
63 | height=500)
64 | with gr.Row():
65 | dataset_load_button = gr.Button("Load dataset")
66 | dataset_checkbox = gr.Checkbox(label='dataset loaded', interactive=False)
67 |
68 | with gr.Row():
69 | with gr.Column():
70 | model_name_or_path = gr.Text(label='Model name or path',
71 | placeholder="distilgpt2",
72 | info="Введите название предварительно обученной модели (например, distilgpt2, gpt2 или sberbank-ai/rugpt3small_based_on_gpt2) или путь к вашей модели")
73 |
74 | with gr.Row():
75 | model_load_button = gr.Button("Load model")
76 | model_checkbox = gr.Checkbox(label='model loaded', interactive=False)
77 | with gr.Row():
78 | launch_button = gr.Button("Launch")
79 | metrics_checkbox = gr.Checkbox(label='metrics calculated, откройте вкладку Result и обновите данные', interactive=False)
80 |
81 | with gr.TabItem("Result"):
82 | with gr.Row():
83 | filter_field_dropdown = gr.Dropdown(["None", "model_name", "dataset_name"],
84 | label='Filter by field')
85 | filter_value_text = gr.Text(label='Filter value')
86 | refresh_button = gr.Button("Refresh")
87 | history_table = gr.Dataframe(
88 | headers=['model_name', 'date', 'dataset_name', 'dataset_version', 'PPL', 'R3', 'R5', 'R-L'],
89 | label='All Data')
90 |
91 | dataset_load_button.click(self.load_dataset_,
92 | inputs=[dataset_file, dataset_title],
93 | outputs=[dataset_visualization, dataset_checkbox])
94 | model_load_button.click(self.load_model_,
95 | inputs=[model_name_or_path],
96 | outputs=[model_checkbox])
97 | launch_button.click(self.calculate_metrics_,
98 | inputs=None,
99 | outputs=[metrics_checkbox])
100 | refresh_button.click(self.refresh_history_,
101 | inputs=[filter_field_dropdown, filter_value_text],
102 | outputs=[history_table])
103 |
104 | self.model_checkbox_ = model_checkbox
105 | self.dataset_checkbox_ = dataset_checkbox
106 | self.metrics_checkbox_ = metrics_checkbox
107 |
108 | def launch(self):
109 | self.demo_.launch(share=True, debug=False, server_name="127.0.0.1", inbrowser=True)
110 |
111 | def load_model_(self, model_name_or_path):
112 | if not model_name_or_path:
113 | print("Model name or path is required.", model_name_or_path)
114 |
115 | self.model_successfully_loaded_ = False
116 |
117 | self.tokenizer_ = AutoTokenizer.from_pretrained(model_name_or_path)
118 | if self.tokenizer_.pad_token is None:
119 | self.tokenizer_.pad_token = self.tokenizer_.eos_token
120 | self.model_ = AutoModelForCausalLM.from_pretrained(model_name_or_path)
121 |
122 | self.model_name_ = str(model_name_or_path)
123 |
124 | self.calculator_ = Metrics_calculator(self.tokenizer_)
125 |
126 | self.calculator_.add_metric('PPL', Metric_ppl(self.model_, stride=512))
127 |
128 | self.calculator_.add_metric('R3', MetricRougeN(3))
129 | self.calculator_.add_metric('R5', MetricRougeN(5))
130 |
131 | self.calculator_.add_metric('R-L', MetricRougeL(3))
132 |
133 | self.model_successfully_loaded_ = True
134 | return True
135 |
136 | def load_dataset_(self, csvfile, title):
137 | self.dataset_successfully_loaded_ = False
138 |
139 | print("csvfile", csvfile)
140 | print("title", title)
141 |
142 | if csvfile is None or title == '':
143 | return pd.DataFrame(), False
144 | dataframe = pd.read_csv(csvfile.name, delimiter=',', encoding='utf-8')
145 |
146 | self.dataset_name_ = title
147 | self.dataset_version_ = get_max_dataset_version(self.dataset_name_, self.conn_, data_table) + 1
148 |
149 | self.contexts_ = list(dataframe['context'].values)
150 | self.references_ = list(dataframe['reference'].values)
151 |
152 | self.dataset_successfully_loaded_ = True
153 | return dataframe.head(10), True
154 |
155 | def calculate_metrics_(self):
156 | if self.dataset_successfully_loaded_ is False or self.model_successfully_loaded_ is False:
157 | return False
158 |
159 | model = self.model_
160 | candidates = generate_candidates(model,
161 | self.tokenizer_,
162 | self.contexts_,
163 | model.config.n_positions,
164 | max_new_tokens=10)
165 |
166 | res = self.calculator_.calculate(self.contexts_, self.references_, candidates)
167 | metric_values = calculate_average_metric_values(res)
168 |
169 | insert_new_record(data_table=data_table,
170 | conn=self.conn_,
171 | model_name=self.model_name_,
172 | dataset_name=self.dataset_name_,
173 | dataset_version=self.dataset_version_,
174 | metric_values=metric_values)
175 |
176 |
177 | return True
178 |
179 | def refresh_history_(self, filter_field, filter_value):
180 | specific_column_value = None
181 | if filter_field != [] and filter_field is not None and filter_field != 'None' and filter_field != '':
182 | specific_column_value = {filter_field: filter_value}
183 |
184 | res = get_records_from_database(data_table, self.conn_, specific_column_value)
185 | df = make_dataframe_from_history_records(res)
186 |
187 | return df
188 |
--------------------------------------------------------------------------------
/explainitall/metrics/RougeAndPPL/rouge_L.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Optional
2 |
3 | def find_lowest_position_higher_then_current(pos_list: List[int], current_position: Optional[int]) -> Optional[int]:
4 | """Находит наименьшую позицию в списке, которая больше текущей"""
5 | for pos in pos_list:
6 | if current_position is None or pos > current_position:
7 | return pos
8 | return None
9 |
10 | def get_element_positions(sequence: List[int], element: int) -> List[int]:
11 | """Возвращает список позиций, на которых элемент встречается в последовательности"""
12 | return [i for i, v in enumerate(sequence) if v == element]
13 |
14 | def is_seq1_better_then_seq2(seq1: List[int], seq2: List[int]) -> bool:
15 | """Определяет, является ли первая последовательность лучше второй"""
16 | if seq1[1] < seq2[1]:
17 | return False
18 | if seq1[0] < seq2[0]:
19 | return True
20 | return False
21 |
22 | def unite_sequencies(existing_sequencies: List[List[int]], new_sequencies: List[List[int]]) -> List[List[int]]:
23 | """Объединяет существующие и новые последовательности, оставляя только лучшие"""
24 | i = len(existing_sequencies) - 1
25 | while i >= 0:
26 | g = len(new_sequencies) - 1
27 | while g >= 0:
28 | if is_seq1_better_then_seq2(existing_sequencies[i], new_sequencies[g]):
29 | del new_sequencies[g]
30 | elif is_seq1_better_then_seq2(new_sequencies[g], existing_sequencies[i]):
31 | del existing_sequencies[i]
32 | break
33 | g -= 1
34 | i -= 1
35 | return existing_sequencies + new_sequencies
36 |
37 | def update_existing_sequencies_with_next_element(existing_sequencies: List[List[int]], element: List[int]) -> List[List[int]]:
38 | """Обновляет существующие последовательности с учетом нового элемента"""
39 | new_sequencies = [[element[0], 1]]
40 | for existing_seq in existing_sequencies:
41 | appropriate_ind = find_lowest_position_higher_then_current(element, existing_seq[0])
42 | if appropriate_ind is not None:
43 | new_sequencies = unite_sequencies(new_sequencies, [[appropriate_ind, existing_seq[1] + 1]])
44 | return unite_sequencies(existing_sequencies, new_sequencies)
45 |
46 | def fbeta_score(precision: float, recall: float, beta: float = 1.0) -> float:
47 | """Вычисляет FBeta оценку (по-умолчанию F1) на основе precision и recall"""
48 | if precision == 0.0 and recall == 0.0:
49 | return 0.0
50 | if beta == 1.0:
51 | return 2 * precision * recall / (precision + recall)
52 | return (1 + beta * beta) * precision * recall / (beta * beta * precision + recall)
53 |
54 | def rouge_L(reference: List[int], candidate: List[int]) -> Dict[str, float]:
55 | """Вычисляет Rouge-L оценку между эталонной и кандидатской последовательностями."""
56 | reference_length = len(reference)
57 | candidate_length = len(candidate)
58 |
59 | if candidate_length == 0 or reference_length == 0:
60 | return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
61 |
62 | candidate_positions_in_reference = [get_element_positions(reference, el) for el in candidate]
63 |
64 | candidate_positions_in_reference = [positions for positions in candidate_positions_in_reference if positions]
65 |
66 | # Когда после фильтрации нет элементов
67 | if not candidate_positions_in_reference:
68 | return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
69 |
70 | existing_sequences = []
71 | for element_positions in candidate_positions_in_reference:
72 | existing_sequences = update_existing_sequencies_with_next_element(existing_sequences, element_positions)
73 |
74 | # Если нет существующих последовательностей
75 | if not existing_sequences:
76 | return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
77 |
78 | max_sequence_length = max(existing_sequences, key=lambda seq: seq[1])[1]
79 |
80 | precision = max_sequence_length / candidate_length
81 | recall = max_sequence_length / reference_length
82 | f1 = fbeta_score(precision, recall)
83 |
84 | return {'precision': precision, 'recall': recall, 'f1': f1}
85 |
--------------------------------------------------------------------------------
/explainitall/metrics/RougeAndPPL/rouge_N.py:
--------------------------------------------------------------------------------
1 | from itertools import islice
2 | from typing import List, Tuple, Iterator, Dict
3 |
4 | def split_into_overlapping_chunks(iterable: List[int], chunk_size: int) -> Iterator[Tuple[int, ...]]:
5 | """
6 | Разбивает последовательность на перекрывающиеся блоки заданного размера
7 | """
8 | iterator = iter(iterable)
9 | res = tuple(islice(iterator, chunk_size))
10 | if len(res) == chunk_size:
11 | yield res
12 | for el in iterator:
13 | res = res[1:] + (el,)
14 | yield res
15 |
16 | def list_intersection(l1: List[int], l2: List[int]) -> int:
17 | """
18 | Подсчитывает и возвращает количество общих элементов между двумя списками (l1 и l2),
19 | удаляя найденные совпадения из второго списка
20 | """
21 | l1_copy = l1[:]
22 | l2_copy = l2[:]
23 |
24 | res = 0
25 | for el in l1_copy:
26 | i = len(l2_copy) - 1
27 | while i >= 0:
28 | if el == l2_copy[i]:
29 | del l2_copy[i]
30 | res += 1
31 | break
32 | i -= 1
33 |
34 | return res
35 |
36 | def rouge_N(reference: List[int], candidate: List[int], n: int) -> Dict[str, float]:
37 | """
38 | Вычисляет Rouge-N оценку между эталонной последовательностью и кандидатом
39 |
40 | :param reference: Эталонная последовательность
41 | :param candidate: Кандидатская последовательность
42 | :param n: Размер n-грамм
43 | """
44 | reference_chunks = list(split_into_overlapping_chunks(reference, n))
45 | reference_length = len(reference_chunks)
46 |
47 | candidate_chunks = list(split_into_overlapping_chunks(candidate, n))
48 | candidate_length = len(candidate_chunks)
49 |
50 | if candidate_length == 0 or reference_length == 0:
51 | return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
52 |
53 | intersection = list_intersection(reference_chunks, candidate_chunks)
54 |
55 | precision = intersection / candidate_length
56 | recall = intersection / reference_length
57 | f1 = 2*precision * recall / (precision + recall)
58 |
59 | return {'precision': precision, 'recall': recall, 'f1': f1}
60 |
--------------------------------------------------------------------------------
/explainitall/nlp.py:
--------------------------------------------------------------------------------
1 | from functools import lru_cache
2 |
3 | import gensim
4 | import pymorphy2
5 |
6 |
7 | class WordProcessor:
8 |
9 | def __init__(self, gensim_nlp_embeddings: gensim.models.keyedvectors.KeyedVectors):
10 | self._morph = pymorphy2.MorphAnalyzer()
11 | self.gensim_nlp_embeddings = gensim_nlp_embeddings
12 |
13 | @lru_cache(maxsize=None)
14 | def get_clean_word(self, word: str):
15 | return word.lower().strip()
16 |
17 | # @lru_cache(maxsize=None)
18 | def get_embeddable_words_batch(self, words: list):
19 | cleaned = [self.get_clean_word(w) for w in words]
20 | embeddable = [self.get_embeddable_word_or_none(w) for w in cleaned if w]
21 | return [w for w in embeddable if w]
22 |
23 | @lru_cache(maxsize=None)
24 | def get_morph_or_none(self, word: str):
25 | morphed = self._morph.parse(word)
26 | known_words = [x for x in morphed if x.is_known]
27 | if not known_words:
28 | return None
29 | return known_words[0]
30 |
31 | @lru_cache(maxsize=None)
32 | def get_normal_form_or_none(self, word: str):
33 | clean_word = self.get_clean_word(word)
34 | if not clean_word:
35 | return None
36 | word_morph = self.get_morph_or_none(clean_word)
37 | if not word_morph:
38 | return None
39 | return word_morph.normal_form
40 |
41 | @lru_cache(maxsize=None)
42 | def get_grammeme_or_none(self, word: str):
43 | clean_word = self.get_clean_word(word)
44 | if not clean_word:
45 | return None
46 | word_morph = self.get_morph_or_none(clean_word)
47 | if not word_morph:
48 | return None
49 | tag = word_morph.tag.POS
50 | tag = 'VERB' if word_morph and tag == 'INFN' else tag
51 | return tag
52 |
53 | def get_embeddable_word_or_none(self, word: str):
54 | normal_form = self.get_normal_form_or_none(word)
55 | grammeme = self.get_grammeme_or_none(word)
56 | if not normal_form or not grammeme:
57 | return None
58 | word_tagged = f'{normal_form}_{grammeme}'
59 | if word_tagged not in self.gensim_nlp_embeddings:
60 | return None
61 | return word_tagged
62 |
--------------------------------------------------------------------------------
/explainitall/stat_helpers.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional, Dict
3 |
4 | import numpy as np
5 | from sklearn import mixture
6 | from sklearn.mixture import GaussianMixture
7 |
8 |
9 | def rayleigh_el(el: float, disp: float) -> float:
10 | return (el / disp) * math.exp(-el ** 2 / disp)
11 |
12 |
13 | def rayleigh(arr: np.ndarray, dispersion: float) -> np.ndarray:
14 | return np.array(list(map(lambda x: rayleigh_el(x, dispersion), arr)))
15 |
16 |
17 | def rayleigh_el_integral(el: float, disp: float) -> float:
18 | return 1 - math.exp(-el ** 2 / disp)
19 |
20 |
21 | def rayleigh_integral(arr: np.ndarray, dispersion: float) -> np.ndarray:
22 | return np.array(list(map(lambda x: rayleigh_el_integral(x, dispersion), arr)))
23 |
24 |
25 | def gaussian_integral_single(element, mean, std, sqrt2):
26 | return 0.5 + 0.5 * math.erf((element - mean) / (sqrt2 * std))
27 |
28 |
29 | def gaussian_mixture_integral_single(element: float, gaussian_mixture_model: GaussianMixture, sqrt2: float) -> float:
30 | means = gaussian_mixture_model.means_
31 | weights = gaussian_mixture_model.weights_
32 | variances = gaussian_mixture_model.covariances_[:, 0]
33 |
34 | integral_element = np.sum([
35 | weight * gaussian_integral_single(mean=mean[0], std=np.sqrt(variance[0]), element=element, sqrt2=sqrt2)
36 | for weight, mean, variance in zip(weights, means, variances)
37 | ])
38 | return float(integral_element)
39 |
40 |
41 | def gaussian_mixture_integral(arr: np.ndarray, gmm: GaussianMixture) -> np.ndarray:
42 | sqrt2 = math.sqrt(2)
43 | if len(arr.shape) == 1:
44 | return np.array([gaussian_mixture_integral_single(x, gmm, sqrt2) for x in arr])
45 |
46 | if len(arr.shape) == 2:
47 | dim_1, dim_2 = arr.shape
48 | reshaped_arr = arr.reshape((dim_1 * dim_2))
49 | reshaped_arr = np.array([gaussian_mixture_integral_single(x, gmm, sqrt2) for x in reshaped_arr])
50 | return reshaped_arr.reshape((dim_1, dim_2))
51 |
52 |
53 | def denormalize_array(array: np.ndarray) -> np.ndarray:
54 | edited = array.copy()
55 | non_zero_counts = np.count_nonzero(edited, axis=0)
56 | edited *= non_zero_counts
57 | return edited
58 |
59 |
60 | def calc_gauss_mixture_stat_params(array: np.ndarray,
61 | num_components: int = 3,
62 | seed: Optional[int] = None) -> np.ndarray:
63 | "Рассчет нового массива на базе гауссовой смеси"
64 | d_array_2d = denormalize_array(array)
65 |
66 | d_array_1d = d_array_2d[~np.isnan(array)]
67 | d_array_1d = d_array_1d.reshape(len(d_array_1d), 1)
68 |
69 | gmm = mixture.GaussianMixture(n_components=num_components, random_state=0)
70 | gmm.fit(d_array_1d)
71 | return gaussian_mixture_integral(d_array_2d, gmm)
72 |
73 |
74 | def calc_gmm_stat_params(array: np.ndarray) -> Dict:
75 | array_1d = array[np.logical_not(np.isnan(array))]
76 |
77 | mean = float(np.mean(array_1d))
78 | std = float(np.std(array_1d))
79 | new_arr = calc_gauss_mixture_stat_params(array)
80 |
81 | return {'new_arr': new_arr, "mean": mean, "std": std}
82 |
83 |
84 | def compute_gaussian_integral(array: np.ndarray, mean: float, std_dev: float) -> np.ndarray:
85 | """Интегральная Гауссова функция для 1д-2д массива"""
86 | sqrt2 = math.sqrt(2)
87 | vectorized_gaussian = np.vectorize(gaussian_integral_single)
88 | array_1d = array.flatten()
89 | result_array = vectorized_gaussian(array_1d, mean, std_dev, sqrt2)
90 | return result_array.reshape(array.shape)
91 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import gensim
2 | from inseq import load_model
3 |
4 | from explainitall.gpt_like_interp import viz, interp
5 | from explainitall.gpt_like_interp.downloader import DownloadManager
6 |
7 |
8 | def load_nlp_model(nlp_model_url):
9 | nlp_model_path = DownloadManager.load_zip(nlp_model_url)
10 | return gensim.models.KeyedVectors.load_word2vec_format(nlp_model_path, binary=True)
11 |
12 |
13 | # 'ID': 180
14 | # 'Размер вектора': 300
15 | # 'Корпус': 'Russian National Corpus'
16 | # 'Размер словаря': 189193
17 | # 'Алгоритм': 'Gensim Continuous Bag-of-Words'
18 | # 'Лемматизация': True
19 |
20 | nlp_model = load_nlp_model('http://vectors.nlpl.eu/repository/20/180.zip')
21 |
22 |
23 | def load_gpt_model(gpt_model_name):
24 | return load_model(model=gpt_model_name,
25 | attribution_method="integrated_gradients")
26 |
27 |
28 | # 'Фреймворк': 'transformers'
29 | # 'Тренировочные токены': '80 млрд'
30 | # 'Размер контекста': 2048
31 | gpt_model = load_gpt_model("sberbank-ai/rugpt3small_based_on_gpt2")
32 |
33 | clusters_discr = [
34 | {'name': 'Животные', 'centroid': ['собака', 'кошка', 'заяц'], 'top_k': 140},
35 | {'name': 'Лекарства', 'centroid': ['уколы', 'таблетки', 'микстуры'], 'top_k': 160},
36 | {'name': 'Болезни', 'centroid': ['простуда', 'орви', 'орз', 'грипп'], 'top_k': 20},
37 | {'name': 'Аллергия', 'centroid': ['аллергия'], 'top_k': 20}
38 | ]
39 |
40 | explainer = interp.ExplainerGPT2(gpt_model=gpt_model, nlp_model=nlp_model)
41 |
42 | expl_data = explainer.interpret(
43 | input_texts='у кошки грипп и аллергия на антибиотбиотики вопрос: чем лечить кошку? ответ:',
44 | generated_texts='лечичичичите ее уколами',
45 | clusters_description=clusters_discr,
46 | batch_size=50,
47 | steps=34,
48 | # max_new_tokens=19
49 | )
50 |
51 | print("\nКарта важности кластеров")
52 | print(expl_data.cluster_imp_df)
53 |
54 | print("\nТепловая карта важности кластеров")
55 | expl_data.show_cluster_imp_heatmap()
56 |
57 | print("\nКарта важности кластеров, группированная")
58 | print(expl_data.cluster_imp_aggr_df)
59 |
60 | print("\nТепловая карта важности кластеров, группированная")
61 | expl_data.show_cluster_imp_aggr_heatmap()
62 |
63 | print("\nКарта важности слов")
64 | print(expl_data.word_imp_df)
65 |
66 | print("\nТепловая карта важности слов")
67 | expl_data.show_word_imp_heatmap()
68 |
69 | print("\nКарта важности слов, нормированная")
70 | print(expl_data.word_imp_norm_df)
71 |
72 | print("\nТепловая карта важности слов, нормированная")
73 | expl_data.show_word_imp_norm_heatmap()
74 |
75 | print("\nГистограмма распределения")
76 | viz.show_distribution_histogram(expl_data.attributions.array)
77 | print("\nГрафик распределения")
78 | viz.show_distribution_plot(expl_data.attributions.array)
79 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | addopts = -v
3 | python_files = test_*.py
4 | python_functions = test_*
5 | testpaths = tests
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | gensim==4.3.2
2 | gradio==4.4.1
3 | gradio_client==0.7.0
4 | inseq==0.5.0
5 | matplotlib==3.8.4
6 | nltk==3.8.1
7 | pandas==2.2.1
8 | pymorphy2==0.9.1
9 | scikit-learn==1.4.1.post1
10 | seaborn==0.13.2
11 | sentence_transformers==2.6.1
12 | SQLAlchemy==2.0.29
13 | statsmodels==0.14.1
14 | torch>=2.0
15 | transformers>=4.39
16 | uvicorn==0.29
17 | scipy==1.12.0
18 | numpy==1.26.4
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | try:
2 | from setuptools import setup, find_namespace_packages
3 | except ImportError:
4 | from distutils.core import setup, find_namespace_packages
5 |
6 | REQUIREMENTS = [i.strip() for i in open("requirements.txt").readlines()]
7 |
8 | setup(
9 | name='explainitall',
10 | version='1.0.2',
11 | long_description=open('README.md', encoding='utf-8', errors='ignore').read(),
12 | long_description_content_type='text/markdown',
13 | packages=find_namespace_packages(include=['explainitall', 'explainitall.*']),
14 | install_requires=REQUIREMENTS
15 | )
16 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Bots-Avatar/ExplainitAll/0339ea5c09c3cd309d53c23b403465c821a778d0/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_inseq_helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from explainitall.gpt_like_interp import inseq_helpers
5 |
6 | test_data = [
7 | (['Reference! site - a-bout Lorem-Ipsum,',
8 | [' ref', '!ere_', 'nce((', '__site', 'a-bout__', '@Lorem', '*Ipsum'],
9 | [['Ref', 'ere', 'nce'], ['site'], ['a-bout'], ['Lorem', ('-', 0), 'Ipsum']]]),
10 | (['Reference! q321 - a-bout Lorem-Ipsum,',
11 | [' ref', '!ere_', 'nce((', '__q321', 'a-bout__', '@Lorem', '*Ipsum'],
12 | [['Ref', 'ere', 'nce'], ['q321'], ['a-bout'], ['Lorem', ('-', 0), 'Ipsum']]]),
13 | (['giving information on its origins, as well ',
14 | ['🃯🃯giving🃯', 'ÐinformationÐ ', 'on__', 'its', '__origin', '&&&s', 'as', 'well'],
15 | [['giving'], ['information'], ['on'], ['its'], ['origin', 's'], ['as'], ['well']]]),
16 | ([' as a random Lipsum generator.',
17 | ['as🃯🃯', 'a', 'Ðrandom', '🃯lipsum', '###generator!'],
18 | [['as'], ['a'], ['random'], ['Lipsum'], ['generator']]]),
19 | (['Папа у Васи чуть-чуть силён в Математике!',
20 | ['папа🃯🃯', 'у', 'Ðва', '🃯с', '###и!', '!чуть', '-', 'чуть', 'силён', '#в!', 'математике'],
21 | [['Папа'], ['у'], ['Ва', 'с', 'и'], ['чуть', ('-', 0), '', 'чуть'], ['силён'], ['в'], ['Математике']]])]
22 |
23 |
24 | # Apply parametrization
25 | @pytest.mark.parametrize("inp_text,inp_pairs,expected_rez", test_data)
26 | def test_detokenizer(inp_text, inp_pairs, expected_rez):
27 | fact_rez = inseq_helpers.Detokenizer(inp_text, inp_pairs).group_text()
28 | assert fact_rez == expected_rez
29 |
30 |
31 | def test_squash_arr():
32 | array = np.array([[4., 7., 3., 8., 6.],
33 | [4., 9., 7., 4., 1.],
34 | [2., 6., 0., 0., 7.],
35 | [2., 5., 5., 4., 8.],
36 | [1., 4., 6., 10., 4.]])
37 |
38 | squashed_arr = inseq_helpers.squash_arr(arr=array,
39 | squash_row_mask=[[0, 1], [1, 5]],
40 | squash_col_mask=[[0, 3], [3, 5]],
41 | aggr_f=np.max)
42 |
43 | expected = np.array([[7., 8.],
44 | [9., 10.]])
45 | np.testing.assert_array_equal(squashed_arr, expected)
46 |
47 |
48 | def test_calculate_mask():
49 | grouped_data = [[1, 2, 3], [11, 22], [333]]
50 | expected_output = [[0, 3], [3, 5], [5, 6]]
51 | assert inseq_helpers.calculate_mask(grouped_data) == expected_output
52 |
--------------------------------------------------------------------------------
/tests/test_stat_helpers.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import pytest
5 | from sklearn import mixture
6 |
7 | from explainitall import (stat_helpers)
8 |
9 |
10 | def test_rayleigh_el():
11 | el = 1.0
12 | disp = 2.0
13 | expected = 0.303
14 | result = stat_helpers.rayleigh_el(el, disp)
15 | assert math.isclose(result, expected, rel_tol=1e-3)
16 |
17 |
18 | @pytest.mark.parametrize("arr, dispersion, expected", [
19 | (np.array([1, 2, 3]), 2.0, np.array([0.303, 0.135, 0.016])),
20 | ])
21 | def test_rayleigh(arr, dispersion, expected):
22 | result = stat_helpers.rayleigh(arr, dispersion)
23 | np.testing.assert_almost_equal(result, expected, decimal=3)
24 |
25 |
26 | @pytest.mark.parametrize("el, disp, expected", [
27 | (1.0, 2.0, 0.3934),
28 | (2.0, 1.5, 0.9305),
29 | (0.5, 3.0, 0.0799)
30 | ])
31 | def test_rayleigh_el_int(el, disp, expected):
32 | result = stat_helpers.rayleigh_el_integral(el, disp)
33 | assert math.isclose(result, expected, rel_tol=1e-3)
34 |
35 |
36 | @pytest.mark.parametrize("arr, dispersion, expected", [
37 | (np.array([1, 2, 3]), 1.0, np.array([0.6321, 0.9816, 0.99987])),
38 | (np.array([0, 0.5, 1]), 0.5, np.array([0.0, 0.39346, 0.8646]))])
39 | def test_rayleigh_int(arr, dispersion, expected):
40 | result = stat_helpers.rayleigh_integral(arr, dispersion)
41 | np.testing.assert_almost_equal(result, expected, decimal=4)
42 |
43 |
44 | @pytest.mark.parametrize("el, mu, std, sqrt2, expected", [
45 | (1.0, 2.0, 0.3934, math.sqrt(2), 0.0055119),
46 | (2.0, 1.5, 0.9305, math.sqrt(2), 0.7044855),
47 | (0.5, 3.0, 0.0799, math.sqrt(2), 0.0)
48 | ])
49 | def test_gauss_integral_element(el, mu, std, sqrt2, expected):
50 | result = stat_helpers.gaussian_integral_single(el, mu, std, sqrt2)
51 | assert math.isclose(result, expected, rel_tol=1e-3)
52 |
53 |
54 | def test_gauss_m_integral_element():
55 | arr = np.array([1, 2, 3, 4, 5])
56 | gmm = mixture.GaussianMixture(n_components=2, random_state=0)
57 | gmm.fit(arr.reshape(-1, 1), 1)
58 | el = 1
59 | sqrt2 = math.sqrt(2)
60 | result = stat_helpers.gaussian_mixture_integral_single(el, gmm, sqrt2)
61 | expected_result = 0.07186
62 | assert math.isclose(result, expected_result, rel_tol=1e-3)
63 |
64 |
65 | def test_gauss_m_integral_1D():
66 | arr = np.array([1, 2, 3, 4, 5])
67 | gmm = mixture.GaussianMixture(n_components=2, random_state=0)
68 | gmm.fit(arr.reshape(-1, 1), 1)
69 | expected_result = np.array([0.0719, 0.2886, 0.5261, 0.6753, 0.9324])
70 | result = stat_helpers.gaussian_mixture_integral(arr=arr, gmm=gmm)
71 | np.testing.assert_almost_equal(result, expected_result, decimal=4)
72 |
73 |
74 | def test_gauss_m_integral_2D():
75 | arr = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
76 | gmm = mixture.GaussianMixture(n_components=2, random_state=0)
77 | gmm.fit(arr.reshape(-1, 1), 1)
78 | expected_result = np.array([[0.1676, 0.4391, 0.8942],
79 | [0.1676, 0.4391, 0.8942],
80 | [0.1676, 0.4391, 0.8942]])
81 | result = stat_helpers.gaussian_mixture_integral(arr=arr, gmm=gmm)
82 | np.testing.assert_almost_equal(result, expected_result, decimal=4)
83 |
84 |
85 | def test_gauss_integral_1D():
86 | arr = np.array([1, 2, 3, 4, 5])
87 | mu = 3
88 | std = 1
89 | expected_result = np.array([0.0227, 0.1586, 0.5000, 0.8413, 0.9772])
90 | result = stat_helpers.compute_gaussian_integral(arr, mu, std)
91 | np.testing.assert_almost_equal(result, expected_result, decimal=4)
92 |
93 | def test_gauss_integral_2D():
94 | arr = np.array([[1, 2], [3, 4]])
95 | mu = 2
96 | std = 1
97 | expected_result = np.array([[0.15865525, 0.5], [0.84134475, 0.97724987]])
98 | result = stat_helpers.compute_gaussian_integral(arr, mu, std)
99 | np.testing.assert_almost_equal(result, expected_result, decimal=4)
100 |
101 |
102 | def test_de_normalizy():
103 | # Test case 1: Normalized array with non-zero values
104 | arr = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
105 | expected = np.array([[2., 4., 6.], [8., 10., 12.]])
106 | np.testing.assert_almost_equal(stat_helpers.denormalize_array(arr), expected, decimal=1)
107 |
108 |
109 | def test_calc_gauss_mixture_stat_params():
110 | arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
111 | result = stat_helpers.calc_gauss_mixture_stat_params(arr, num_components=1, seed=0)
112 | expected = np.array([[0.061, 0.123, 0.219], [0.349, 0.5, 0.651], [0.781, 0.877, 0.939]])
113 | np.testing.assert_almost_equal(result, expected, decimal=3)
114 |
--------------------------------------------------------------------------------