├── .dockerignore
├── .github
└── workflows
│ └── build.yaml
├── .gitignore
├── LICENSE
├── README.md
├── autotm
├── __init__.py
├── abstract_params.py
├── algorithms_for_tuning
│ ├── __init__.py
│ ├── bayesian_optimization
│ │ ├── __init__.py
│ │ └── bayes_opt.py
│ ├── genetic_algorithm
│ │ ├── __init__.py
│ │ ├── config.yaml
│ │ ├── crossover.py
│ │ ├── ga.py
│ │ ├── genetic_algorithm.py
│ │ ├── mutation.py
│ │ ├── selection.py
│ │ ├── statistics_collector.py
│ │ ├── strategy.py
│ │ └── surrogate.py
│ ├── individuals.py
│ └── nelder_mead_optimization
│ │ ├── __init__.py
│ │ └── nelder_mead.py
├── base.py
├── batch_vect_utils.py
├── clustering.py
├── content_splitter.py
├── data_generator
│ └── synthetic_data_generation.ipynb
├── fitness
│ ├── __init__.py
│ ├── base_score.py
│ ├── cluster_tasks.py
│ ├── estimator.py
│ ├── external_scores.py
│ ├── local_tasks.py
│ └── tm.py
├── graph_ga.py
├── infer.py
├── main.py
├── main_fitness_worker.py
├── make-recovery-file.py
├── params.py
├── params_logging_utils.py
├── pipeline.py
├── preprocessing
│ ├── __init__.py
│ ├── cooc.py
│ ├── dictionaries_preparation.py
│ └── text_preprocessing.py
├── schemas.py
├── utils.py
└── visualization
│ ├── __init__.py
│ └── dynamic_tracker.py
├── cluster
├── bin
│ └── fitnessctl
├── charts
│ └── autotm
│ │ ├── .helmignore
│ │ ├── Chart.yaml
│ │ ├── templates
│ │ ├── _helpers.tpl
│ │ ├── configmaps.yaml
│ │ ├── deployments.yaml
│ │ ├── pvc.yaml
│ │ └── services.yaml
│ │ └── values.yaml
├── conf
│ └── pv.yaml
└── docker
│ ├── flower.dockerfile
│ ├── jupyter.dockerfile
│ ├── mlflow-webserver.dockerfile
│ └── worker.dockerfile
├── conf
└── config.yaml
├── data
└── sample_corpora
│ ├── clean_docs_v17_gost_only.csv
│ ├── dataset_books_stroitelstvo.csv
│ ├── imdb_100.csv
│ ├── imdb_1000.csv
│ ├── partition_df.csv
│ └── sample_dataset_lenta.csv
├── distributed
├── README.md
├── autotm_distributed
│ ├── __init__.py
│ ├── deploy_config_generator.py
│ ├── main.py
│ ├── metrics.py
│ ├── params_logging_utils.py
│ ├── preprocessing.py
│ ├── schemas.py
│ ├── tasks.py
│ ├── test_app.py
│ ├── tm.py
│ └── utils.py
├── bin
│ ├── fitnessctl
│ ├── kube-fitnessctl
│ └── trun.sh
├── deploy
│ ├── ess-small-datasets-config.yaml
│ ├── file.yaml
│ ├── fitness-worker-health-checker.yaml
│ ├── kube-fitness-client-job.yaml
│ ├── kube-fitness-client-job.yaml.j2
│ ├── kube-fitness-workers.yaml
│ ├── kube-fitness-workers.yaml.j2
│ ├── mlflow-minikube-persistent-volumes.yaml
│ ├── mlflow-persistent-volumes.yaml
│ ├── mlflow.yaml
│ ├── mongo_dev
│ │ ├── docker-compose.yaml
│ │ └── wait-for.sh
│ ├── ns.yaml
│ └── test-datasets-config.yaml
├── docker
│ ├── base.dockerfile
│ ├── cli.dockerfile
│ ├── client.dockerfile
│ ├── flower.dockerfile
│ ├── health-checker.dockerfile
│ ├── mlflow-webserver.dockerfile
│ ├── test.dockerfile
│ └── worker.dockerfile
├── poetry.lock
├── pyproject.toml
├── setup.py
└── test
│ └── topic_model_test.py
├── docs
├── Makefile
├── _static
│ └── style.css
├── _templates
│ ├── autosummary
│ │ ├── class.rst
│ │ └── module.rst
│ ├── classtemplate.rst
│ └── functiontemplate.rst
├── conf.py
├── img
│ ├── MyLogo.png
│ ├── autotm_arch_v3 (1).png
│ ├── img_library_eng.png
│ ├── photo_2023-06-29_20-20-57.jpg
│ ├── photo_2023-06-30_14-17-08.jpg
│ ├── pipeling.png
│ └── strategy.png
├── index.rst
├── make.bat
└── pages
│ ├── algorithms_for_tuning.rst
│ ├── api
│ ├── algorithms_for_tuning.rst
│ ├── fitness.rst
│ ├── index.rst
│ ├── preprocessing.rst
│ └── visualization.rst
│ ├── fitness.rst
│ ├── installation.rst
│ ├── preprocessing.rst
│ ├── userguide
│ ├── index.rst
│ ├── metrics.rst
│ └── regularizers.rst
│ └── visualization.rst
├── examples
├── demo
│ ├── autotm_demo_ru_OLD.ipynb
│ └── autotm_demo_updated_clustering.ipynb
├── demo_autotm.ipynb
├── examples_autotm_fit_predict.py
├── graph_building_for_stroitelstvo.py
├── graph_profile_example.py
└── topic_modeling_of_corporative_data.py
├── logging.config.yml
├── poetry.lock
├── pyproject.toml
├── scripts
├── __init__.py
├── algs
│ ├── full_pipeline_example.py
│ └── nelder_mead_experiments.py
├── experiments
│ ├── __init__.py
│ ├── analysis.ipynb
│ ├── experiment.py
│ ├── plot_boxplot_Final_results.png
│ ├── plot_boxplot_Start_results.png
│ ├── plot_progress_hotel-reviews_sample.png
│ ├── plots.ipynb
│ └── statistics
│ │ ├── 240602-124458_hotel-reviews_sample_fixed_parameters.txt
│ │ ├── 240602-124458_hotel-reviews_sample_fixed_progress.txt
│ │ ├── 240602-124539_hotel-reviews_sample_fixed_parameters.txt
│ │ ├── 240602-124539_hotel-reviews_sample_fixed_progress.txt
│ │ ├── 240602-124633_hotel-reviews_sample_fixed_parameters.txt
│ │ ├── 240602-124633_hotel-reviews_sample_fixed_progress.txt
│ │ ├── 240602-124719_hotel-reviews_sample_pipeline_parameters.txt
│ │ ├── 240602-124719_hotel-reviews_sample_pipeline_progress.txt
│ │ ├── 240602-124816_hotel-reviews_sample_pipeline_parameters.txt
│ │ ├── 240602-124816_hotel-reviews_sample_pipeline_progress.txt
│ │ ├── 240602-124950_hotel-reviews_sample_pipeline_parameters.txt
│ │ ├── 240602-124950_hotel-reviews_sample_pipeline_progress.txt
│ │ ├── 240603-231051_hotel-reviews_sample_fixed_parameters.txt
│ │ ├── 240603-231154_hotel-reviews_sample_fixed_parameters.txt
│ │ ├── 240603-231154_hotel-reviews_sample_fixed_progress.txt
│ │ ├── 240603-231313_hotel-reviews_sample_pipeline_parameters.txt
│ │ ├── 240603-231313_hotel-reviews_sample_pipeline_progress.txt
│ │ ├── 240603-232451_hotel-reviews_sample_fixed_parameters.txt
│ │ ├── 240603-232451_hotel-reviews_sample_fixed_progress.txt
│ │ ├── 240603-232546_hotel-reviews_sample_fixed_parameters.txt
│ │ ├── 240603-232546_hotel-reviews_sample_fixed_progress.txt
│ │ ├── 240603-232733_hotel-reviews_sample_fixed_parameters.txt
│ │ ├── 240603-232733_hotel-reviews_sample_fixed_progress.txt
│ │ ├── 240603-232847_hotel-reviews_sample_pipeline_parameters.txt
│ │ ├── 240603-232916_hotel-reviews_sample_fixed_parameters.txt
│ │ ├── 240603-232916_hotel-reviews_sample_fixed_progress.txt
│ │ ├── 240603-233132_hotel-reviews_sample_pipeline_parameters.txt
│ │ └── 240603-233132_hotel-reviews_sample_pipeline_progress.txt
├── other
│ ├── big_sample_4class_avg.csv
│ ├── preparation_pipeline.py
│ ├── sample_3class.csv
│ ├── sample_4class_avg.csv
│ └── toloka_markup_analysis.ipynb
└── topic_modeling_of_corporative_data.py
├── tests
├── integration
│ ├── __init__.py
│ └── test_fit_predict.py
└── unit
│ ├── __init__.py
│ ├── conftest.py
│ ├── test_cooc.py
│ ├── test_dictionaries_preparation.py
│ ├── test_llm_fitness.py
│ └── test_preprocessing.py
└── toloka
├── Estimate_topics_interpretability
├── input-data.json
├── instructions.md
├── output-data.json
├── task.css
├── task.html
└── task.js
└── MarkupPreparation.ipynb
/.dockerignore:
--------------------------------------------------------------------------------
1 | ./Notebooks
2 | ./src/algorithms_for_tuning/genetic_algorithm/resources/
3 | ./venv
4 | ./src/irace_config/ga_test
5 |
--------------------------------------------------------------------------------
/.github/workflows/build.yaml:
--------------------------------------------------------------------------------
1 | name: build
2 | run-name: ${{ github.repository }} installation test
3 |
4 | on: [push]
5 |
6 | jobs:
7 | build:
8 | runs-on: ubuntu-latest
9 | strategy:
10 | fail-fast: false
11 | matrix:
12 | python-version: ["3.9", "3.10", "3.11"]
13 |
14 | steps:
15 | - run: echo "🎉 The job was automatically triggered by a ${{ github.event_name }} event."
16 | - uses: actions/checkout@v3
17 | - name: List files in the repository
18 | run: |
19 | ls ${{ github.workspace }}
20 | - name: Set up Python ${{ matrix.python-version }}
21 | uses: actions/setup-python@v4
22 | with:
23 | python-version: ${{ matrix.python-version }}
24 | - name: Using python version
25 | run: python --version
26 | # - name: Install pip
27 | # run: apt install -y python3-pip
28 | - name: Install dependencies for tests
29 | run: |
30 | python -m pip install --upgrade pip
31 | python -m pip install flake8 poetry
32 | poetry install
33 | - name: Download english corpus
34 | run: poetry run python -m spacy download en_core_web_sm
35 | - name: Setting language and locale
36 | run: |
37 | sudo apt-get update
38 | sudo apt-get install -y locales
39 | sudo locale-gen ru_RU.UTF-8
40 | sudo update-locale
41 | export LC_ALL="ru_RU.UTF-8"
42 | export LANG="ru_RU.UTF-8"
43 | export LANGUAGE="ru_RU.UTF-8"
44 | - name: set pythonpath
45 | run: |
46 | echo "PYTHONPATH=." >> $GITHUB_ENV
47 | - name: Lint with flake8
48 | run: |
49 | flake8 autotm --count --select=E9,F63,F7,F82 --show-source --statistics
50 | - name: Run test code
51 | run: |
52 | poetry run pytest tests
53 | - run: echo "🍏 This job's status is ${{ job.status }}."
54 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Sphinx
132 | docs/generated
133 |
134 | autotm/algorithms_for_tuning/genetic_algorithm/resources/*
135 | .idea
136 | algorithms_for_tuning/genetic_algorithm/logs/*
137 | algorithms_for_tuning/genetic_algorithm/bigartm.*
138 | logs
139 | new_logs
140 | resources
141 | *.whl
142 | surrogate_logs/*
143 | *.orig
144 | *.pickle
145 | Notebooks/
146 | logs/
147 | examples/mlruns/
148 | data/experiment_datasets/
149 | data/processed_sample_corpora
150 | data/sample_corpora/*.txt
151 | data/sample_corpora/*.pkl
152 | data/sample_corpora/batches/*
153 | examples/tmp
154 | examples/metrics
155 | examples/*.txt
156 | examples/out
157 | bigartm.*
158 | coverage_re/
159 | tests/integration/mlruns/*
160 | tests/integration/metrics
161 | **/mlruns/*
162 | **/metrics/*
163 | autotm_workdir_*
164 | requirements.txt
165 | *.patch
166 |
167 | mixtures.csv
168 | model.artm
169 |
170 | # sphinx docs
171 | docs/_build
172 |
173 | examples/experiments/*.txt
174 | examples/experiments/*/*.txt
175 | examples/experiments/*.png
176 | examples/experiments/*/*.png
177 | scripts/experiments/metrics
178 | scripts/experiments/mlruns
179 |
180 | rsync-repo.sh
181 |
182 | tmp
183 |
184 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2023, Industrial AI Lab Team
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | AutoTM
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 | :sparkles:**News**:sparkles: We have fully updated our framework to AutoTM 2.0 version enriched with new functionality! Stay tuned!
30 |
31 | Automatic parameters selection for topic models (ARTM approach) using evolutionary and bayesian algorithms.
32 | AutoTM provides necessary tools to preprocess english and russian text datasets and tune additively regularized topic models.
33 |
34 | ## What is AutoTM?
35 | Topic modeling is one of the basic methods for a whole range of tasks:
36 |
37 | * Exploratory data analysis of unlabelled text data
38 | * Extracting interpretable features (topics and their combinations) from text data
39 | * Searching for hidden insights in the data
40 |
41 | While ARTM (additive regularization for topic models) approach provides the significant flexibility and quality comparative or better that neural
42 | approaches it is hard to tune such models due to amount of hyperparameters and their combinations. That is why we provide optimization pipelines to efortlessly process custom datasets.
43 |
44 | To overcome the tuning problems AutoTM presents an easy way to represent a learning strategy to train specific models for input corporas. We implement two strategy variants:
45 |
46 | * fixed-size variant, that provides a learning strategy that follow the best practices collected from the manual tuning history
47 |
48 |
49 |
50 | * graph-based variant with more flexibility and unfixed ordering and amount of stages (**New in AutoTM 2.0**). Example of pipeline is provided below:
51 |
52 |
53 |
54 |
55 |
56 | Optimization procedure is done by genetic algorithm (GA) which operators are specifically tuned for the each of the strategy creation variants (GA for graph-based is **New in AutoTM 2.0**). Bayesian Optimization is available only for fixed-size strategy.
57 |
58 | To speed up the procedure AutoTM also contain surrogate modeling implementation for fixed-size and graph-based (**New in AutoTM 2.0**) learning strategies that, for some iterations,
59 | approximate fitness function to reduce computation costs on training topic models.
60 |
61 |
62 |
63 |
64 |
65 | AutoTM also propose a range of metrics that can be used as fitness function, like classical ones as coherence to LLM-based (**New in AutoTM 2.0**).
66 |
67 | ## Installation
68 |
69 | ! Note: The functionality of topic models training is available only for linux distributions.
70 |
71 | **Via pip:**
72 |
73 | ```pip install autotm```
74 |
75 | ```python -m spacy download en_core_web_sm```
76 |
77 | **From source:**
78 |
79 | ```poetry install```
80 |
81 | ```python -m spacy download en_core_web_sm```
82 |
83 | [//]: # (## Dataset and )
84 |
85 | ## Quickstart
86 |
87 | Start with the notebook [Easy topic modeling with AutoTM](https://github.com/aimclub/AutoTM/blob/main/examples/demo_autotm.ipynb) or with the following script [AutoTM configurations](https://github.com/aimclub/AutoTM/blob/main/examples/examples_autotm_fit_predict.py)
88 |
89 | ## Running from the command line
90 |
91 | To fit a model:
92 | ```autotmctl --verbose fit --config conf/config.yaml --in data/sample_corpora/sample_dataset_lenta.csv```
93 |
94 | To predict with a fitted model:
95 | ```autotmctl predict --in data/sample_corpora/sample_dataset_lenta.csv --model model.artm```
96 |
97 |
98 | ## Citation
99 |
100 | ```bibtex
101 | @article{10.1093/jigpal/jzac019,
102 | author = {Khodorchenko, Maria and Butakov, Nikolay and Sokhin, Timur and Teryoshkin, Sergey},
103 | title = "{ Surrogate-based optimization of learning strategies for additively regularized topic models}",
104 | journal = {Logic Journal of the IGPL},
105 | year = {2022},
106 | month = {02},
107 | issn = {1367-0751},
108 | doi = {10.1093/jigpal/jzac019},
109 | url = {https://doi.org/10.1093/jigpal/jzac019},}
110 |
111 | ```
112 |
--------------------------------------------------------------------------------
/autotm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/autotm/__init__.py
--------------------------------------------------------------------------------
/autotm/abstract_params.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from typing import List
4 |
5 |
6 | class AbstractParams(ABC):
7 | @property
8 | @abstractmethod
9 | def basic_topics(self) -> int:
10 | ...
11 |
12 | @property
13 | @abstractmethod
14 | def mutation_probability(self):
15 | ...
16 |
17 | @abstractmethod
18 | def make_params_dict(self):
19 | ...
20 |
21 | @abstractmethod
22 | def run_train(self, model):
23 | """
24 | Trains the topic model
25 | :param model: an instance of TopicModel
26 | :return:
27 | """
28 | ...
29 |
30 | @abstractmethod
31 | def validate_params(self) -> bool:
32 | ...
33 |
34 | @abstractmethod
35 | def crossover(self, parent2: "AbstractParams", **kwargs) -> List["AbstractParams"]:
36 | ...
37 |
38 | @abstractmethod
39 | def mutate(self, **kwargs) -> "AbstractParams":
40 | ...
41 |
42 | @abstractmethod
43 | def to_vector(self) -> List[float]:
44 | ...
45 |
46 |
--------------------------------------------------------------------------------
/autotm/algorithms_for_tuning/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/autotm/algorithms_for_tuning/__init__.py
--------------------------------------------------------------------------------
/autotm/algorithms_for_tuning/bayesian_optimization/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/autotm/algorithms_for_tuning/bayesian_optimization/__init__.py
--------------------------------------------------------------------------------
/autotm/algorithms_for_tuning/genetic_algorithm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/autotm/algorithms_for_tuning/genetic_algorithm/__init__.py
--------------------------------------------------------------------------------
/autotm/algorithms_for_tuning/genetic_algorithm/config.yaml:
--------------------------------------------------------------------------------
1 | appName: geneticAlgo
2 | logLevel: WARN
3 |
4 | testMode: False
5 |
6 | gaAlgoParams:
7 | numEvals: 20
8 |
9 |
--------------------------------------------------------------------------------
/autotm/algorithms_for_tuning/genetic_algorithm/crossover.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from typing import List, Callable
4 |
5 |
6 | def crossover_pmx(parent_1: List[float], parent_2: List[float], **kwargs) -> List[List[float]]:
7 | """
8 | Pmx crossover
9 |
10 | Exchange chromosome parts
11 |
12 | Parameters
13 | ----------
14 | parent_1: List[float]
15 | The first individual to be processed
16 | parent_2: List[float]
17 | The second individual to be processed
18 |
19 | Returns
20 | ----------
21 | Updated individuals with exchanged chromosome parts
22 | """
23 | points_num = len(parent_1)
24 | while True:
25 | cut_ix = np.random.choice(points_num + 1, 2, replace=False)
26 | min_ix = np.min(cut_ix)
27 | max_ix = np.max(cut_ix)
28 | part = parent_1[min_ix:max_ix]
29 | if len(part) != len(parent_1):
30 | break
31 | parent_1[min_ix:max_ix] = parent_2[min_ix:max_ix]
32 | parent_2[min_ix:max_ix] = part
33 | return [parent_1, parent_2]
34 |
35 |
36 | # discrete crossover
37 | def crossover_one_point(parent_1: List[float], parent_2: List[float], **kwargs) -> List[List[float]]:
38 | """
39 | One-point crossover
40 |
41 | Exchange points between chromosomes
42 |
43 | Parameters
44 | ----------
45 | parent_1: List[float]
46 | The first individual to be processed
47 | parent_2: List[float]
48 | The second individual to be processed
49 |
50 | Returns
51 | ----------
52 | Updated individuals with exchanged chromosome parts
53 | """
54 | elem_cross_prob = kwargs["elem_cross_prob"]
55 | for i in range(len(parent_1)):
56 | # removed mutation preservation
57 | if random.random() < elem_cross_prob:
58 | parent_1[i], parent_2[i] = parent_2[i], parent_1[i]
59 | return [parent_1, parent_2]
60 |
61 |
62 | def crossover_blend_new(parent_1: List[float], parent_2: List[float], **kwargs) -> List[List[float]]:
63 | """
64 | Blend crossover
65 |
66 | Making combination of parents solution with coefficient
67 |
68 | Parameters
69 | ----------
70 | parent_1: List[float]
71 | The first individual to be processed
72 | parent_2: List[float]
73 | The second individual to be processed
74 | alpha: float
75 | Blending coefficient
76 |
77 | Returns
78 | ----------
79 | Updated individuals with exchanged chromosome parts
80 | """
81 | alpha = kwargs["alpha"]
82 | child_1 = []
83 | child_2 = []
84 | u = random.random()
85 | gamma = (1.0 + 2.0 * alpha) * u - alpha # fixed (1. + 2. * alpha) * u - alpha
86 | for i in range(len(parent_1)):
87 | child_1.append((1.0 - gamma) * parent_1[i] + gamma * parent_2[i])
88 | child_2.append(gamma * parent_1[i] + (1.0 - gamma) * parent_2[i])
89 |
90 | # TODO: reconsider this
91 | child_1[12:15] = parent_1[12:15]
92 | child_2[12:15] = parent_2[12:15]
93 | return [child_1, child_2]
94 |
95 |
96 | def crossover_blend(parent_1: List[float], parent_2: List[float], **kwargs) -> List[List[float]]:
97 | """
98 | Blend crossover
99 |
100 | Making combination of parents solution with coefficient
101 |
102 | Parameters
103 | ----------
104 | parent_1: List[float]
105 | The first individual to be processed
106 | parent_2: List[float]
107 | The second individual to be processed
108 | alpha: float
109 | Blending coefficient
110 |
111 | Returns
112 | ----------
113 | Updated individuals with exchanged chromosome parts
114 | """
115 | alpha = kwargs["alpha"]
116 | child = []
117 | u = random.random()
118 | gamma = (1 - 2 * alpha) * u - alpha
119 | for i in range(len(parent_1)):
120 | child.append((1 - gamma) * parent_1[i] + gamma * parent_2[i])
121 | if random.random() > 0.5:
122 | child[12:15] = parent_1[12:15]
123 | else:
124 | child[12:15] = parent_2[12:15]
125 | return [child]
126 |
127 |
128 | def crossover(crossover_type: str = "crossover_one_point") -> Callable:
129 | """
130 | Crossover function
131 |
132 | Parameters
133 | ----------
134 | crossover_type : str, default="crossover_one_point"
135 | Crossover to be used in the genetic algorithm
136 | """
137 | if crossover_type == "crossover_pmx":
138 | return crossover_pmx
139 | if crossover_type == "crossover_one_point":
140 | return crossover_one_point
141 | if crossover_type == "blend_crossover":
142 | return crossover_blend
143 |
--------------------------------------------------------------------------------
/autotm/algorithms_for_tuning/genetic_algorithm/mutation.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import List
3 |
4 | import numpy as np
5 |
6 |
7 | def mutation_one_param(
8 | individ: List[float],
9 | low_spb: float,
10 | high_spb: float,
11 | low_spm: float,
12 | high_spm: float,
13 | low_n: int,
14 | high_n: int,
15 | low_back: int,
16 | high_back: int,
17 | low_decor: float,
18 | high_decor: float,
19 | elem_mutation_prob: float = 0.1,
20 | ):
21 | """
22 | One-point mutation
23 |
24 | Checking the probability of mutation for each of the elements
25 |
26 | Parameters
27 | ----------
28 | individ: List[float]
29 | Individual to be processed
30 | low_spb: float
31 | The lower possible bound for sparsity regularizer of back topics
32 | high_spb: float
33 | The higher possible bound for sparsity regularizer of back topics
34 | low_spm: float
35 | The lower possible bound for sparsity regularizer of specific topics
36 | high_spm: float
37 | The higher possible bound for sparsity regularizer of specific topics
38 | low_n: int
39 | The lower possible bound for amount of iterations between stages
40 | high_n: int
41 | The higher possible bound for amount of iterations between stages
42 | low_back:
43 | The lower possible bound for amount of back topics
44 | high_back:
45 | The higher possible bound for amount of back topics
46 |
47 |
48 | Returns
49 | ----------
50 | Updated individuals with exchanged chromosome parts
51 | """
52 | for i in range(len(individ)):
53 | if random.random() <= elem_mutation_prob:
54 | if i in [2, 3]:
55 | individ[i] = np.random.uniform(low=low_spb, high=high_spb, size=1)[0]
56 | for i in [5, 6, 8, 9]:
57 | individ[i] = np.random.uniform(low=low_spm, high=high_spm, size=1)[0]
58 | for i in [1, 4, 7, 10]:
59 | individ[i] = float(np.random.randint(low=low_n, high=high_n, size=1)[0])
60 | for i in [11]:
61 | individ[i] = float(np.random.randint(low=low_back, high=high_back, size=1)[0])
62 | for i in [0, 15]:
63 | individ[i] = np.random.uniform(low=low_decor, high=high_decor, size=1)[0]
64 | return individ
65 |
66 |
67 | def positioning_mutation(individ, elem_mutation_prob=0.1, **kwargs):
68 | for i in range(len(individ)):
69 | set_1 = set([i])
70 | if random.random() <= elem_mutation_prob:
71 | if i in [0, 15]:
72 | set_2 = set([0, 15])
73 | ix = np.random.choice(list(set_2.difference(set_1)))
74 | tmp = individ[ix]
75 | individ[ix] = individ[i]
76 | individ[i] = tmp
77 | elif i in [1, 4, 7, 10, 11]:
78 | set_2 = set([1, 4, 7, 10, 11])
79 | ix = np.random.choice(list(set_2.difference(set_1)))
80 | tmp = individ[ix]
81 | individ[ix] = individ[i]
82 | individ[i] = tmp
83 | elif i in [2, 3]:
84 | set_2 = set([2, 3])
85 | ix = np.random.choice(list(set_2.difference(set_1)))
86 | tmp = individ[ix]
87 | individ[ix] = individ[i]
88 | individ[i] = tmp
89 | elif i in [5, 6, 8, 9]:
90 | set_2 = set([5, 6, 8, 9])
91 | ix = np.random.choice(list(set_2.difference(set_1)))
92 | tmp = individ[ix]
93 | individ[ix] = individ[i]
94 | individ[i] = tmp
95 | return individ
96 |
97 |
98 | def mutation_combined(individ, elem_mutation_prob=0.1, **kwargs):
99 | if random.random() <= individ[14]: # TODO: check 14th position
100 | return mutation_one_param(individ, elem_mutation_prob)
101 | else:
102 | return positioning_mutation(individ, elem_mutation_prob)
103 | pass
104 |
105 |
106 | def do_swap_in_ranges(individ, i, ranges):
107 | swap_range = next((r for r in ranges if i in r), None)
108 | if swap_range is not None:
109 | swap_range = [j for j in swap_range if i != j]
110 | j = np.random.choice(swap_range)
111 | individ[i], individ[j] = individ[j], individ[i]
112 |
113 |
114 | PSM_NEW_SWAP_RANGES = [[1, 4, 7, 10], [2, 3], [5, 6, 8, 9]]
115 |
116 |
117 | def mutation_psm_new(individ, elem_mutation_prob, **kwargs):
118 | for i in range(len(individ)):
119 | if random.random() < elem_mutation_prob:
120 | if i == 0:
121 | individ[i] = individ[15]
122 | elif i == 15:
123 | individ[i] = individ[0]
124 | else:
125 | do_swap_in_ranges(individ, i, PSM_NEW_SWAP_RANGES)
126 | return individ
127 |
128 |
129 | PSM_SWAP_RANGES = [[1, 4, 7], [2, 5], [3, 6]]
130 |
131 |
132 | def mutation_psm(individ, elem_mutation_prob=0.1, **kwargs):
133 | for i in range(len(individ)):
134 | if random.random() < elem_mutation_prob:
135 | if i == 0:
136 | individ[i] = np.random.uniform(low=1, high=100, size=1)[0]
137 | else:
138 | do_swap_in_ranges(individ, i, PSM_SWAP_RANGES)
139 | return individ
140 |
141 |
142 | def mutation(mutation_type="mutation_one_param"):
143 | if mutation_type == "mutation_one_param":
144 | return mutation_one_param
145 | if mutation_type == "combined":
146 | return mutation_combined
147 | if mutation_type == "psm":
148 | return mutation_psm
149 | if mutation_type == "positioning_mutation":
150 | return positioning_mutation
151 |
--------------------------------------------------------------------------------
/autotm/algorithms_for_tuning/genetic_algorithm/selection.py:
--------------------------------------------------------------------------------
1 | import operator
2 | import numpy as np
3 | import random
4 |
5 |
6 | # TODO: roulette wheel selection, stochastic universal sampling and tournament selection
7 |
8 |
9 | def yield_matching_pairs(pairs, population):
10 | population.sort(key=operator.attrgetter("fitness_value"))
11 | population_pairs_pool = []
12 |
13 | while len(population_pairs_pool) < pairs:
14 | chosen = []
15 | idx = 0
16 | selection_probability = random.random()
17 | for ix, individ in enumerate(population):
18 | if selection_probability <= individ._prob:
19 | idx = ix
20 | chosen.append(individ)
21 | break
22 |
23 | selection_probability = random.random()
24 | for k, individ in enumerate(population):
25 | if k != idx:
26 | if selection_probability <= individ._prob:
27 | elems = frozenset((idx, k))
28 | if (len(population_pairs_pool) == 0) or (
29 | (len(population_pairs_pool) > 0)
30 | and (elems not in population_pairs_pool)
31 | ):
32 | chosen.append(individ)
33 | population_pairs_pool.append(elems)
34 | break
35 | else:
36 | continue
37 | if len(chosen) == 1:
38 | selection_idx = np.random.choice(
39 | [m for m in [i for i in range(len(population))] if m != idx]
40 | )
41 | chosen.append(population[selection_idx])
42 | if len(chosen) == 0:
43 | yield None, None
44 | else:
45 | yield chosen[0], chosen[1]
46 |
47 |
48 | def selection_fitness_prop(population, best_proc, children_num):
49 | all_fitness = []
50 |
51 | for individ in population:
52 | all_fitness.append(individ.fitness_value)
53 | fitness_std = np.std(all_fitness)
54 | fitness_mean = np.mean(all_fitness)
55 | cumsum_fitness = 0
56 | # adjust probabilities with sigma scaling
57 | c = 2
58 | for individ in population:
59 | updated_individ_fitness = max(individ.fitness_value - (fitness_mean - c * fitness_std), 0)
60 | cumsum_fitness += updated_individ_fitness
61 | individ._prob = updated_individ_fitness / cumsum_fitness
62 | pairs_count = len(population) * (1 - best_proc)
63 | if children_num == 2:
64 | pairs_count //= 2
65 | return yield_matching_pairs(round(pairs_count), population)
66 |
67 |
68 | def selection_rank_based(population, best_proc, children_num):
69 | population.sort(key=operator.attrgetter("fitness_value"))
70 | for ix, individ in enumerate(population):
71 | individ._prob = 2 * (ix + 1) / (len(population) * (len(population) - 1))
72 | if children_num == 2:
73 | # new population size
74 | return yield_matching_pairs(
75 | round((len(population) * (1 - best_proc))), population
76 | )
77 | else:
78 | return yield_matching_pairs(
79 | round((len(population) * (1 - best_proc))), population
80 | )
81 |
82 |
83 | def stochastic_universal_sampling():
84 | raise NotImplementedError
85 |
86 |
87 | def selection(selection_type="fitness_prop"):
88 | if selection_type == "fitness_prop":
89 | return selection_fitness_prop
90 | if selection_type == "rank_based":
91 | return selection_rank_based
92 |
--------------------------------------------------------------------------------
/autotm/algorithms_for_tuning/genetic_algorithm/statistics_collector.py:
--------------------------------------------------------------------------------
1 | from autotm.algorithms_for_tuning.individuals import Individual
2 |
3 |
4 | class StatisticsCollector:
5 | """
6 | This logger handles collection of statistics
7 | """
8 |
9 | def log_iteration(self, evaluations: int, best_fitness: float):
10 | """
11 | :param evaluations: the number of used evaluations
12 | :param best_fitness: the best fitness in the current iteration
13 | """
14 | pass
15 |
16 | def log_individual(self, individual: Individual):
17 | """
18 | :param individual: a new evaluated individual
19 | """
20 | pass
21 |
--------------------------------------------------------------------------------
/autotm/algorithms_for_tuning/genetic_algorithm/strategy.py:
--------------------------------------------------------------------------------
1 | # strategies of automatic EA configuration
2 |
3 | # CMA-ES (Covariance Matrix Adaptation Evolution Strategy)
4 |
5 | from math import sqrt, log
6 |
7 | # code source: https://github.com/DEAP/deap/blob/master/deap/cma.py
8 | import numpy
9 |
10 |
11 | class Strategy(object):
12 | def __init__(self, centroid, sigma, **kwargs):
13 | self.params = kwargs
14 |
15 | # Create a centroid as a numpy array
16 | self.centroid = numpy.array(centroid)
17 |
18 | self.dim = len(self.centroid)
19 | self.sigma = sigma
20 | self.pc = numpy.zeros(self.dim)
21 | self.ps = numpy.zeros(self.dim)
22 | self.chiN = sqrt(self.dim) * (
23 | 1 - 1.0 / (4.0 * self.dim) + 1.0 / (21.0 * self.dim**2)
24 | )
25 |
26 | self.C = self.params.get("cmatrix", numpy.identity(self.dim))
27 | self.diagD, self.B = numpy.linalg.eigh(self.C)
28 |
29 | indx = numpy.argsort(self.diagD)
30 | self.diagD = self.diagD[indx] ** 0.5
31 | self.B = self.B[:, indx]
32 | self.BD = self.B * self.diagD
33 |
34 | self.cond = self.diagD[indx[-1]] / self.diagD[indx[0]]
35 |
36 | self.lambda_ = self.params.get("lambda_", int(4 + 3 * log(self.dim)))
37 | self.update_count = 0
38 | self.computeParams(self.params)
39 |
40 | def generate(self, ind_init):
41 | r"""Generate a population of :math:`\lambda` individuals of type
42 | *ind_init* from the current strategy.
43 | :param ind_init: A function object that is able to initialize an
44 | individual from a list.
45 | :returns: A list of individuals.
46 | """
47 | arz = numpy.random.standard_normal((self.lambda_, self.dim))
48 | arz = self.centroid + self.sigma * numpy.dot(arz, self.BD.T)
49 | return [ind_init(a) for a in arz]
50 |
51 | def update(self, population):
52 | """Update the current covariance matrix strategy from the
53 | *population*.
54 | :param population: A list of individuals from which to update the
55 | parameters.
56 | """
57 | population.sort(key=lambda ind: ind.fitness, reverse=True)
58 |
59 | old_centroid = self.centroid
60 | self.centroid = numpy.dot(self.weights, population[0: self.mu])
61 |
62 | c_diff = self.centroid - old_centroid
63 |
64 | # Cumulation : update evolution path
65 | self.ps = (1 - self.cs) * self.ps + sqrt(
66 | self.cs * (2 - self.cs) * self.mueff
67 | ) / self.sigma * numpy.dot(
68 | self.B, (1.0 / self.diagD) * numpy.dot(self.B.T, c_diff)
69 | )
70 |
71 | hsig = float(
72 | (
73 | numpy.linalg.norm(self.ps)
74 | / sqrt(1.0 - (1.0 - self.cs) ** (2.0 * (self.update_count + 1.0)))
75 | / self.chiN
76 | < (1.4 + 2.0 / (self.dim + 1.0))
77 | )
78 | )
79 |
80 | self.update_count += 1
81 |
82 | self.pc = (1 - self.cc) * self.pc + hsig * sqrt(
83 | self.cc * (2 - self.cc) * self.mueff
84 | ) / self.sigma * c_diff
85 |
86 | # Update covariance matrix
87 | artmp = population[0: self.mu] - old_centroid
88 | self.C = (
89 | (
90 | 1
91 | - self.ccov1
92 | - self.ccovmu
93 | + (1 - hsig) * self.ccov1 * self.cc * (2 - self.cc)
94 | )
95 | * self.C
96 | + self.ccov1 * numpy.outer(self.pc, self.pc)
97 | + self.ccovmu * numpy.dot((self.weights * artmp.T), artmp) / self.sigma**2
98 | )
99 |
100 | self.sigma *= numpy.exp(
101 | (numpy.linalg.norm(self.ps) / self.chiN - 1.0) * self.cs / self.damps
102 | )
103 |
104 | self.diagD, self.B = numpy.linalg.eigh(self.C)
105 | indx = numpy.argsort(self.diagD)
106 |
107 | self.cond = self.diagD[indx[-1]] / self.diagD[indx[0]]
108 |
109 | self.diagD = self.diagD[indx] ** 0.5
110 | self.B = self.B[:, indx]
111 | self.BD = self.B * self.diagD
112 |
113 | def computeParams(self, params):
114 | r"""Computes the parameters depending on :math:`\lambda`. It needs to
115 | be called again if :math:`\lambda` changes during evolution.
116 | :param params: A dictionary of the manually set parameters.
117 | """
118 | self.mu = params.get("mu", int(self.lambda_ / 2))
119 | rweights = params.get("weights", "superlinear")
120 | if rweights == "superlinear":
121 | self.weights = log(self.mu + 0.5) - numpy.log(numpy.arange(1, self.mu + 1))
122 | elif rweights == "linear":
123 | self.weights = self.mu + 0.5 - numpy.arange(1, self.mu + 1)
124 | elif rweights == "equal":
125 | self.weights = numpy.ones(self.mu)
126 | else:
127 | raise RuntimeError("Unknown weights : %s" % rweights)
128 |
129 | self.weights /= sum(self.weights)
130 | self.mueff = 1.0 / sum(self.weights**2)
131 |
132 | self.cc = params.get("ccum", 4.0 / (self.dim + 4.0))
133 | self.cs = params.get("cs", (self.mueff + 2.0) / (self.dim + self.mueff + 3.0))
134 | self.ccov1 = params.get("ccov1", 2.0 / ((self.dim + 1.3) ** 2 + self.mueff))
135 | self.ccovmu = params.get(
136 | "ccovmu",
137 | 2.0 * (self.mueff - 2.0 + 1.0 / self.mueff) / ((self.dim + 2.0) ** 2 + self.mueff),
138 | )
139 | self.ccovmu = min(1 - self.ccov1, self.ccovmu)
140 | self.damps = (
141 | 1.0 + 2.0 * max(0.0, sqrt((self.mueff - 1.0) / (self.dim + 1.0)) - 1.0) + self.cs
142 | )
143 | self.damps = params.get("damps", self.damps)
144 |
--------------------------------------------------------------------------------
/autotm/algorithms_for_tuning/genetic_algorithm/surrogate.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 | from sklearn.ensemble import RandomForestRegressor, BaggingRegressor
5 | from sklearn.gaussian_process import GaussianProcessRegressor
6 | from sklearn.gaussian_process.kernels import RBF, ConstantKernel, Matern, WhiteKernel, ExpSineSquared, RationalQuadratic
7 | from sklearn.metrics import mean_squared_error
8 | from sklearn.neural_network import MLPRegressor
9 | from sklearn.svm import SVR
10 | from sklearn.tree import DecisionTreeRegressor
11 |
12 | logger = logging.getLogger("GA_algo")
13 |
14 |
15 | # TODO: Add fitness type
16 | def set_surrogate_fitness(value, fitness_type="avg_coherence_score"):
17 | npmis = {
18 | "npmi_50": None,
19 | "npmi_15": None,
20 | "npmi_25": None,
21 | "npmi_50_list": None,
22 | }
23 | scores_dict = {
24 | fitness_type: value,
25 | "perplexityScore": None,
26 | "backgroundTokensRatioScore": None,
27 | "contrast": None,
28 | "purity": None,
29 | "kernelSize": None,
30 | "npmi_50_list": [None], # npmi_values_50_list,
31 | "npmi_50": None,
32 | "sparsity_phi": None,
33 | "sparsity_theta": None,
34 | "topic_significance_uni": None,
35 | "topic_significance_vacuous": None,
36 | "topic_significance_back": None,
37 | "switchP_list": [None],
38 | "switchP": None,
39 | "all_topics": None,
40 | # **coherence_scores,
41 | **npmis,
42 | }
43 | return scores_dict
44 |
45 |
46 | class Surrogate:
47 | def __init__(self, surrogate_name, **kwargs):
48 | self.name = surrogate_name
49 | self.kwargs = kwargs
50 | self.surrogate = None
51 | self.br_n_estimators = None
52 | self.br_n_jobs = None
53 | self.gpr_kernel = None
54 |
55 | def create(self):
56 | kernel = self.kwargs["gpr_kernel"]
57 | del self.kwargs["gpr_kernel"]
58 | gpr_alpha = self.kwargs["gpr_alpha"]
59 | del self.kwargs["gpr_alpha"]
60 | normalize_y = self.kwargs["normalize_y"]
61 | del self.kwargs["normalize_y"]
62 |
63 | if self.name == "random-forest-regressor":
64 | self.surrogate = RandomForestRegressor(**self.kwargs)
65 | elif self.name == "mlp-regressor":
66 | if not self.br_n_estimators:
67 | self.br_n_estimators = self.kwargs["br_n_estimators"]
68 | del self.kwargs["br_n_estimators"]
69 | self.br_n_jobs = self.kwargs["n_jobs"]
70 | del self.kwargs["n_jobs"]
71 | self.kwargs["alpha"] = self.kwargs["mlp_alpha"]
72 | del self.kwargs["mlp_alpha"]
73 | self.surrogate = BaggingRegressor(
74 | base_estimator=MLPRegressor(**self.kwargs),
75 | n_estimators=self.br_n_estimators,
76 | n_jobs=self.br_n_jobs,
77 | )
78 | elif self.name == "GPR": # tune ??
79 | if not self.gpr_kernel:
80 | if kernel == "RBF":
81 | self.gpr_kernel = 1.0 * RBF(1.0)
82 | elif kernel == "RBFwithConstant":
83 | self.gpr_kernel = 1.0 * RBF(1.0) + ConstantKernel()
84 | elif kernel == "Matern":
85 | self.gpr_kernel = 1.0 * Matern(1.0)
86 | elif kernel == "WhiteKernel":
87 | self.gpr_kernel = 1.0 * WhiteKernel(1.0)
88 | elif kernel == "ExpSineSquared":
89 | self.gpr_kernel = ExpSineSquared()
90 | elif kernel == "RationalQuadratic":
91 | self.gpr_kernel = RationalQuadratic(1.0)
92 | self.kwargs["kernel"] = self.gpr_kernel
93 | self.kwargs["alpha"] = gpr_alpha
94 | self.kwargs["normalize_y"] = normalize_y
95 | self.surrogate = GaussianProcessRegressor(**self.kwargs)
96 | elif self.name == "decision-tree-regressor":
97 | try:
98 | if self.kwargs["max_depth"] == 0:
99 | self.kwargs["max_depth"] = None
100 | except KeyError:
101 | logger.error("No max_depth")
102 | self.surrogate = DecisionTreeRegressor(**self.kwargs)
103 | elif self.name == "SVR":
104 | self.surrogate = SVR(**self.kwargs)
105 | # else:
106 | # raise Exception('Undefined surr')
107 |
108 | def fit(self, X, y):
109 | logger.debug(f"X: {X}, y: {y}")
110 | self.create()
111 | self.surrogate.fit(X, y)
112 |
113 | def score(self, X, y):
114 | r_2 = self.surrogate.score(X, y)
115 | y_pred = self.surrogate.predict(X)
116 | mse = mean_squared_error(y, y_pred)
117 | rmse = np.sqrt(mse)
118 | return r_2, mse, rmse
119 |
120 | def predict(self, X):
121 | m = self.surrogate.predict(X)
122 | return m
123 |
124 |
125 | def get_prediction_uncertanty(model, X, surrogate_name, percentile=90):
126 | interval_len = []
127 | if surrogate_name == "random-forest-regressor":
128 | for x in range(len(X)):
129 | preds = []
130 | for pred in model.estimators_:
131 | prediction = pred.predict(np.array(X[x]).reshape(1, -1))
132 | preds.append(prediction[0])
133 | err_down = np.percentile(preds, (100 - percentile) / 2.0)
134 | err_up = np.percentile(preds, 100 - (100 - percentile) / 2.0)
135 | interval_len.append(err_up - err_down)
136 | elif surrogate_name == "GPR":
137 | y_hat, y_sigma = model.predict(X, return_std=True)
138 | interval_len = list(y_sigma)
139 | elif surrogate_name == "decision-tree-regressor":
140 | raise NotImplementedError
141 | return interval_len
142 |
--------------------------------------------------------------------------------
/autotm/algorithms_for_tuning/individuals.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from abc import ABC, abstractmethod
4 |
5 | import numpy as np
6 | import pandas as pd
7 |
8 | from autotm.abstract_params import AbstractParams
9 | from autotm.schemas import IndividualDTO
10 | from autotm.utils import AVG_COHERENCE_SCORE, LLM_SCORE
11 |
12 | SPARSITY_PHI = "sparsity_phi"
13 | SPARSITY_THETA = "sparsity_theta"
14 | SWITCHP_SCORE = "switchP"
15 | DF_NAMES = {"20ng": 0, "lentaru": 1, "amazon_food": 2}
16 |
17 | METRICS_COLS = [
18 | "avg_coherence_score",
19 | "perplexityScore",
20 | "backgroundTokensRatioScore",
21 | "avg_switchp",
22 | "coherence_10",
23 | "coherence_15",
24 | "coherence_20",
25 | "coherence_25",
26 | "coherence_30",
27 | "coherence_35",
28 | "coherence_40",
29 | "coherence_45",
30 | "coherence_50",
31 | "coherence_55",
32 | "contrast",
33 | "purity",
34 | "kernelSize",
35 | "sparsity_phi",
36 | "sparsity_theta",
37 | "topic_significance_uni",
38 | "topic_significance_vacuous",
39 | "topic_significance_back",
40 | "npmi_15",
41 | "npmi_25",
42 | "npmi_50",
43 | ]
44 |
45 | PATH_TO_LEARNED_SCORING = "./scoring_func"
46 |
47 |
48 | class Individual(ABC):
49 | id: str
50 |
51 | @property
52 | @abstractmethod
53 | def dto(self) -> IndividualDTO:
54 | ...
55 |
56 | @property
57 | @abstractmethod
58 | def fitness_value(self) -> float:
59 | ...
60 |
61 | @property
62 | @abstractmethod
63 | def params(self) -> AbstractParams:
64 | ...
65 |
66 |
67 | class BaseIndividual(Individual, ABC):
68 | def __init__(self, dto: IndividualDTO):
69 | self._dto = dto
70 |
71 | @property
72 | def dto(self) -> IndividualDTO:
73 | return self._dto
74 |
75 | @property
76 | def params(self) -> AbstractParams:
77 | return self.dto.params
78 |
79 |
80 | class RegularFitnessIndividual(BaseIndividual):
81 | @property
82 | def fitness_value(self) -> float:
83 | return self.dto.fitness_value[AVG_COHERENCE_SCORE]
84 |
85 |
86 | class LearnedModel:
87 | def __init__(self, save_path, dataset_name):
88 | dataset_id = DF_NAMES[dataset_name]
89 | general_save_path = os.path.join(save_path, "general")
90 | native_save_path = os.path.join(save_path, "native")
91 | with open(
92 | os.path.join(general_save_path, f"general_automl_{dataset_id}.pickle"), "rb"
93 | ) as f:
94 | self.general_model = pickle.load(f)
95 | self.native_model = []
96 | for i in range(5):
97 | with open(
98 | os.path.join(
99 | native_save_path, f"native_automl_{dataset_id}_fold_{i}.pickle"
100 | ),
101 | "rb",
102 | ) as f:
103 | self.native_model.append(pickle.load(f))
104 |
105 | def general_predict(self, df: pd.DataFrame):
106 | y = self.general_model.predict(df[METRICS_COLS])
107 | return y
108 |
109 | def native_predict(self, df: pd.DataFrame):
110 | y = []
111 | for k, nm in enumerate(self.native_model):
112 | y.append(nm.predict(df[METRICS_COLS]))
113 | y = np.array(y)
114 | return np.mean(y, axis=0)
115 |
116 |
117 | class LearnedScorerFitnessIndividual(BaseIndividual):
118 | @property
119 | def fitness_value(self) -> float:
120 | # dataset_name = self.dto.dataset # TODO: check namings
121 | # m = LearnedModel(save_path=PATH_TO_LEARNED_SCORING, dataset_name=dataset_name)
122 | # TODO: predict from metrics df
123 | raise NotImplementedError()
124 |
125 |
126 | class SparsityScalerBasedFitnessIndividual(BaseIndividual):
127 | @property
128 | def fitness_value(self) -> float:
129 | # it is a handling of the situation when a fitness-worker wasn't able to correctly calculate this indvidual
130 | # due to some error in the proceess
131 | # and thus the fitness value doesn't have any metrics except dummy AVG_COHERENCE_SCORE equal to zero
132 | if self.dto.fitness_value[AVG_COHERENCE_SCORE] < 0.00000001:
133 | return 0.0
134 |
135 | alpha = 0.7
136 | if 0.2 <= self.dto.fitness_value[SPARSITY_THETA] <= 0.8:
137 | alpha = 1
138 | # if SWITCHP_SCORE in self.dto.fitness_value:
139 | # return alpha * (self.dto.fitness_value[AVG_COHERENCE_SCORE] + self.dto.fitness_value[SWITCHP_SCORE])
140 | # else:
141 | # return alpha * self.dto.fitness_value[AVG_COHERENCE_SCORE]
142 | return alpha * self.dto.fitness_value[AVG_COHERENCE_SCORE]
143 |
144 |
145 | class LLMBasedFitnessIndividual(BaseIndividual):
146 | @property
147 | def fitness_value(self) -> float:
148 | return self.dto.fitness_value.get(LLM_SCORE, 0.0)
149 |
150 |
151 | class IndividualBuilder:
152 | SUPPORTED_IND_TYPES = ["regular", "sparse", "llm"]
153 |
154 | def __init__(self, ind_type: str = "regular"):
155 | self._ind_type = ind_type
156 |
157 | if self._ind_type not in self.SUPPORTED_IND_TYPES:
158 | raise ValueError(f"Unsupported ind type: {self._ind_type}")
159 |
160 | @property
161 | def individual_type(self) -> str:
162 | return self._ind_type
163 |
164 | def make_individual(self, dto: IndividualDTO) -> Individual:
165 | if self._ind_type == "regular":
166 | return RegularFitnessIndividual(dto=dto)
167 | elif self._ind_type == "sparse":
168 | return SparsityScalerBasedFitnessIndividual(dto=dto)
169 | else:
170 | return LLMBasedFitnessIndividual(dto=dto)
171 |
--------------------------------------------------------------------------------
/autotm/algorithms_for_tuning/nelder_mead_optimization/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/autotm/algorithms_for_tuning/nelder_mead_optimization/__init__.py
--------------------------------------------------------------------------------
/autotm/algorithms_for_tuning/nelder_mead_optimization/nelder_mead.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 | from scipy.optimize import minimize
5 |
6 | from autotm.fitness.tm import FitnessCalculatorWrapper
7 |
8 |
9 | class NelderMeadOptimization:
10 | def __init__(
11 | self,
12 | dataset,
13 | data_path,
14 | exp_id,
15 | topic_count,
16 | train_option,
17 | low_decor=0,
18 | high_decor=1e5,
19 | low_n=0,
20 | high_n=30,
21 | low_back=0,
22 | high_back=5,
23 | low_spb=0,
24 | high_spb=1e2,
25 | low_spm=-1e-3,
26 | high_spm=1e2,
27 | low_sp_phi=-1e3,
28 | high_sp_phi=1e3,
29 | low_prob=0,
30 | high_prob=1,
31 | ):
32 | self.dataset = (dataset,)
33 | self.data_path = data_path
34 | self.exp_id = exp_id
35 | self.topic_count = topic_count
36 | self.train_option = train_option
37 | self.high_decor = high_decor
38 | self.low_decor = low_decor
39 | self.low_n = low_n
40 | self.high_n = high_n
41 | self.low_back = low_back
42 | self.high_back = high_back
43 | self.high_spb = high_spb
44 | self.low_spb = low_spb
45 | self.low_spm = low_spm
46 | self.high_spm = high_spm
47 | self.low_sp_phi = low_sp_phi
48 | self.high_sp_phi = high_sp_phi
49 | self.low_prob = low_prob
50 | self.high_prob = high_prob
51 |
52 | def initialize_params(self):
53 | val_decor, val_decor_2 = np.random.uniform(low=self.low_decor, high=self.high_decor, size=2)
54 | var_n = np.random.randint(low=self.low_n, high=self.high_n, size=4)
55 | var_back = np.random.randint(low=self.low_back, high=self.high_back, size=1)[0]
56 | var_sm = np.random.uniform(low=self.low_spb, high=self.high_spb, size=2)
57 | var_sp = np.random.uniform(low=self.low_sp_phi, high=self.high_sp_phi, size=4)
58 | params = [
59 | val_decor,
60 | var_n[0],
61 | var_sm[0],
62 | var_sm[1],
63 | var_n[1],
64 | var_sp[0],
65 | var_sp[1],
66 | var_n[2],
67 | var_sp[2],
68 | var_sp[3],
69 | var_n[3],
70 | var_back,
71 | val_decor_2,
72 | ]
73 | params = [float(i) for i in params]
74 | return params
75 |
76 | def run_algorithm(self, num_iterations: int = 400, ini_point: list = None):
77 | fitness_calculator = FitnessCalculatorWrapper(
78 | self.dataset, self.data_path, self.topic_count, self.train_option
79 | )
80 |
81 | if ini_point is None:
82 | initial_point = self.initialize_params()
83 | else:
84 | assert len(ini_point) == 13
85 | logging.info(ini_point) # TODO: remove this
86 | ini_point = [float(i) for i in ini_point]
87 | initial_point = ini_point
88 |
89 | res = minimize(
90 | fitness_calculator.run,
91 | initial_point,
92 | bounds=[
93 | (self.low_decor, self.high_decor),
94 | (self.low_n, self.high_n),
95 | (self.low_spb, self.high_spb),
96 | (self.low_spb, self.high_spb),
97 | (self.low_n, self.high_n),
98 | (self.low_sp_phi, self.high_sp_phi),
99 | (self.low_sp_phi, self.high_sp_phi),
100 | (self.low_n, self.high_n),
101 | (self.low_sp_phi, self.high_sp_phi),
102 | (self.low_sp_phi, self.high_sp_phi),
103 | (self.low_n, self.high_n),
104 | (self.low_back, self.high_back),
105 | (self.low_decor, self.high_decor),
106 | ],
107 | method="Nelder-Mead",
108 | options={"return_all": True, "maxiter": num_iterations},
109 | )
110 |
111 | return res
112 |
--------------------------------------------------------------------------------
/autotm/batch_vect_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from artm.batches_utils import BatchVectorizer, Batch
3 | import glob
4 | import random
5 |
6 |
7 | class SampleBatchVectorizer(BatchVectorizer):
8 | def __init__(self, sample_size=10, **kwargs):
9 | self.sample_size = sample_size
10 | super().__init__(**kwargs)
11 |
12 | def _parse_batches(self, data_weight, batches):
13 | if self._process_in_memory:
14 | self._model.master.import_batches(batches)
15 | self._batches_list = [batch.id for batch in batches]
16 | return
17 |
18 | data_paths, data_weights, target_folders = self._populate_data(
19 | data_weight, True
20 | )
21 | for data_p, data_w, target_f in zip(data_paths, data_weights, target_folders):
22 | if batches is None:
23 | batch_filenames = glob.glob(os.path.join(data_p, "*.batch"))
24 | if len(batch_filenames) < self.sample_size:
25 | self.sample_size = len(batch_filenames)
26 | batch_filenames = random.sample(batch_filenames, self.sample_size)
27 | self._batches_list += [Batch(filename) for filename in batch_filenames]
28 |
29 | if len(self._batches_list) < 1:
30 | raise RuntimeError("No batches were found")
31 |
32 | self._weights += [data_w for i in range(len(batch_filenames))]
33 | else:
34 | self._batches_list += [
35 | Batch(os.path.join(data_p, batch)) for batch in batches
36 | ]
37 | self._weights += [data_w for i in range(len(batches))]
38 |
--------------------------------------------------------------------------------
/autotm/clustering.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import copy
3 | import warnings
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import pandas as pd
8 | from sklearn.cluster import KMeans
9 | from sklearn.manifold import TSNE
10 | from sklearn.preprocessing import StandardScaler
11 |
12 | warnings.filterwarnings('ignore')
13 |
14 |
15 | def cluster_phi(phi_df: pd.DataFrame, n_clusters=10, plot_img=True):
16 | _phi_df = copy.deepcopy(phi_df)
17 | y = _phi_df.index.values
18 | x = _phi_df.values
19 | standardized_x = StandardScaler().fit_transform(x)
20 | y_kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(standardized_x)
21 |
22 | if plot_img:
23 | tsne = TSNE(n_components=2).fit_transform(standardized_x)
24 | plt.scatter(tsne[:, 0], tsne[:, 1], c=y_kmeans.labels_, s=6, cmap='Spectral')
25 | plt.gca().set_aspect('equal', 'datalim')
26 | plt.colorbar(boundaries=np.arange(11) - 0.5).set_ticks(np.arange(10))
27 | plt.title('Scatterplot of lenta.ru data', fontsize=24)
28 |
29 | _phi_df['labels'] = y_kmeans.labels_
30 | return _phi_df
31 |
32 |
33 |
--------------------------------------------------------------------------------
/autotm/content_splitter.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 |
4 | class BaseTextSplitter(ABC):
5 |
6 | def text_process(self):
7 | raise NotImplementedError
8 |
9 | def transform(self):
10 | raise NotImplementedError
11 |
12 | def text_split(self):
13 | raise NotImplementedError
14 |
15 | class TextSplitter(BaseTextSplitter):
16 | def content_splitting(self):
17 | raise NotImplementedError
18 |
19 |
--------------------------------------------------------------------------------
/autotm/data_generator/synthetic_data_generation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "outputs": [],
7 | "source": [
8 | "import numpy as np"
9 | ],
10 | "metadata": {
11 | "collapsed": false
12 | }
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "outputs": [],
18 | "source": [
19 | "def data_generator(\n",
20 | " vocab_size=10000,\n",
21 | " num_documents=5000,\n",
22 | " mean_text_len=15,\n",
23 | " std_text_len=5,\n",
24 | " topics_num=10,\n",
25 | " distribution=\"dirichlet\",\n",
26 | "):\n",
27 | " vocabulary = [i for i in range(vocab_size)]\n",
28 | "\n",
29 | " for doc in num_documents:\n",
30 | " doc_len = int(np.random.normal(loc=mean_text_len, scale=std_text_len))"
31 | ],
32 | "metadata": {
33 | "collapsed": false
34 | }
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": null,
39 | "outputs": [],
40 | "source": [],
41 | "metadata": {
42 | "collapsed": false
43 | }
44 | }
45 | ],
46 | "metadata": {
47 | "kernelspec": {
48 | "display_name": "Python 3",
49 | "language": "python",
50 | "name": "python3"
51 | },
52 | "language_info": {
53 | "codemirror_mode": {
54 | "name": "ipython",
55 | "version": 2
56 | },
57 | "file_extension": ".py",
58 | "mimetype": "text/x-python",
59 | "name": "python",
60 | "nbconvert_exporter": "python",
61 | "pygments_lexer": "ipython2",
62 | "version": "2.7.6"
63 | }
64 | },
65 | "nbformat": 4,
66 | "nbformat_minor": 0
67 | }
68 |
--------------------------------------------------------------------------------
/autotm/fitness/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # local or cluster
4 | SUPPORTED_EXEC_MODES = ['local', 'cluster']
5 | AUTOTM_EXEC_MODE = os.environ.get("AUTOTM_EXEC_MODE", "local")
6 |
7 |
8 | # head or worker
9 | SUPPORTED_COMPONENTS = ['head', 'worker']
10 | AUTOTM_COMPONENT = os.environ.get("AUTOTM_COMPONENT", "head")
11 |
--------------------------------------------------------------------------------
/autotm/fitness/external_scores.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | # additional scores to calculate
5 |
6 |
7 | # Topic Significance
8 | def kl_divergence(p, q):
9 | return np.sum(np.where(p != 0, p * np.log(p / q), 0))
10 |
11 |
12 | # Uniform Distribution Over Words (W-Uniform)
13 | def ts_uniform(topic_word_dist):
14 | n_words = topic_word_dist.shape[0]
15 | w_uniform = np.ones(n_words) / n_words
16 | uniform_distances_kl = [kl_divergence(p, w_uniform) for p in topic_word_dist.T]
17 | return uniform_distances_kl
18 |
19 |
20 | # Vacuous Semantic Distribution (W-Vacuous)
21 | def ts_vacuous(doc_topic_dist, topic_word_dist, total_tokens):
22 | n_words = topic_word_dist.shape[0]
23 | # n_tokens = np.sum([len(text) for text in texts])
24 | p_k = np.sum(doc_topic_dist, axis=1) / total_tokens
25 | w_vacauous = np.sum(topic_word_dist * np.tile(p_k, (n_words, 1)), axis=1)
26 | vacauous_distances_kl = [kl_divergence(p, w_vacauous) for p in topic_word_dist.T]
27 | return vacauous_distances_kl
28 |
29 |
30 | # Background Distribution (D-BGround)
31 | def ts_bground(doc_topic_dist):
32 | n_documents = doc_topic_dist.shape[1]
33 | d_bground = np.ones(n_documents) / n_documents
34 | d_bground_distances_kl = [kl_divergence(p.T, d_bground) for p in doc_topic_dist]
35 | return d_bground_distances_kl
36 |
37 |
38 | # SwitchP
39 | def switchp(phi, texts):
40 | words = phi.index.to_list()
41 | max_topic_word_dist = np.argmax(phi.to_numpy(), axis=1)
42 | max_topic_word_dist = dict(zip(words, max_topic_word_dist))
43 | switchp_scores = []
44 | for text in texts:
45 | mapped_text = [
46 | max_topic_word_dist[word]
47 | for word in text.split()
48 | if word in max_topic_word_dist
49 | ]
50 | switches = (np.diff(mapped_text) != 0).sum()
51 | switchp_scores.append(switches / (len(mapped_text) - 1))
52 |
53 | return switchp_scores
54 |
--------------------------------------------------------------------------------
/autotm/fitness/local_tasks.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import List, Optional
3 |
4 | from autotm.algorithms_for_tuning.individuals import Individual, IndividualBuilder
5 | from autotm.fitness.tm import fit_tm_of_individual
6 | from autotm.params_logging_utils import log_params_and_artifacts, log_stats, model_files
7 | from autotm.schemas import IndividualDTO
8 |
9 | logger = logging.getLogger("root")
10 |
11 |
12 | def do_fitness_calculating(
13 | individual: IndividualDTO,
14 | log_artifact_and_parameters: bool = False,
15 | log_run_stats: bool = False,
16 | alg_args: Optional[str] = None,
17 | is_tmp: bool = False,
18 | ) -> IndividualDTO:
19 | # make copy
20 | individual_json = individual.model_dump_json()
21 | individual = IndividualDTO.model_validate_json(individual_json)
22 | logger.info("Doing fitness calculating with individual %s" % individual_json)
23 |
24 | with fit_tm_of_individual(
25 | dataset=individual.dataset,
26 | data_path=individual.data_path,
27 | params=individual.params,
28 | fitness_name=individual.fitness_name,
29 | topic_count=individual.topic_count,
30 | force_dataset_settings_checkout=individual.force_dataset_settings_checkout,
31 | train_option=individual.train_option,
32 | ) as (time_metrics, metrics, tm):
33 | individual.fitness_value = metrics
34 |
35 | with model_files(tm) as tm_files:
36 | if log_artifact_and_parameters:
37 | log_params_and_artifacts(
38 | tm, tm_files, individual, time_metrics, alg_args, is_tmp=is_tmp
39 | )
40 |
41 | if log_run_stats:
42 | log_stats(tm, tm_files, individual, time_metrics, alg_args)
43 |
44 | return individual
45 |
46 |
47 | def calculate_fitness(
48 | individual: IndividualDTO,
49 | log_artifact_and_parameters: bool = False,
50 | log_run_stats: bool = False,
51 | alg_args: Optional[str] = None,
52 | is_tmp: bool = False,
53 | ) -> IndividualDTO:
54 | return do_fitness_calculating(
55 | individual, log_artifact_and_parameters, log_run_stats, alg_args, is_tmp
56 | )
57 |
58 |
59 | def estimate_fitness(ibuilder: IndividualBuilder, population: List[Individual]) -> List[Individual]:
60 | logger.info("Calculating fitness...")
61 | population_with_fitness = []
62 | for individual in population:
63 | if individual.dto.fitness_value is not None:
64 | logger.info("Fitness value already calculated")
65 | population_with_fitness.append(individual)
66 | continue
67 | individ_with_fitness = calculate_fitness(individual.dto)
68 | population_with_fitness.append(ibuilder.make_individual(individ_with_fitness))
69 | logger.info("The fitness results have been obtained")
70 | return population_with_fitness
71 |
72 |
73 | def log_best_solution(
74 | ibuilder: IndividualBuilder,
75 | individual: Individual,
76 | wait_for_result_timeout: Optional[float] = None,
77 | alg_args: Optional[str] = None,
78 | is_tmp: bool = False,
79 | ):
80 | logger.info("Sending a best individual to be logged")
81 | res = ibuilder.make_individual(calculate_fitness(individual.dto,
82 | log_artifact_and_parameters=True,
83 | is_tmp=is_tmp))
84 |
85 | # TODO: write logging
86 | return res
87 |
--------------------------------------------------------------------------------
/autotm/main.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pprint
4 | import tempfile
5 | from contextlib import contextmanager
6 | from typing import Optional
7 |
8 | import click
9 | import pandas as pd
10 | import yaml
11 |
12 | from autotm.base import AutoTM
13 |
14 | # TODO: add proper logging format initialization
15 | logging.basicConfig(level=logging.INFO)
16 | logger = logging.getLogger()
17 |
18 |
19 | def obtain_autotm_params(
20 | config_path: str,
21 | topic_count: Optional[int],
22 | lang: Optional[str],
23 | alg: Optional[str],
24 | surrogate_alg: Optional[str],
25 | log_file: Optional[str]
26 | ):
27 | # TODO: define clearly format of config
28 | if config_path is not None:
29 | logger.info(f"Reading config from path: {os.path.abspath(config_path)}")
30 | with open(config_path, "r") as f:
31 | config = yaml.safe_load(f)
32 | else:
33 | config = dict()
34 |
35 | if topic_count is not None:
36 | config['topic_count'] = topic_count
37 | if lang is not None:
38 | pp = config.get('preprocessing_params', dict())
39 | pp['lang'] = lang
40 | config['preprocessing_params'] = pp
41 | if alg is not None:
42 | config['alg_name'] = alg
43 | if surrogate_alg is not None:
44 | config['surrogate_alg_name'] = surrogate_alg
45 | if log_file is not None:
46 | config['log_file_path'] = log_file
47 |
48 | return config
49 |
50 |
51 | @contextmanager
52 | def prepare_working_dir(working_dir: Optional[str] = None):
53 | if working_dir is None:
54 | with tempfile.TemporaryDirectory(prefix="autotm_wd_") as tmp_working_dir:
55 | yield tmp_working_dir
56 | else:
57 | yield working_dir
58 |
59 |
60 | @click.group()
61 | @click.option('-v', '--verbose', is_flag=True, show_default=True, default=False, help="Verbose output")
62 | def cli(verbose: bool):
63 | if verbose:
64 | logging.basicConfig(level=logging.DEBUG, force=True)
65 |
66 |
67 | @cli.command()
68 | @click.option('--config', 'config_path', type=str, help="A path to config for fitting the model")
69 | @click.option(
70 | '--working-dir',
71 | type=str,
72 | help="A path to working directory used by AutoTM for storing intermediate files. "
73 | "If not specified temporary directory will be created in the current directory "
74 | "and will be deleted upon successful finishing."
75 | )
76 | @click.option('--in', 'in_', type=str, required=True, help="A file in csv format with text corpus to build model on")
77 | @click.option(
78 | '--out',
79 | type=str,
80 | default="mixtures.csv",
81 | help="A path to a file in csv format that will contain topic mixtures for texts from the incoming corpus"
82 | )
83 | @click.option('--model', type=str, default='model.artm', help="A path that will contain fitted ARTM model")
84 | @click.option('-t', '--topic-count', type=int, help="Number of topics to fit model with")
85 | @click.option('--lang', type=str, help='Language of the dataset')
86 | @click.option('--alg', type=str, help="Hyperparameters tuning algorithm. Available: ga, bayes")
87 | @click.option('--surrogate-alg', type=str, help="Surrogate algorithm to use.")
88 | @click.option('--log-file', type=str, help="Log file path")
89 | @click.option(
90 | '--overwrite',
91 | is_flag=True,
92 | show_default=True,
93 | default=False,
94 | help="Overwrite if model or/and mixture files already exist"
95 | )
96 | def fit(
97 | config_path: Optional[str],
98 | working_dir: Optional[str],
99 | in_: str,
100 | out: str,
101 | model: str,
102 | topic_count: Optional[int],
103 | lang: Optional[str],
104 | alg: Optional[str],
105 | surrogate_alg: Optional[str],
106 | log_file: Optional[str],
107 | overwrite: bool
108 | ):
109 | config = obtain_autotm_params(config_path, topic_count, lang, alg, surrogate_alg, log_file)
110 |
111 | logger.debug(f"Running AutoTM with params: {pprint.pformat(config, indent=4)}")
112 |
113 | logger.info(f"Loading data from {os.path.abspath(in_)}")
114 | df = pd.read_csv(in_)
115 |
116 | with prepare_working_dir(working_dir) as work_dir:
117 | logger.info(f"Using working directory {os.path.abspath(work_dir)} for AutoTM")
118 | autotm = AutoTM(
119 | **config,
120 | working_dir_path=work_dir
121 | )
122 | mixtures = autotm.fit_predict(df)
123 |
124 | logger.info(f"Saving model to {os.path.abspath(model)}")
125 | autotm.save(model, overwrite=overwrite)
126 | logger.info(f"Saving mixtures to {os.path.abspath(out)}")
127 | mixtures.to_csv(out, mode='w' if overwrite else 'x')
128 |
129 | logger.info("Finished AutoTM")
130 |
131 |
132 | @cli.command()
133 | @click.option('--model', type=str, required=True, help="A path to fitted saved ARTM model")
134 | @click.option('--in', 'in_', type=str, required=True, help="A file in csv format with text corpus to build model on")
135 | @click.option(
136 | '--out',
137 | type=str,
138 | default="mixtures.csv",
139 | help="A path to a file in csv format that will contain topic mixtures for texts from the incoming corpus"
140 | )
141 | @click.option(
142 | '--overwrite',
143 | is_flag=True,
144 | show_default=True,
145 | default=False,
146 | help="Overwrite if the mixture file already exists"
147 | )
148 | def predict(model: str, in_: str, out: str, overwrite: bool):
149 | logger.info(f"Loading model from {os.path.abspath(model)}")
150 | autotm_loaded = AutoTM.load(model)
151 |
152 | logger.info(f"Predicting mixtures for data from {os.path.abspath(in_)}")
153 | df = pd.read_csv(in_)
154 | mixtures = autotm_loaded.predict(df)
155 |
156 | logger.info(f"Saving mixtures to {os.path.abspath(out)}")
157 | mixtures.to_csv(out, mode='w' if overwrite else 'x')
158 |
159 |
160 | if __name__ == "__main__":
161 | cli()
162 |
--------------------------------------------------------------------------------
/autotm/main_fitness_worker.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import yaml
4 |
5 |
6 | def main():
7 | os.environ['AUTOTM_COMPONENT'] = 'worker'
8 | os.environ['AUTOTM_EXEC_MODE'] = 'cluster'
9 |
10 | from autotm.fitness.cluster_tasks import make_celery_app
11 | from autotm.fitness.tm import TopicModelFactory
12 |
13 | if "DATASETS_CONFIG" in os.environ:
14 | with open(os.environ["DATASETS_CONFIG"], "r") as f:
15 | config = yaml.safe_load(f)
16 | dataset_settings = config["datasets"]
17 | else:
18 | dataset_settings = None
19 |
20 | TopicModelFactory.init_factory_settings(
21 | num_processors=os.getenv("NUM_PROCESSORS", None),
22 | dataset_settings=dataset_settings
23 | )
24 |
25 | celery_app = make_celery_app()
26 | celery_app.worker_main()
27 |
28 |
29 | if __name__ == '__main__':
30 | main()
31 |
--------------------------------------------------------------------------------
/autotm/make-recovery-file.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | import datetime
3 | import logging
4 | import os
5 | import shutil
6 | import sys
7 | from glob import glob
8 |
9 | logger = logging.getLogger()
10 |
11 |
12 | # example: irace-log-2021-07-19T21:48:00-2b5faba3-99a8-47ba-a5cd-2c99e6877160.Rdata
13 | def fpath_to_datetime(fpath: str) -> datetime:
14 | fname = os.path.basename(fpath)
15 | dt_str = fname[len("irace-log-"): len("irace-log-") + len("YYYY-mm-ddTHH:MM:ss")]
16 | dt = datetime.datetime.strptime(dt_str, "%Y-%m-%dT%H:%M:%S")
17 | return dt
18 |
19 |
20 | if __name__ == "__main__":
21 | save_dir = sys.argv[1]
22 | logger.info(f"Looking for data files in {save_dir}")
23 | data_files = [
24 | (fpath, fpath_to_datetime(fpath)) for fpath in glob(f"{save_dir}/*.Rdata")
25 | ]
26 |
27 | if len(data_files) == 0:
28 | logger.info(
29 | f"No data files have been found in {save_dir}. It is a normal situation. Interrupting execution."
30 | )
31 | sys.exit(0)
32 |
33 | last_data_filepath, _ = max(data_files, key=lambda x: x[1])
34 | recovery_filepath = os.path.join(save_dir, "recovery_rdata.checkpoint")
35 |
36 | logger.info(
37 | f"Copying last generated data file ({last_data_filepath}) as recovery file ({recovery_filepath})"
38 | )
39 | shutil.copy(last_data_filepath, recovery_filepath)
40 | logger.info("Copying done")
41 |
--------------------------------------------------------------------------------
/autotm/pipeline.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import List, Union
3 |
4 | from pydantic import BaseModel
5 |
6 |
7 | class IntRangeDistribution(BaseModel):
8 | low: int
9 | high: int
10 |
11 | def create_value(self) -> int:
12 | return random.randint(self.low, self.high)
13 |
14 | def clip(self, value) -> int:
15 | value = int(value)
16 | return int(max(self.low, min(self.high, value)))
17 |
18 |
19 | class FloatRangeDistribution(BaseModel):
20 | low: float
21 | high: float
22 |
23 | def create_value(self) -> float:
24 | return random.uniform(self.low, self.high)
25 |
26 | def clip(self, value) -> float:
27 | return max(self.low, min(self.high, value))
28 |
29 |
30 | class Param(BaseModel):
31 | """
32 | Single parameter of a stage.
33 | Distribution can be unavailable after serialisation.
34 | """
35 | name: str
36 | distribution: Union[IntRangeDistribution, FloatRangeDistribution]
37 |
38 | def create_value(self):
39 | if self.distribution is None:
40 | raise ValueError("Distribution is unavailable. One must restore the distribution after serialisation.")
41 | return self.distribution.create_value()
42 |
43 |
44 | class StageType(BaseModel):
45 | """
46 | Stage template that defines params for a stage.
47 | See create_stage generator.
48 | """
49 | name: str
50 | params: List[Param]
51 |
52 |
53 | class Stage(BaseModel):
54 | """
55 | Stage instance with parameter values.
56 | """
57 | stage_type: StageType
58 | values: List
59 |
60 | def __init__(self, **data):
61 | """
62 | :param stage_type: back reference to the template for parameter mutation
63 | :param values: instance's parameter values
64 | """
65 | super().__init__(**data)
66 | if len(self.stage_type.params) != len(self.values):
67 | raise ValueError("Number of values does not match number of parameters.")
68 |
69 | def __str__(self):
70 | return f"{self.stage_type.name}{self.values}"
71 |
72 | def clip_values(self):
73 | self.values = [param.distribution.clip(value) for param, value in zip(self.stage_type.params, self.values)]
74 |
75 |
76 | def create_stage(stage_type: StageType) -> Stage:
77 | return Stage(stage_type=stage_type, values=[param.create_value() for param in stage_type.params])
78 |
79 |
80 | class Pipeline(BaseModel):
81 | """
82 | List of stages that can be mutated.
83 | """
84 | stages: List[Stage]
85 | required_params: Stage
86 |
87 | def __str__(self):
88 | return f'{str(self.required_params)} {" ".join(map(str, self.stages))}'
89 |
90 | def random_stage_index(self, with_last: bool = False):
91 | last = len(self.stages)
92 | if with_last:
93 | last += 1
94 | return random.randint(0, last - 1)
95 |
96 | def __lt__(self, other):
97 | # important for sort method usage
98 | return len(self.stages) < len(other.stages)
99 |
--------------------------------------------------------------------------------
/autotm/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 | # former 'ppp.csv'
2 | PREPOCESSED_DATASET_FILENAME = "dataset_processed.csv"
3 | RESERVED_TUPLE = ("_SERVICE_", "total_pairs_count")
4 |
--------------------------------------------------------------------------------
/autotm/schemas.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from typing import Union
3 |
4 | from pydantic import BaseModel
5 |
6 | from autotm.params import PipelineParams, FixedListParams
7 | from autotm.utils import MetricsScores
8 |
9 | AnyParams = Union[FixedListParams, PipelineParams]
10 |
11 |
12 | class IndividualDTO(BaseModel):
13 | id: str
14 | data_path: str
15 | params: AnyParams
16 | fitness_name: str = "regular"
17 | dataset: str = "default"
18 | force_dataset_settings_checkout: bool = False
19 | fitness_value: Optional[MetricsScores] = None
20 | exp_id: Optional[int] = None
21 | alg_id: Optional[str] = None
22 | tag: Optional[str] = None
23 | iteration_id: int = 0
24 | topic_count: Optional[int] = None
25 | train_option: str = "offline"
26 |
27 | class Config:
28 | arbitrary_types_allowed = True
29 |
30 | def make_params_dict(self):
31 | return self.params.make_params_dict()
32 |
33 |
--------------------------------------------------------------------------------
/autotm/visualization/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/autotm/visualization/__init__.py
--------------------------------------------------------------------------------
/cluster/bin/fitnessctl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -ex
4 |
5 | if [[ -z "${KUBE_NAMESPACE}" ]]
6 | then
7 | kubectl_args=""
8 | else
9 | kubectl_args="-n ${KUBE_NAMESPACE}"
10 | fi
11 |
12 | release_name="autotm"
13 |
14 | registry="node2.bdcl:5000"
15 | mlflow_image="${registry}/mlflow-webserver:latest"
16 | flower_image="${registry}/flower:latest"
17 | fitness_worker_image="${registry}/fitness-worker:latest"
18 | jupyter_image="${registry}/autotm-jupyter:latest"
19 |
20 | base_dir="./cluster"
21 | docker_files_dir="${base_dir}/docker/"
22 | chart_path="${base_dir}/charts/autotm"
23 |
24 | function build_app() {
25 | echo "Building app..."
26 |
27 | poetry export --without-hashes > requirements.txt
28 | poetry build
29 |
30 | echo "Finished app building"
31 | }
32 |
33 | function build_images(){
34 | echo "Building images..."
35 |
36 | build_app
37 |
38 | docker build -f "${docker_files_dir}/mlflow-webserver.dockerfile" -t ${mlflow_image} .
39 | docker build -f "${docker_files_dir}/flower.dockerfile" -t ${flower_image} .
40 | docker build -f "${docker_files_dir}/worker.dockerfile" -t ${fitness_worker_image} .
41 | docker build -f "${docker_files_dir}/jupyter.dockerfile" -t ${jupyter_image} .
42 |
43 | echo "Finished images building"
44 | }
45 |
46 | function push_images() {
47 | echo "Pushing images..."
48 |
49 | docker push ${mlflow_image}
50 | docker push ${flower_image}
51 | docker push ${fitness_worker_image}
52 | docker push ${jupyter_image}
53 |
54 | echo "Finished pushing images"
55 | }
56 |
57 | function install() {
58 | echo "Installing images..."
59 |
60 | build_images
61 | push_images
62 |
63 | echo "Finished installing images"
64 | }
65 |
66 | function create_pv() {
67 | kubectl ${kubectl_args} apply -f "${base_dir}/conf/pv.yaml"
68 | }
69 |
70 | function delete_pv() {
71 | kubectl ${kubectl_args} delete -f "${base_dir}/conf/pv.yaml" --ignore-not-found
72 | }
73 |
74 | function recreate_pv() {
75 | delete_pv
76 | create_pv
77 | }
78 |
79 | function install_chart() {
80 | helm install \
81 | --create-namespace --namespace="${release_name}" \
82 | --set=autotm_prefix='autotm' --set=mongo_enabled="false" \
83 | ${release_name} ${chart_path}
84 | }
85 |
86 | function uninstall_chart() {
87 | helm uninstall autotm --ignore-not-found --wait
88 | }
89 |
90 | function upgrade_chart() {
91 | # helm upgrade --namespace="${release_name}" --reuse-values autotm ${chart_path}
92 | helm upgrade --namespace="${release_name}" autotm ${chart_path}
93 | }
94 |
95 | function dry_run_chart() {
96 | helm install \
97 | --create-namespace --namespace=${release_name} \
98 | --dry-run=server \
99 | --set=autotm_prefix='autotm' --set=mongo_enabled="false" \
100 | ${release_name} ${chart_path} --wait
101 | }
102 |
103 | function help() {
104 | echo "
105 | Supported env variables:
106 | KUBE_NAMESPACE - a kubernetes namespace to make actions in
107 |
108 | List of commands.
109 | build-app - build the app as a .whl distribution
110 | build-images - build all required docker images
111 | push-images - push all required docker images to the private registry on node2.bdcl
112 | install-images - build-images and push-images
113 | create-pv - creates persistent volumes required for functioning of the chart
114 | install-chart - install autotm deployment
115 | uninstall-chart - uninstall autotm deployment
116 | upgrade-chart - upgrade autotm deployment
117 | help - prints this message
118 | "
119 | }
120 |
121 | function main () {
122 | cmd="$1"
123 |
124 | if [ -z "${cmd}" ]
125 | then
126 | echo "No command is provided."
127 | help
128 | exit 1
129 | fi
130 |
131 | shift 1
132 |
133 | echo "Executing command: ${cmd}"
134 |
135 | case "${cmd}" in
136 | "build-app")
137 | build_app
138 | ;;
139 |
140 | "build-images")
141 | build_images
142 | ;;
143 |
144 | "push-images")
145 | push_images
146 | ;;
147 |
148 | "install-images")
149 | install
150 | ;;
151 |
152 | "create-pv")
153 | create_pv
154 | ;;
155 |
156 | "delete-pv")
157 | delete_pv
158 | ;;
159 |
160 | "recreate-pv")
161 | recreate_pv
162 | ;;
163 |
164 | "install-chart")
165 | install_chart
166 | ;;
167 |
168 | "uninstall-chart")
169 | uninstall_chart
170 | ;;
171 |
172 | "upgrade-chart")
173 | upgrade_chart
174 | ;;
175 |
176 | "dry-run-chart")
177 | dry_run_chart
178 | ;;
179 |
180 | "help")
181 | help
182 | ;;
183 |
184 | *)
185 | echo "Unknown command: ${cmd}"
186 | ;;
187 |
188 | esac
189 | }
190 |
191 | main "${@}"
192 |
--------------------------------------------------------------------------------
/cluster/charts/autotm/.helmignore:
--------------------------------------------------------------------------------
1 | # Patterns to ignore when building packages.
2 | # This supports shell glob matching, relative path matching, and
3 | # negation (prefixed with !). Only one pattern per line.
4 | .DS_Store
5 | # Common VCS dirs
6 | .git/
7 | .gitignore
8 | .bzr/
9 | .bzrignore
10 | .hg/
11 | .hgignore
12 | .svn/
13 | # Common backup files
14 | *.swp
15 | *.bak
16 | *.tmp
17 | *.orig
18 | *~
19 | # Various IDEs
20 | .project
21 | .idea/
22 | *.tmproj
23 | .vscode/
24 |
--------------------------------------------------------------------------------
/cluster/charts/autotm/Chart.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v2
2 | name: autotm
3 | description: A Helm chart for AutoTM
4 |
5 | # A chart can be either an 'application' or a 'library' chart.
6 | #
7 | # Application charts are a collection of templates that can be packaged into versioned archives
8 | # to be deployed.
9 | #
10 | # Library charts provide useful utilities or functions for the chart developer. They're included as
11 | # a dependency of application charts to inject those utilities and functions into the rendering
12 | # pipeline. Library charts do not define any templates and therefore cannot be deployed.
13 | type: application
14 |
15 | # This is the chart version. This version number should be incremented each time you make changes
16 | # to the chart and its templates, including the app version.
17 | # Versions are expected to follow Semantic Versioning (https://semver.org/)
18 | version: 0.1.0
19 |
20 | # This is the version number of the application being deployed. This version number should be
21 | # incremented each time you make changes to the application. Versions are not expected to
22 | # follow Semantic Versioning. They should reflect the version the application is using.
23 | # It is recommended to use it with quotes.
24 | appVersion: "1.0.0"
25 |
--------------------------------------------------------------------------------
/cluster/charts/autotm/templates/_helpers.tpl:
--------------------------------------------------------------------------------
1 | {{/*
2 | Expand the name of the chart.
3 | */}}
4 | {{- define "st-workspace.name" -}}
5 | {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
6 | {{- end }}
7 |
8 | {{/*
9 | Create a default fully qualified app name.
10 | We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec).
11 | If release name contains chart name it will be used as a full name.
12 | */}}
13 | {{- define "st-workspace.fullname" -}}
14 | {{- if .Values.fullnameOverride }}
15 | {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
16 | {{- else }}
17 | {{- $name := default .Chart.Name .Values.nameOverride }}
18 | {{- if contains $name .Release.Name }}
19 | {{- .Release.Name | trunc 63 | trimSuffix "-" }}
20 | {{- else }}
21 | {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
22 | {{- end }}
23 | {{- end }}
24 | {{- end }}
25 |
26 | {{/*
27 | Create chart name and version as used by the chart label.
28 | */}}
29 | {{- define "st-workspace.chart" -}}
30 | {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
31 | {{- end }}
32 |
33 | {{/*
34 | Common labels
35 | */}}
36 | {{- define "st-workspace.labels" -}}
37 | helm.sh/chart: {{ include "st-workspace.chart" . }}
38 | {{ include "st-workspace.selectorLabels" . }}
39 | {{- if .Chart.AppVersion }}
40 | app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
41 | {{- end }}
42 | app.kubernetes.io/managed-by: {{ .Release.Service }}
43 | {{- end }}
44 |
45 | {{/*
46 | Selector labels
47 | */}}
48 | {{- define "st-workspace.selectorLabels" -}}
49 | app.kubernetes.io/name: {{ include "st-workspace.name" . }}
50 | app.kubernetes.io/instance: {{ .Release.Name }}
51 | {{- end }}
52 |
53 | {{/*
54 | Create the name of the service account to use
55 | */}}
56 | {{- define "st-workspace.serviceAccountName" -}}
57 | {{- if .Values.serviceAccount.create }}
58 | {{- default (include "st-workspace.fullname" .) .Values.serviceAccount.name }}
59 | {{- else }}
60 | {{- default "default" .Values.serviceAccount.name }}
61 | {{- end }}
62 | {{- end }}
63 |
64 | {{/*
65 | Helper function to define prefix for all entities
66 | */}}
67 | {{- define "autotm.prefix" -}}
68 | {{- $prefix := default .Values.autotm_prefix .Release.Name | trunc 16 -}}
69 | {{- ternary $prefix (printf "%s-" $prefix) (empty $prefix) -}}
70 | {{- end -}}
71 |
72 | {{/*
73 | Mlflow db url
74 | */}}
75 | {{- define "autotm.mlflow_db_url" -}}
76 | {{- if .Values.mlflow_enabled -}}
77 | {{- printf "mysql+pymysql://%s:%s@%smlflow-db:3306/%s" (.Values.mlflow_mysql_user) (.Values.mlflow_mysql_password) (include "autotm.prefix" . ) (.Values.mlflow_mysql_database) -}}
78 | {{- else -}}
79 | ""
80 | {{- end -}}
81 | {{- end -}}
82 |
83 | {{/*
84 | Mongo url
85 | */}}
86 | {{- define "autotm.mongo_url" -}}
87 | {{- if .Values.mongo_enabled -}}
88 | {{- printf "mongodb://%s:%s@%smongo-tm-experiments-db:27017" (.Values.mongo_user) (.Values.mongo_password) (include "autotm.prefix" .) -}}
89 | {{- else -}}
90 | ""
91 | {{- end -}}
92 | {{- end -}}
93 |
94 | {{/*
95 | Celery broker url
96 | */}}
97 | {{- define "autotm.celery_broker_url" -}}
98 | {{- printf "amqp://guest:guest@%srabbitmq-service:5672" (include "autotm.prefix" .) -}}
99 | {{- end -}}
100 |
101 | {{/*
102 | Celery result backend url
103 | */}}
104 | {{- define "autotm.celery_result_backend" -}}
105 | {{- printf "redis://%sredis:6379/1" (include "autotm.prefix" .) -}}
106 | {{- end -}}
107 |
108 | {{/*
109 | Check if required persistent volume exists.
110 | One should invoke it like '{{ include "autotm.find_pv" (list . "pv_mlflow_db") }}', where 'pv_mlflow_db' is a variable from values.yaml
111 | See:
112 | - https://stackoverflow.com/questions/70803242/helm-pass-single-string-to-a-template
113 | - https://austindewey.com/2021/06/02/writing-function-based-templates-in-helm/
114 | */}}
115 | {{- define "autotm.find_pv" -}}
116 | {{- $root := index . 0 -}}
117 | {{- $pv_var := index . 1 -}}
118 | {{- $pv_name := (printf "%s%s" (include "autotm.prefix" $root) (get $root.Values $pv_var)) -}}
119 | {{- $result := (lookup "v1" "PersistentVolume" "" $pv_name) -}}
120 | {{- if empty $result -}}
121 | {{- (printf "Persistent volume with name '%s' not found" $pv_name) | fail -}}
122 | {{- else -}}
123 | {{- $result.metadata.name -}}
124 | {{- end -}}
125 | {{- end -}}
126 |
--------------------------------------------------------------------------------
/cluster/charts/autotm/templates/configmaps.yaml:
--------------------------------------------------------------------------------
1 | {{ if .Values.mlflow_enabled }}
2 | ##################
3 | #### MLFlow
4 | ##################
5 | ---
6 | apiVersion: v1
7 | kind: ConfigMap
8 | metadata:
9 | name: {{ include "autotm.prefix" . }}mlflow-mysql-cnf
10 | labels:
11 | {{- range $key, $val := .Values.required_labels }}
12 | {{ $key }}: {{ $val | quote }}
13 | {{- end }}
14 | data:
15 | my.cnf: |
16 | [mysqld]
17 | max_connections=500
18 | wait_timeout=15
19 | interactive_timeout=15
20 | {{ end }}
21 | ##################
22 | #### Celery & Fitness Workers
23 | ##################
24 | ---
25 | apiVersion: v1
26 | kind: ConfigMap
27 | metadata:
28 | name: {{ include "autotm.prefix" . }}rabbitmq-config
29 | labels:
30 | {{- range $key, $val := .Values.required_labels }}
31 | {{ $key }}: {{ $val | quote }}
32 | {{- end }}
33 | data:
34 | consumer-settings.conf: |
35 | ## Consumer timeout
36 | ## If a message delivered to a consumer has not been acknowledge before this timer
37 | ## triggers the channel will be force closed by the broker. This ensure that
38 | ## faultly consumers that never ack will not hold on to messages indefinitely.
39 | ##
40 | consumer_timeout = 1800000
41 | ---
42 | apiVersion: v1
43 | kind: ConfigMap
44 | metadata:
45 | name: {{ include "autotm.prefix" . }}worker-config
46 | labels:
47 | {{- range $key, $val := .Values.required_labels }}
48 | {{ $key }}: {{ $val | quote }}
49 | {{- end }}
50 | data:
51 | datasets-config.yaml: |
52 | {{ .Values.worker_datasets_config_content | indent 4}}
53 |
--------------------------------------------------------------------------------
/cluster/charts/autotm/templates/pvc.yaml:
--------------------------------------------------------------------------------
1 | {{ if .Values.pvc_create_enabled }}
2 | {{ if .Values.mlflow_enabled }}
3 | ##################
4 | #### MLFlow
5 | ##################
6 | ---
7 | apiVersion: v1
8 | kind: PersistentVolumeClaim
9 | metadata:
10 | name: {{ include "autotm.prefix" . }}mlflow-db-pvc
11 | labels:
12 | {{- range $key, $val := .Values.required_labels }}
13 | {{ $key }}: {{ $val | quote }}
14 | {{- end }}
15 | spec:
16 | storageClassName: {{ .Values.storage_class }}
17 | volumeName: {{ include "autotm.find_pv" (list . "pv_mlflow_db") }}
18 | accessModes:
19 | - ReadWriteOnce
20 | resources:
21 | requests:
22 | storage: 25Gi
23 | ---
24 | apiVersion: v1
25 | kind: PersistentVolumeClaim
26 | metadata:
27 | name: {{ include "autotm.prefix" . }}mlflow-artifact-store-pvc
28 | labels:
29 | {{- range $key, $val := .Values.required_labels }}
30 | {{ $key }}: {{ $val | quote }}
31 | {{- end }}
32 | spec:
33 | storageClassName: {{ .Values.storage_class }}
34 | volumeName: {{ include "autotm.find_pv" (list . "pv_mlflow_artifact_store") }}
35 | accessModes:
36 | - ReadWriteMany
37 | resources:
38 | requests:
39 | storage: 25Gi
40 | {{ end }}
41 | {{ if .Values.mongo_enabled }}
42 | ##################
43 | #### Mongo
44 | ##################
45 | ---
46 | apiVersion: v1
47 | kind: PersistentVolumeClaim
48 | metadata:
49 | name: {{ include "autotm.prefix" . }}mongo-tm-experiments-pvc
50 | labels:
51 | {{- range $key, $val := .Values.required_labels }}
52 | {{ $key }}: {{ $val | quote }}
53 | {{- end }}
54 | spec:
55 | storageClassName: {{ .Values.storage_class }}
56 | volumeName: {{ include "autotm.find_pv" (list . "pv_mongo_db") }}
57 | accessModes:
58 | - ReadWriteOnce
59 | resources:
60 | requests:
61 | storage: 50Gi
62 | {{ end }}
63 | ###################
64 | ##### Celery & Fitness Workers
65 | ###################
66 | ---
67 | apiVersion: v1
68 | kind: PersistentVolumeClaim
69 | metadata:
70 | name: {{ include "autotm.prefix" . }}datasets
71 | labels:
72 | {{- range $key, $val := .Values.required_labels }}
73 | {{ $key }}: {{ $val | quote }}
74 | {{- end }}
75 | spec:
76 | storageClassName: {{ .Values.storage_class }}
77 | volumeName: {{ include "autotm.find_pv" (list . "pv_dataset_store") }}
78 | accessModes:
79 | - ReadOnlyMany
80 | resources:
81 | requests:
82 | storage: 50Gi
83 | {{ end }}
84 |
--------------------------------------------------------------------------------
/cluster/charts/autotm/templates/services.yaml:
--------------------------------------------------------------------------------
1 | {{ if .Values.jupyter_enabled }}
2 | ##################
3 | #### Jupyter
4 | ##################
5 | ---
6 | apiVersion: v1
7 | kind: Service
8 | metadata:
9 | name: {{ include "autotm.prefix" . }}jupyter
10 | # labels:
11 | # {{- range $key, $val := .Values.required_labels }}
12 | # {{ $key }}:{{ $val | quote }}
13 | # {{- end }}
14 | spec:
15 | type: NodePort
16 | ports:
17 | - port: 8888
18 | protocol: TCP
19 | name: jupyter
20 | - port: 4040
21 | protocol: TCP
22 | name: spark
23 | selector:
24 | app: {{ include "autotm.prefix" . }}jupyter
25 | {{ end }}
26 | {{ if .Values.mlflow_enabled }}
27 | ##################
28 | #### MLFlow
29 | ##################
30 | ---
31 | apiVersion: v1
32 | kind: Service
33 | metadata:
34 | name: {{ include "autotm.prefix" . }}mlflow
35 | labels:
36 | {{- range $key, $val := .Values.required_labels }}
37 | {{ $key }}: {{ $val | quote }}
38 | {{- end }}
39 | spec:
40 | type: NodePort
41 | ports:
42 | - port: 5000
43 | selector:
44 | app: {{ include "autotm.prefix" . }}mlflow
45 | ---
46 | apiVersion: v1
47 | kind: Service
48 | metadata:
49 | name: {{ include "autotm.prefix" . }}mlflow-db
50 | labels:
51 | {{- range $key, $val := .Values.required_labels }}
52 | {{ $key }}: {{ $val | quote }}
53 | {{- end }}
54 | spec:
55 | ports:
56 | - port: 3306
57 | selector:
58 | app: {{ include "autotm.prefix" . }}mlflow-db
59 | {{ end }}
60 | {{ if .Values.mongo_enabled }}
61 | ##################
62 | #### Mongo
63 | ##################
64 | ---
65 | apiVersion: v1
66 | kind: Service
67 | metadata:
68 | name: {{ include "autotm.prefix" . }}mongo-tm-experiments-db
69 | labels:
70 | {{- range $key, $val := .Values.required_labels }}
71 | {{ $key }}: {{ $val | quote }}
72 | {{- end }}
73 | spec:
74 | type: NodePort
75 | ports:
76 | - port: 27017
77 | selector:
78 | app: {{ include "autotm.prefix" . }}mongo-tm-experiments-db
79 | ---
80 | apiVersion: v1
81 | kind: Service
82 | metadata:
83 | name: {{ include "autotm.prefix" . }}mongo-express
84 | labels:
85 | {{- range $key, $val := .Values.required_labels }}
86 | {{ $key }}: {{ $val | quote }}
87 | {{- end }}
88 | spec:
89 | type: NodePort
90 | ports:
91 | - port: 8081
92 | selector:
93 | app: {{ include "autotm.prefix" . }}mongo-express-tm-experiments
94 | {{ end }}
95 | ##################
96 | #### Celery & Fitness Workers
97 | ##################
98 | ---
99 | apiVersion: v1
100 | kind: Service
101 | metadata:
102 | name: {{ include "autotm.prefix" . }}rabbitmq-service
103 | labels:
104 | {{- range $key, $val := .Values.required_labels }}
105 | {{ $key }}: {{ $val | quote }}
106 | {{- end }}
107 | spec:
108 | type: NodePort
109 | ports:
110 | - port: 5672
111 | selector:
112 | app: {{ include "autotm.prefix" . }}rabbitmq
113 | ---
114 | apiVersion: v1
115 | kind: Service
116 | metadata:
117 | name: {{ include "autotm.prefix" . }}redis
118 | labels:
119 | {{- range $key, $val := .Values.required_labels }}
120 | {{ $key }}: {{ $val | quote }}
121 | {{- end }}
122 | spec:
123 | type: NodePort
124 | ports:
125 | - port: 6379
126 | selector:
127 | app: {{ include "autotm.prefix" . }}redis
128 | ---
129 | apiVersion: v1
130 | kind: Service
131 | metadata:
132 | name: {{ include "autotm.prefix" . }}celery-flower-service
133 | labels:
134 | {{- range $key, $val := .Values.required_labels }}
135 | {{ $key }}: {{ $val | quote }}
136 | {{- end }}
137 | spec:
138 | type: NodePort
139 | ports:
140 | - port: 5555
141 | selector:
142 | app: {{ include "autotm.prefix" . }}celery-flower
143 |
--------------------------------------------------------------------------------
/cluster/charts/autotm/values.yaml:
--------------------------------------------------------------------------------
1 | ### General section
2 | autotm_prefix: ""
3 | required_labels:
4 | owner: "autotm"
5 |
6 | ### Images
7 | # pull policy
8 | pull_policy: "IfNotPresent"
9 | worker_image_pull_policy: "Always"
10 | # images
11 | mysql_image: "mysql/mysql-server:5.7.28"
12 | phpmyadmin_image: "phpmyadmin:5.1.1"
13 | mongo_image: "mongo:4.4.6-bionic"
14 | mongoexpress_image: "mongo-express:latest"
15 | rabbitmq_image: "node2.bdcl:5000/rabbitmq:3.8-management-alpine"
16 | redis_image: "node2.bdcl:5000/redis:6.2"
17 | mlflow_image: "node2.bdcl:5000/mlflow-webserver:latest"
18 | flower_image: "node2.bdcl:5000/flower:latest"
19 | worker_image: "node2.bdcl:5000/fitness-worker:latest"
20 | jupyter_image: "node2.bdcl:5000/autotm-jupyter:latest"
21 |
22 | ### Volumes
23 | pvc_create_enabled: "true"
24 | storage_class: "manual"
25 | pv_mlflow_db: "mlflow-db-pv"
26 | pv_mlflow_artifact_store: "mlflow-artifact-store-pv"
27 | pv_mongo_db: "mongo-tm-experiments-pv"
28 | pv_dataset_store: "datasets"
29 |
30 | ### Jupyter
31 | jupyter_enabled: "true"
32 | jupyter_cpu_limits: "4"
33 | jupyter_mem_limits: "16Gi"
34 |
35 | ### Mlflow
36 | mlflow_enabled: "true"
37 | mlflow_mysql_database: "mlflow"
38 | mlflow_mysql_user: "mlflow"
39 | mlflow_mysql_password: "mlflow"
40 | mlflow_mysql_root_password: "mlflow"
41 |
42 | ### Mongo
43 | mongo_enabled: "false"
44 | mongo_user: "mongoadmin"
45 | mongo_password: "secret"
46 |
47 | ### Fitness Worker settings
48 | worker_datasets_dir_path: "/storage"
49 | worker_count: "1"
50 | worker_cpu: "4"
51 | worker_mem: "12Gi"
52 | worker_mongo_collection: "tm_stats"
53 | worker_datasets_config_content: |
54 | datasets:
55 | # the rest settings will be the same as for the first dataset but will be added automatically
56 | books_stroyitelstvo_2030:
57 | base_path: "/storage/books_stroyitelstvo_2030_sample"
58 | topic_count: 10
59 |
60 | 20newsgroups:
61 | base_path: "/storage/20newsgroups_sample"
62 | topic_count: 20
63 |
64 | clickhouse_issues:
65 | base_path: "/storage/clickhouse_issues_sample"
66 | labels: yes
67 | topic_count: 50
68 |
69 | # the rest settings will be the same as for the first dataset but will be added automatically
70 | banners:
71 | base_path: "/storage/banners_sample"
72 | topic_count: 20
73 |
74 | amazon_food:
75 | base_path: "/storage/amazon_food_sample"
76 | topic_count: 20
77 |
78 | hotel-reviews:
79 | base_path: "/storage/hotel-reviews_sample"
80 | topic_count: 20
81 |
82 | lenta_ru:
83 | base_path: "/storage/lenta_ru_sample"
84 | topic_count: 20
85 |
--------------------------------------------------------------------------------
/cluster/conf/pv.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | apiVersion: v1
3 | kind: PersistentVolume
4 | metadata:
5 | name: autotm-mlflow-db-pv
6 | spec:
7 | storageClassName: manual
8 | capacity:
9 | storage: 25Gi
10 | volumeMode: Filesystem
11 | accessModes:
12 | - ReadWriteOnce
13 | persistentVolumeReclaimPolicy: Retain
14 | local:
15 | path: /data/hdd/autotm-mlflow-db
16 | nodeAffinity:
17 | required:
18 | nodeSelectorTerms:
19 | - matchExpressions:
20 | - key: kubernetes.io/hostname
21 | operator: In
22 | values:
23 | - node5.bdcl
24 | ---
25 | apiVersion: v1
26 | kind: PersistentVolume
27 | metadata:
28 | name: autotm-mlflow-artifact-store-pv
29 | spec:
30 | storageClassName: manual
31 | capacity:
32 | storage: 50Gi
33 | volumeMode: Filesystem
34 | accessModes:
35 | - ReadWriteMany
36 | persistentVolumeReclaimPolicy: Retain
37 | hostPath:
38 | path: /mnt/ess_storage/DN_1/storage/home/khodorchenko/mlflow-tm-experiments/autotm-mlflow-artifact-store
39 | type: DirectoryOrCreate
40 | ---
41 | apiVersion: v1
42 | kind: PersistentVolume
43 | metadata:
44 | name: autotm-mongo-tm-experiments-pv
45 | spec:
46 | storageClassName: manual
47 | capacity:
48 | storage: 50Gi
49 | volumeMode: Filesystem
50 | accessModes:
51 | - ReadWriteOnce
52 | persistentVolumeReclaimPolicy: Retain
53 | local:
54 | path: /data/hdd/autotm-mongo-db-tm-experiments
55 | nodeAffinity:
56 | required:
57 | nodeSelectorTerms:
58 | - matchExpressions:
59 | - key: kubernetes.io/hostname
60 | operator: In
61 | values:
62 | - node5.bdcl
63 | ---
64 | apiVersion: v1
65 | kind: PersistentVolume
66 | metadata:
67 | name: autotm-datasets
68 | spec:
69 | storageClassName: manual
70 | capacity:
71 | storage: 50Gi
72 | volumeMode: Filesystem
73 | accessModes:
74 | - ReadOnlyMany
75 | persistentVolumeReclaimPolicy: Retain
76 | hostPath:
77 | path: /mnt/ess_storage/DN_1/storage/home/khodorchenko/GOTM/datasets_TM_scoring
78 | type: Directory
79 |
--------------------------------------------------------------------------------
/cluster/docker/flower.dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:alpine
2 |
3 | # Get latest root certificates
4 | RUN apk add --no-cache ca-certificates && update-ca-certificates
5 |
6 | # Install the required packages
7 | RUN pip install --no-cache-dir redis flower==1.0.0
8 |
9 | # PYTHONUNBUFFERED: Force stdin, stdout and stderr to be totally unbuffered. (equivalent to `python -u`)
10 | # PYTHONHASHSEED: Enable hash randomization (equivalent to `python -R`)
11 | # PYTHONDONTWRITEBYTECODE: Do not write byte files to disk, since we maintain it as readonly. (equivalent to `python -B`)
12 | ENV PYTHONUNBUFFERED=1 PYTHONHASHSEED=random PYTHONDONTWRITEBYTECODE=1
13 |
14 | # Default port
15 | EXPOSE 5555
16 |
17 | ENV FLOWER_DATA_DIR /data
18 | ENV PYTHONPATH ${FLOWER_DATA_DIR}
19 |
20 | WORKDIR $FLOWER_DATA_DIR
21 |
22 | # Add a user with an explicit UID/GID and create necessary directories
23 | RUN set -eux; \
24 | addgroup -g 1000 flower; \
25 | adduser -u 1000 -G flower flower -D; \
26 | mkdir -p "$FLOWER_DATA_DIR"; \
27 | chown flower:flower "$FLOWER_DATA_DIR"
28 | USER flower
29 |
30 | VOLUME $FLOWER_DATA_DIR
31 |
32 | # for '-A distributed_fitness' see kube_fitness.tasks.make_celery_app
33 | ENTRYPOINT celery flower --address=0.0.0.0 --port=5555
--------------------------------------------------------------------------------
/cluster/docker/jupyter.dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.9
2 |
3 | RUN apt-get update && apt-get install -y bash python3 python3-pip
4 |
5 | COPY requirements.txt /tmp/
6 |
7 | RUN pip3 install -r /tmp/requirements.txt
8 |
9 | COPY dist/autotm-0.1.0-py3-none-any.whl /tmp
10 |
11 | RUN pip3 install --no-deps /tmp/autotm-0.1.0-py3-none-any.whl
12 |
13 | RUN pip3 install jupyter
14 |
15 | RUN pip3 install datasets
16 |
17 | RUN python -m spacy download en_core_web_sm
18 |
19 | COPY data /root
20 |
21 | COPY examples/autotm_demo_updated.ipynb /root
22 |
23 | ENTRYPOINT ["jupyter", "notebook", "--notebook-dir=/root", "--ip=0.0.0.0", "--port=8888", "--allow-root", "--no-browser", "--NotebookApp.token=''"]
24 |
--------------------------------------------------------------------------------
/cluster/docker/mlflow-webserver.dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.8
2 |
3 | RUN pip install PyMySQL==1.1.0 sqlalchemy==2.0.23 psycopg2-binary==2.9.9 protobuf==3.20.0 mlflow[extras]==2.9.2
4 |
5 | ENTRYPOINT ["mlflow", "server"]
6 |
--------------------------------------------------------------------------------
/cluster/docker/worker.dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:20.04
2 |
3 | RUN apt-get update && apt-get install -y bash python3 python3-pip
4 |
5 | #RUN pip3 install 'celery == 4.4.7' 'bigartm == 0.9.2' 'tqdm == 4.50.2' 'numpy == 1.19.2' 'dataclasses-json == 0.5.2'
6 |
7 | COPY requirements.txt /tmp/
8 |
9 | RUN pip3 install -r /tmp/requirements.txt
10 |
11 | COPY dist/autotm-0.1.0-py3-none-any.whl /tmp
12 |
13 | RUN pip3 install --no-deps /tmp/autotm-0.1.0-py3-none-any.whl
14 |
15 | ENTRYPOINT fitness-worker \
16 | --concurrency 1 \
17 | --queues fitness_tasks \
18 | --loglevel INFO
19 |
--------------------------------------------------------------------------------
/conf/config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | preprocessing_params:
3 | lang: "ru"
4 | min_tokens_count: 3
5 |
6 | alg_name: "ga"
7 | alg_params:
8 | num_iterations: 2
9 | num_individuals: 4
10 | use_nelder_mead_in_mutation: no
11 | use_nelder_mead_in_crossover: no
12 | use_nelder_mead_in_selector: no
13 | train_option: "offline"
14 |
--------------------------------------------------------------------------------
/distributed/README.md:
--------------------------------------------------------------------------------
1 | Before you start:
2 | 1. One should remember that there is 'help' command thats describes various possible operations:
3 | ```
4 | ./bin/fitnessctl help
5 | ```
6 |
7 | 2. All operations with 'fitnessctl' can be executed in a different namespaces.
8 | One only needs to set 'KUBE_NAMESPACE' variable.
9 | ```
10 | KUBE_NAMESPACE=custom_namespace ./bin/fitnessctl deploy
11 | ```
12 |
13 | To run kube-fitness app in production with mlflow one should:
14 |
15 | 1. Set up .kube/config the way it can access and manage remote cluster.
16 |
17 | 2. Build the app's wheel:
18 | ```
19 | ./bin/fitnessctl build-app
20 | ```
21 |
22 | 3. Build docker images and install them into the shared private repository
23 | `./bin/fitnessctl install-images`. To accomplish this operation successfully you should have the repository added to your docker settings.
24 | See the example of the settings below.
25 |
26 | 4. Create mlflow volumes and PVCs (we require to supply desired_namespace because of PVCs) by issung the command:
27 | ```
28 | kubectl -n apply -f deploy/mlflow-persistent-volumes.yaml
29 | ```
30 | Do not forget that the directories required by volumes should exist beforehand.
31 |
32 | 5. Deploy mlflow:
33 | ```
34 | ./bin/fitnessctl create-mlflow
35 | ```
36 |
37 | 6. Generate production config either using 'gencfg' command with appropriate arguments or using 'prodcfg' alone:
38 | ```
39 | ./bin/fitnessctl prodcfg
40 | ```
41 |
42 | 7. Finally deploy kube-fitness app with the command:
43 | ```
44 | ./bin/fitnessctl create
45 | ```
46 |
47 | Alternatively, one may safely use 'recreate' command if there may be already another deployment:
48 | ```
49 | ./bin/fitnessctl recreate
50 | ```
51 |
52 | 8. To check if everything works fine, one may create a test client and verify that it is completed successfully:
53 | ```
54 | ./bin/fitnessctl recreate-client
55 | ```
56 |
57 | NOTE: there is a shortcut for the above sequence (EXCEPT p.4 and p.5):
58 | ```
59 | ./bin/fitnessctl deploy
60 | ```
61 |
62 | To deploy the kube-fitness app on a kubernetes one should do the following:
63 |
64 | 1. Build docker images and install them into the shared private repository
65 | `./bin/fitnessctl install-images`. To accomplish this operation successfully you should have the repository added to your docker settings.
66 | See the example of the settings below.
67 |
68 | 3. Build and install the wheel into the private PyPI registry (private is the name of the registry).
69 | ```
70 | python setup.py sdist bdist_wheel register -r private upload -r private
71 | ```
72 | To perform this operation you need to have a proper config at **~/.pypirc**
73 | See the example below.
74 |
75 |
76 | 4. Go to the remote server that has access to your kubernetes cluster and has *kubectl* installed.
77 | Install **kube_fitness** wheel
78 | ```
79 | pip3 install --trusted-host node2.bdcl --index http://node2.bdcl:30496 kube_fitness --user --upgrade --force-reinstall
80 | ```
81 |
82 | The wheel should be installed with **--user** setting as it puts necessary files to **$HOME/.local/share/kube-fitness-data** directory.
83 | Other options are recommended for second and further reinstalls.
84 |
85 | 5. Add the line `export PATH="$HOME/.local/bin:$PATH"` into your **~/.bash_profile** to be able to call **kube-fitnessctl** utility.
86 |
87 | 6. Run `kube-fitnessctl gencfg ` to create depployment file for kubernetes.
88 |
89 | 7. Run `kube-fitnessctl create` or `kube-fitnessctl recreate` to deploy fitness workers on your cluster.
90 |
91 | 8. To access the app from your jupyter server, you should install the wheel as well using pip.
92 |
93 | 9. See example how to run the client in kube_fitness/test_app.py
94 |
95 | .pypirc example config:
96 | ```
97 | [distutils]
98 | index-servers =
99 | private
100 |
101 | [private]
102 | repository: http://:/
103 | username:
104 | password:
105 | ```
106 |
107 |
108 | Docker daemon settings (**/etc/docker/daemon.json**) to access a private registry
109 | `"insecure-registries":["node2.bdcl:5000"]`
110 |
--------------------------------------------------------------------------------
/distributed/autotm_distributed/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/distributed/autotm_distributed/__init__.py
--------------------------------------------------------------------------------
/distributed/autotm_distributed/deploy_config_generator.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 |
5 | from jinja2 import Environment, FileSystemLoader
6 |
7 | logger = logging.getLogger()
8 |
9 |
10 | def main():
11 | parser = argparse.ArgumentParser(description='Generate deploy config with'
12 | ' entities for distributed fitness estimation')
13 |
14 | parser.add_argument("--host-data-dir",
15 | type=str,
16 | required=True,
17 | help="Path to a directory on hosts to be mounted in containers of workers")
18 |
19 | parser.add_argument("--datasets-config",
20 | type=str,
21 | required=True,
22 | help="Path to a directory on hosts to be mounted in containers of workers")
23 |
24 | parser.add_argument("--host-data-dir-mount-path", type=str, default="/storage",
25 | help="Path inside containers of workers where to mount host_data_dir")
26 |
27 | parser.add_argument("--registry", type=str, default=None,
28 | help="Registry to push images to")
29 |
30 | parser.add_argument("--client_image", type=str, default="fitness-client:latest",
31 | help="Image to use for a client pod")
32 |
33 | parser.add_argument("--flower_image", type=str, default="flower:latest",
34 | help="Image to use for Flower - a celery monitoring tool")
35 |
36 | parser.add_argument("--worker_image", type=str, default="fitness-worker:latest",
37 | help="Image to use for worker pods")
38 |
39 | parser.add_argument("--worker_count", type=int, default=1,
40 | help="Count of workers to be launched on the kubernetes cluster")
41 |
42 | parser.add_argument("--worker_cpu", type=int, default=1,
43 | help="Count of cpu (int) to be allocated per worker")
44 |
45 | parser.add_argument("--worker_mem", type=str, default="1G",
46 | help="Amount of memory to be allocated per worker")
47 |
48 | parser.add_argument("--config_template_dir", default="deploy",
49 | help="Path to a template's dir the config will be generated from")
50 |
51 | parser.add_argument("--out_dir", type=str, default=None,
52 | help="Path to a generated config")
53 |
54 | parser.add_argument("--mongo_collection", type=str, default='main_tm_stats',
55 | help="Mongo collection")
56 |
57 | args = parser.parse_args()
58 |
59 | worker_template_path = "kube-fitness-workers.yaml.j2"
60 | client_template_path = "kube-fitness-client-job.yaml.j2"
61 |
62 | if args.out_dir:
63 | worker_cfg_out_path = os.path.join(args.out_dir, "kube-fitness-workers.yaml")
64 | client_cfg_out_path = os.path.join(args.out_dir, "kube-fitness-client-job.yaml")
65 | else:
66 | worker_cfg_out_path = os.path.join(args.config_template_dir, "kube-fitness-workers.yaml")
67 | client_cfg_out_path = os.path.join(args.config_template_dir, "kube-fitness-client-job.yaml")
68 |
69 | datasets_config = args.datasets_config \
70 | if os.path.isabs(args.datasets_config) else os.path.join(args.config_template_dir, args.datasets_config)
71 |
72 | logging.info(f"Reading datasets config from {args.datasets_config}")
73 | with open(datasets_config, "r") as f:
74 | datasets_config_content = f.read()
75 |
76 | logging.info(f"Using template dir: {args.config_template_dir}")
77 | logging.info(f"Using template {worker_template_path}")
78 | logging.info(f"Generating config file {worker_cfg_out_path}")
79 |
80 | env = Environment(loader=FileSystemLoader(args.config_template_dir))
81 | template = env.get_template(worker_template_path)
82 | template.stream(
83 | flower_image=f"{args.registry}/{args.flower_image}" if args.registry else args.flower_image,
84 | image=f"{args.registry}/{args.worker_image}" if args.registry else args.worker_image,
85 | pull_policy="Always" if args.registry else "IfNotPresent",
86 | worker_count=args.worker_count,
87 | worker_cpu=args.worker_cpu,
88 | worker_mem=args.worker_mem,
89 | host_data_dir=args.host_data_dir,
90 | host_data_dir_mount_path=args.host_data_dir_mount_path,
91 | datasets_config_content=datasets_config_content,
92 | mongo_collection=args.mongo_collection
93 | ).dump(worker_cfg_out_path)
94 |
95 | template = env.get_template(client_template_path)
96 | template.stream(
97 | image=f"{args.registry}/{args.client_image}" if args.registry else args.client_image,
98 | pull_policy="Always" if args.registry else "IfNotPresent",
99 | ).dump(client_cfg_out_path)
100 |
101 |
102 | if __name__ == "__main__":
103 | main()
104 |
--------------------------------------------------------------------------------
/distributed/autotm_distributed/main.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import yaml
4 |
5 | from autotm_distributed.tasks import make_celery_app
6 | from autotm_distributed.tm import TopicModelFactory
7 |
8 | if __name__ == '__main__':
9 | if "DATASETS_CONFIG" in os.environ:
10 | with open(os.environ["DATASETS_CONFIG"], "r") as f:
11 | config = yaml.load(f)
12 | dataset_settings = config["datasets"]
13 | else:
14 | dataset_settings = None
15 |
16 | TopicModelFactory.init_factory_settings(
17 | num_processors=os.getenv("NUM_PROCESSORS", None),
18 | dataset_settings=dataset_settings
19 | )
20 |
21 | celery_app = make_celery_app()
22 | celery_app.worker_main()
23 |
--------------------------------------------------------------------------------
/distributed/autotm_distributed/metrics.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any
2 |
3 | MetricsScores = Dict[str, Any]
4 |
5 | TimeMeasurements = Dict[str, float]
6 |
7 | AVG_COHERENCE_SCORE = "avg_coherence_score"
8 |
--------------------------------------------------------------------------------
/distributed/autotm_distributed/schemas.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import List, Optional
3 |
4 | from pydantic import BaseModel
5 |
6 | from autotm_distributed import MetricsScores
7 |
8 | PARAM_NAMES = [
9 | 'val_decor', 'var_n_0', 'var_sm_0', 'var_sm_1', 'var_n_1',
10 | 'var_sp_0', 'var_sp_1', 'var_n_2',
11 | 'var_sp_2', 'var_sp_3', "var_n_3",
12 | 'var_n_4',
13 | 'ext_mutation_prob', 'ext_elem_mutation_prob', 'ext_mutation_selector',
14 | 'val_decor_2'
15 | ]
16 |
17 |
18 | # @dataclass_json
19 | # @dataclass
20 | class IndividualDTO(BaseModel):
21 | id: str
22 | params: List[object]
23 | fitness_name: str = "default"
24 | dataset: str = "default"
25 | force_dataset_settings_checkout: bool = False
26 | fitness_value: MetricsScores = None
27 | exp_id: Optional[int] = None
28 | alg_id: Optional[str] = None
29 | tag: Optional[str] = None
30 | iteration_id: int = 0
31 | topic_count: Optional[int] = None
32 | train_option: str = 'offline'
33 |
34 | def make_params_dict(self):
35 | if len(self.params) > len(PARAM_NAMES):
36 | len_diff = len(self.params) - len(PARAM_NAMES)
37 | param_names = copy.deepcopy(PARAM_NAMES) + [f"unknown_param_#{i}" for i in range(len_diff)]
38 | else:
39 | param_names = PARAM_NAMES
40 |
41 | return {name: p_val for name, p_val in zip(param_names, self.params)}
42 |
--------------------------------------------------------------------------------
/distributed/autotm_distributed/test_app.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pprint import pprint
3 | from typing import Optional
4 |
5 | import click
6 |
7 | from autotm_distributed import AVG_COHERENCE_SCORE
8 |
9 | logging.basicConfig(level=logging.DEBUG)
10 |
11 | from autotm_distributed.tasks import make_celery_app, parallel_fitness, log_best_solution
12 | from autotm_distributed.schemas import IndividualDTO
13 |
14 | logger = logging.getLogger("TEST_APP")
15 |
16 | celery_app = make_celery_app(client_to_broker_max_retries=25)
17 |
18 | population = [
19 | IndividualDTO(id="1", dataset="first_dataset", topic_count=5, exp_id=0, alg_id="test_alg", iteration_id=100,
20 | tag="just_a_simple_tag", params=[
21 | 66717.86348968784, 2.0, 98.42286785825902,
22 | 80.18543570807961, 2.0, -19.948347560420373,
23 | -52.28141634493725, 2.0, -92.85597392137976,
24 | -60.49287378084627, 4.0, 3.0, 0.06840138630839943,
25 | 0.556001061599461, 0.9894122432621849, 11679.364068753106]),
26 |
27 | IndividualDTO(id="2", dataset="second_dataset", topic_count=5, params=[
28 | 42260.028918433134, 1.0, 7.535922806674758,
29 | 92.32509683092258, 3.0, -71.92218883101997,
30 | -56.016307418098386, 1.0, -3.237446735558109,
31 | -20.448661743825, 4.0, 7.0, 0.08031597367372223,
32 | 0.42253357962287563, 0.02249898631530911, 61969.93405537756]),
33 |
34 | IndividualDTO(id="3", dataset="first_dataset", topic_count=5, params=[
35 | 53787.42158965788, 1.0, 96.00600751876978,
36 | 82.83724552058459, 6.0, -51.625141715571715,
37 | -35.45911388077616, 0.0, -79.6738397452075,
38 | -42.96576312232228, 3.0, 6.0, 0.1842678599143196,
39 | 0.08563438015417912, 0.104280507307428, 91187.26051038165]),
40 |
41 | IndividualDTO(id="4", dataset="second_dataset", topic_count=5, params=[
42 | 24924.136392296103, 0.0, 98.63988602807903,
43 | 49.03544407815009, 5.0, -8.734591095928806,
44 | -6.99720964952175, 4.0, -32.880078901677265,
45 | -24.61400511189416, 3.0, 0.0, 0.9084621726817743,
46 | 0.6392049950522389, 0.3133878344721094, 39413.00378611856])
47 | ]
48 |
49 |
50 | @click.group()
51 | def cli():
52 | pass
53 |
54 |
55 | @cli.command()
56 | def test():
57 | logger.info("Starting parallel computations...")
58 |
59 | population_with_fitness = parallel_fitness(population)
60 |
61 | assert len(population_with_fitness) == 4, \
62 | f"Wrong population size (should be == 4): {pprint(population_with_fitness)}"
63 |
64 | for ind in population_with_fitness:
65 | logger.info(f"Individual {ind.id} has fitness {ind.fitness_value}")
66 |
67 | assert max((ind.fitness_value[AVG_COHERENCE_SCORE] for ind in population_with_fitness)) > 0, \
68 | "At least one fitness should be more than zero"
69 |
70 | logger.info("Test run succeded")
71 | click.echo('Initialized the database')
72 |
73 |
74 | @cli.command()
75 | @click.option('--timeout', type=float, default=25.0)
76 | @click.option('--expid', type=int)
77 | def run_and_log(timeout: float, expid: Optional[str] = None):
78 | best_ind = population[0]
79 | if expid:
80 | best_ind.exp_id = expid
81 | logger.info(f"Individual: {best_ind}")
82 | ind = log_best_solution(individual=best_ind, wait_for_result_timeout=timeout, alg_args="--arg 1 --arg 2")
83 | logger.info(f"Logged the best solution. Obtained fitness is {ind.fitness_value[AVG_COHERENCE_SCORE]}")
84 |
85 |
86 | if __name__ == '__main__':
87 | cli()
88 |
--------------------------------------------------------------------------------
/distributed/autotm_distributed/utils.py:
--------------------------------------------------------------------------------
1 | import io
2 | import logging
3 | from datetime import datetime
4 | from typing import Optional
5 |
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class TqdmToLogger(io.StringIO):
11 | """
12 | Output stream for TQDM which will output to logger module instead of
13 | the StdOut.
14 | """
15 | def __init__(self, base_logger, level=None):
16 | super(TqdmToLogger, self).__init__()
17 | self.logger = base_logger
18 | self.level = level or logging.INFO
19 | self.buf = ''
20 |
21 | def write(self, buf):
22 | self.buf = buf.strip('\r\n\t ')
23 |
24 | def flush(self):
25 | self.logger.log(self.level, self.buf)
26 |
27 |
28 | class log_exec_timer:
29 | def __init__(self, name: Optional[str] = None):
30 | self.name = name
31 | self._start = None
32 | self._duration = None
33 |
34 | def __enter__(self):
35 | self._start = datetime.now()
36 | return self
37 |
38 | def __exit__(self, typ, value, traceback):
39 | self._duration = (datetime.now() - self._start).total_seconds()
40 | msg = f"Exec time of {self.name}: {self._duration}" if self.name else f"Exec time: {self._duration}"
41 | logger.info(msg)
42 |
43 | @property
44 | def duration(self):
45 | return self._duration
46 |
--------------------------------------------------------------------------------
/distributed/bin/kube-fitnessctl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export DIST_MODE=installed
4 | exec fitnessctl "${@}"
--------------------------------------------------------------------------------
/distributed/bin/trun.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | docker run -it --rm \
4 | -v /home/nikolay/wspace/20newsgroups/:/dataset \
5 | -e "CELERY_BROKER_URL=amqp://guest:guest@rabbitmq-service:5672" \
6 | -e "CELERY_RESULT_BACKEND=rpc://" \
7 | -e "DICTIONARY_PATH=/dataset/dictionary.txt" \
8 | -e "BATCHES_DIR=/dataset/batches" \
9 | -e "MUTUAL_INFO_DICT_PATH=/dataset/mutual_info_dict.pkl" \
10 | -e "EXPERIMENTS_PATH=/tmp/tm_experiments" \
11 | -e "TOPIC_COUNT=10" \
12 | -e "NUM_PROCESSORS=1" \
13 | fitness-worker:latest h
--------------------------------------------------------------------------------
/distributed/deploy/ess-small-datasets-config.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | # the rest settings will be the same as for the first dataset but will be added automatically
3 | books_stroyitelstvo_2030:
4 | base_path: "/storage/books_stroyitelstvo_2030_sample"
5 | topic_count: 10
6 |
7 | 20newsgroups:
8 | base_path: "/storage/20newsgroups_sample"
9 | topic_count: 20
10 |
11 | clickhouse_issues:
12 | base_path: "/storage/clickhouse_issues_sample"
13 | labels: yes
14 | topic_count: 50
15 |
16 | # the rest settings will be the same as for the first dataset but will be added automatically
17 | banners:
18 | base_path: "/storage/banners_sample"
19 | topic_count: 20
20 |
21 | amazon_food:
22 | base_path: "/storage/amazon_food_sample"
23 | topic_count: 20
24 |
25 | hotel-reviews:
26 | base_path: "/storage/hotel-reviews_sample"
27 | topic_count: 20
28 |
29 | lenta_ru:
30 | base_path: "/storage/lenta_ru_sample"
31 | topic_count: 20
32 |
--------------------------------------------------------------------------------
/distributed/deploy/file.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | apiVersion: v1
3 | kind: Service
4 | metadata:
5 | name: mlflow-db-phpmyadmin
6 | spec:
7 | type: NodePort
8 | ports:
9 | - port: 80
10 | selector:
11 | app: mlflow-db-phpmyadmin
12 | ---
13 | apiVersion: apps/v1
14 | kind: Deployment
15 | metadata:
16 | name: mlflow-db-phpmyadmin
17 | spec:
18 | replicas: 1
19 | selector:
20 | matchLabels:
21 | app: mlflow-db-phpmyadmin
22 | template:
23 | metadata:
24 | annotations:
25 | "sidecar.istio.io/inject": "false"
26 | labels:
27 | app: mlflow-db-phpmyadmin
28 | spec:
29 | containers:
30 | - name: phpmyadmin
31 | image: phpmyadmin:5.1.1
32 | imagePullPolicy: IfNotPresent
33 | env:
34 | - name: PMA_HOST
35 | value: mlflow-db
36 | - name: PMA_PORT
37 | value: "3306"
38 | - name: PMA_USER
39 | value: mlflow
40 | - name: PMA_PASSWORD
41 | value: mlflow
42 | ports:
43 | - containerPort: 80
44 |
--------------------------------------------------------------------------------
/distributed/deploy/fitness-worker-health-checker.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | apiVersion: batch/v1beta1
3 | kind: CronJob
4 | metadata:
5 | name: fitness-worker-health-checker
6 | spec:
7 | schedule: "*/5 * * * *"
8 | concurrencyPolicy: Forbid
9 | jobTemplate:
10 | spec:
11 | template:
12 | metadata:
13 | annotations:
14 | "sidecar.istio.io/inject": "false"
15 | spec:
16 | serviceAccountName: default-editor
17 | containers:
18 | - name: health-checker
19 | image: node2.bdcl:5000/fitness-worker-health-checker:latest
20 | imagePullPolicy: IfNotPresent
21 | command:
22 | - /bin/sh
23 | - -c
24 | - >
25 | kube_ns=$(cat /var/run/secrets/kubernetes.io/serviceaccount/namespace) &&
26 | kubectl -n ${kube_ns} logs --prefix --tail=100 -l=app=fitness-worker 2>1 |
27 | grep -i 'Could not create logging file: File exists' |
28 | perl -n -e'/pod\/(.+)\/worker/ && print "$1\n"' |
29 | sort |
30 | uniq |
31 | xargs -n 1 -t -I {} -P8 kubectl -n ${kube_ns} delete pod {} --ignore-not-found
32 | restartPolicy: Never
--------------------------------------------------------------------------------
/distributed/deploy/kube-fitness-client-job.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: batch/v1
2 | kind: Job
3 | metadata:
4 | name: fitness-client
5 | spec:
6 | template:
7 | # works even with injection
8 | metadata:
9 | annotations:
10 | "sidecar.istio.io/inject": "false"
11 | spec:
12 | containers:
13 | - name: client
14 | image: node2.bdcl:5000/fitness-client:latest
15 | args: ["test"]
16 | imagePullPolicy: Always
17 | env:
18 | - name: CELERY_BROKER_URL
19 | value: "amqp://guest:guest@rabbitmq-service:5672"
20 | - name: CELERY_RESULT_BACKEND
21 | value: "redis://redis:6379/1" # "rpc://"
22 | restartPolicy: Never
23 | backoffLimit: 0
--------------------------------------------------------------------------------
/distributed/deploy/kube-fitness-client-job.yaml.j2:
--------------------------------------------------------------------------------
1 | apiVersion: batch/v1
2 | kind: Job
3 | metadata:
4 | name: fitness-client
5 | spec:
6 | template:
7 | # works even with injection
8 | metadata:
9 | annotations:
10 | "sidecar.istio.io/inject": "false"
11 | spec:
12 | containers:
13 | - name: client
14 | image: {{ image }}
15 | args: ["test"]
16 | imagePullPolicy: {{ pull_policy }}
17 | env:
18 | - name: CELERY_BROKER_URL
19 | value: "amqp://guest:guest@rabbitmq-service:5672"
20 | - name: CELERY_RESULT_BACKEND
21 | value: "redis://redis:6379/1" # "rpc://"
22 | restartPolicy: Never
23 | backoffLimit: 0
--------------------------------------------------------------------------------
/distributed/deploy/kube-fitness-workers.yaml.j2:
--------------------------------------------------------------------------------
1 | ---
2 | apiVersion: apps/v1
3 | kind: Deployment
4 | metadata:
5 | name: rabbitmq
6 | spec:
7 | replicas: 1
8 | selector:
9 | matchLabels:
10 | app: rabbitmq
11 | template:
12 | metadata:
13 | annotations:
14 | "sidecar.istio.io/inject": "false"
15 | labels:
16 | app: rabbitmq
17 | spec:
18 | volumes:
19 | - name: config-volume
20 | configMap:
21 | name: rabbitmq-config
22 | containers:
23 | - name: rabbitmq
24 | image: node2.bdcl:5000/rabbitmq:3.8-management-alpine
25 | imagePullPolicy: IfNotPresent
26 | ports:
27 | - containerPort: 5672
28 | volumeMounts:
29 | - name: config-volume
30 | mountPath: /etc/rabbitmq/conf.d/consumer-settings.conf
31 | subPath: consumer-settings.conf
32 | ---
33 | apiVersion: v1
34 | kind: Service
35 | metadata:
36 | name: rabbitmq-service
37 | spec:
38 | ports:
39 | - port: 5672
40 | selector:
41 | app: rabbitmq
42 | ---
43 | apiVersion: v1
44 | kind: ConfigMap
45 | metadata:
46 | name: rabbitmq-config
47 | data:
48 | consumer-settings.conf: |
49 | ## Consumer timeout
50 | ## If a message delivered to a consumer has not been acknowledge before this timer
51 | ## triggers the channel will be force closed by the broker. This ensure that
52 | ## faultly consumers that never ack will not hold on to messages indefinitely.
53 | ##
54 | consumer_timeout = 1800000
55 | ---
56 | apiVersion: apps/v1 # for k8s versions before 1.9.0 use apps/v1beta2 and before 1.8.0 use extensions/v1beta1
57 | kind: Deployment
58 | metadata:
59 | name: redis
60 | spec:
61 | replicas: 1
62 | selector:
63 | matchLabels:
64 | app: redis
65 | template:
66 | metadata:
67 | annotations:
68 | "sidecar.istio.io/inject": "false"
69 | labels:
70 | app: redis
71 | spec:
72 | containers:
73 | - name: redis
74 | image: node2.bdcl:5000/redis:6.2
75 | imagePullPolicy: IfNotPresent
76 | resources:
77 | requests:
78 | cpu: 100m
79 | memory: 100Mi
80 | ports:
81 | - containerPort: 6379
82 | ---
83 | apiVersion: v1
84 | kind: Service
85 | metadata:
86 | name: redis
87 | spec:
88 | ports:
89 | - port: 6379
90 | selector:
91 | app: redis
92 | ---
93 | apiVersion: apps/v1
94 | kind: Deployment
95 | metadata:
96 | name: celery-flower
97 | spec:
98 | replicas: 1
99 | selector:
100 | matchLabels:
101 | app: celery-flower
102 | template:
103 | metadata:
104 | annotations:
105 | "sidecar.istio.io/inject": "false"
106 | labels:
107 | app: celery-flower
108 | spec:
109 | containers:
110 | - name: flower
111 | image: {{ flower_image }}
112 | imagePullPolicy: {{ pull_policy }}
113 | ports:
114 | - containerPort: 5555
115 | env:
116 | - name: CELERY_BROKER_URL
117 | value: "amqp://guest:guest@rabbitmq-service:5672"
118 | - name: CELERY_RESULT_BACKEND
119 | value: "redis://redis:6379/1" # "rpc://"
120 | ---
121 | apiVersion: v1
122 | kind: Service
123 | metadata:
124 | name: celery-flower-service
125 | spec:
126 | type: NodePort
127 | ports:
128 | - port: 5555
129 | selector:
130 | app: celery-flower
131 | ---
132 | apiVersion: v1
133 | kind: ConfigMap
134 | metadata:
135 | name: fitness-worker-config
136 | data:
137 | datasets-config.yaml: |
138 | {{ datasets_config_content | indent( width=4, first=False) }}
139 | ---
140 | apiVersion: apps/v1
141 | kind: Deployment
142 | metadata:
143 | name: fitness-worker
144 | labels:
145 | app: fitness-worker
146 | spec:
147 | replicas: {{ worker_count }}
148 | selector:
149 | matchLabels:
150 | app: fitness-worker
151 | template:
152 | metadata:
153 | annotations:
154 | "sidecar.istio.io/inject": "false"
155 | labels:
156 | app: fitness-worker
157 | spec:
158 | volumes:
159 | - name: dataset
160 | hostPath:
161 | path: {{ host_data_dir }}
162 | type: Directory
163 | - name: config-volume
164 | configMap:
165 | name: fitness-worker-config
166 | - name: mlflow-vol
167 | persistentVolumeClaim:
168 | claimName: mlflow-artifact-store-pvc
169 | containers:
170 | - name: worker
171 | image: {{ image }}
172 | imagePullPolicy: {{ pull_policy }}
173 | volumeMounts:
174 | - name: dataset
175 | mountPath: {{ host_data_dir_mount_path }}
176 | - name: config-volume
177 | mountPath: /etc/fitness/datasets-config.yaml
178 | subPath: datasets-config.yaml
179 | - mountPath: "/var/lib/mlruns"
180 | name: mlflow-vol
181 | env:
182 | - name: CELERY_BROKER_URL
183 | value: "amqp://guest:guest@rabbitmq-service:5672"
184 | - name: CELERY_RESULT_BACKEND
185 | value: "redis://redis:6379/1" # "rpc://"
186 | - name: NUM_PROCESSORS
187 | value: "{{ worker_cpu }}"
188 | - name: DATASETS_CONFIG
189 | value: /etc/fitness/datasets-config.yaml
190 | - name: MLFLOW_TRACKING_URI
191 | value: mysql+pymysql://mlflow:mlflow@mlflow-db:3306/mlflow
192 | # see: https://github.com/mongodb/mongo-python-driver/blob/c8d920a46bfb7b054326b3e983943bfc794cb676/pymongo/mongo_client.py
193 | - name: MONGO_URI
194 | value: mongodb://mongoadmin:secret@mongo-tm-experiments-db:27017
195 | - name: MONGO_COLLECTION
196 | value: "{{ mongo_collection or 'tm_stats' }}"
197 | resources:
198 | requests:
199 | memory: "{{ worker_mem }}"
200 | cpu: "{{ worker_cpu }}"
201 | limits:
202 | memory: "{{ worker_mem }}"
203 | cpu: "{{ worker_cpu }}"
204 |
--------------------------------------------------------------------------------
/distributed/deploy/mlflow-minikube-persistent-volumes.yaml:
--------------------------------------------------------------------------------
1 | # TODO: need to be fixed in mongo part
2 | #---
3 | #apiVersion: v1
4 | #kind: PersistentVolume
5 | #metadata:
6 | # name: mlflow-db-pv
7 | #spec:
8 | # storageClassName: manual
9 | # capacity:
10 | # storage: 25Gi
11 | # volumeMode: Filesystem
12 | # accessModes:
13 | # - ReadWriteOnce
14 | # persistentVolumeReclaimPolicy: Retain
15 | # local:
16 | # path: /data/hdd/mlflow-db
17 | # nodeAffinity:
18 | # required:
19 | # nodeSelectorTerms:
20 | # - matchExpressions:
21 | # - key: kubernetes.io/hostname
22 | # operator: In
23 | # values:
24 | # - localhost.localdomain
25 | #---
26 | #apiVersion: v1
27 | #kind: PersistentVolume
28 | #metadata:
29 | # name: mlflow-artifact-store-pv
30 | #spec:
31 | # storageClassName: manual
32 | # capacity:
33 | # storage: 50Gi
34 | # volumeMode: Filesystem
35 | # accessModes:
36 | # - ReadWriteMany
37 | # persistentVolumeReclaimPolicy: Retain
38 | # hostPath:
39 | # path: /mnt/ess_storage/DN_1/storage/home/khodorchenko/mlflow-tm-experiments/mlflow-artifact-store
40 | # type: Directory
41 | #---
42 | #apiVersion: v1
43 | #kind: PersistentVolume
44 | #metadata:
45 | # name: mongo-tm-experiments-pv
46 | #spec:
47 | # storageClassName: manual
48 | # capacity:
49 | # storage: 20Gi
50 | # volumeMode: Filesystem
51 | # accessModes:
52 | # - ReadWriteOnce
53 | # persistentVolumeReclaimPolicy: Retain
54 | # hostPath:
55 | # path: /mnt/ess_storage/DN_1/storage/home/khodorchenko/mlflow-tm-experiments/mongodb
56 | # type: Directory
57 | #---
58 | #apiVersion: v1
59 | #kind: PersistentVolumeClaim
60 | #metadata:
61 | # name: mlflow-db-pvc
62 | #spec:
63 | # storageClassName: manual
64 | # volumeName: mlflow-db-pv
65 | # accessModes:
66 | # - ReadWriteOnce
67 | # resources:
68 | # requests:
69 | # storage: 25Gi
70 | #---
71 | #apiVersion: v1
72 | #kind: PersistentVolumeClaim
73 | #metadata:
74 | # name: mlflow-artifact-store-pvc
75 | #spec:
76 | # storageClassName: manual
77 | # volumeName: mlflow-artifact-store-pv
78 | # accessModes:
79 | # - ReadWriteMany
80 | # resources:
81 | # requests:
82 | # storage: 25Gi
83 | #---
84 | #apiVersion: v1
85 | #kind: PersistentVolumeClaim
86 | #metadata:
87 | # name: mongo-tm-experiments-pvc
88 | #spec:
89 | # storageClassName: manual
90 | # volumeName: mongo-tm-experiments-pv
91 | # accessModes:
92 | # - ReadWriteOnce
93 | # resources:
94 | # requests:
95 | # storage: 20Gi
96 |
--------------------------------------------------------------------------------
/distributed/deploy/mlflow-persistent-volumes.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | apiVersion: v1
3 | kind: PersistentVolume
4 | metadata:
5 | name: mlflow-db-pv
6 | spec:
7 | storageClassName: manual
8 | capacity:
9 | storage: 25Gi
10 | volumeMode: Filesystem
11 | accessModes:
12 | - ReadWriteOnce
13 | persistentVolumeReclaimPolicy: Retain
14 | local:
15 | path: /data/hdd/mlflow-db
16 | nodeAffinity:
17 | required:
18 | nodeSelectorTerms:
19 | - matchExpressions:
20 | - key: kubernetes.io/hostname
21 | operator: In
22 | values:
23 | - node3.bdcl
24 | ---
25 | apiVersion: v1
26 | kind: PersistentVolume
27 | metadata:
28 | name: mlflow-artifact-store-pv
29 | spec:
30 | storageClassName: manual
31 | capacity:
32 | storage: 50Gi
33 | volumeMode: Filesystem
34 | accessModes:
35 | - ReadWriteMany
36 | persistentVolumeReclaimPolicy: Retain
37 | hostPath:
38 | path: /mnt/ess_storage/DN_1/storage/home/khodorchenko/mlflow-tm-experiments/mlflow-artifact-store
39 | type: Directory
40 | #---
41 | #apiVersion: v1
42 | #kind: PersistentVolume
43 | #metadata:
44 | # name: mongo-tm-experiments-pv
45 | #spec:
46 | # storageClassName: manual
47 | # capacity:
48 | # storage: 50Gi
49 | # volumeMode: Filesystem
50 | # accessModes:
51 | # - ReadWriteOnce
52 | # persistentVolumeReclaimPolicy: Retain
53 | # local:
54 | # path: /data/hdd/mongo-db-tm-experiments
55 | # nodeAffinity:
56 | # required:
57 | # nodeSelectorTerms:
58 | # - matchExpressions:
59 | # - key: kubernetes.io/hostname
60 | # operator: In
61 | # values:
62 | # - node3.bdcl
63 | #z
64 | ---
65 | apiVersion: v1
66 | kind: PersistentVolume
67 | metadata:
68 | name: mongo-tm-experiments-pv-part3
69 | spec:
70 | storageClassName: manual
71 | capacity:
72 | storage: 50Gi
73 | volumeMode: Filesystem
74 | accessModes:
75 | - ReadWriteOnce
76 | persistentVolumeReclaimPolicy: Retain
77 | local:
78 | path: /data/hdd/mongo-db-tm-experiments-part3
79 | nodeAffinity:
80 | required:
81 | nodeSelectorTerms:
82 | - matchExpressions:
83 | - key: kubernetes.io/hostname
84 | operator: In
85 | values:
86 | - node11.bdcl
87 | ---
88 | apiVersion: v1
89 | kind: PersistentVolumeClaim
90 | metadata:
91 | name: mlflow-db-pvc
92 | spec:
93 | storageClassName: manual
94 | volumeName: mlflow-db-pv
95 | accessModes:
96 | - ReadWriteOnce
97 | resources:
98 | requests:
99 | storage: 25Gi
100 | ---
101 | apiVersion: v1
102 | kind: PersistentVolumeClaim
103 | metadata:
104 | name: mlflow-artifact-store-pvc
105 | spec:
106 | storageClassName: manual
107 | volumeName: mlflow-artifact-store-pv
108 | accessModes:
109 | - ReadWriteMany
110 | resources:
111 | requests:
112 | storage: 25Gi
113 | ---
114 | apiVersion: v1
115 | kind: PersistentVolumeClaim
116 | metadata:
117 | name: mongo-tm-experiments-pvc
118 | spec:
119 | storageClassName: manual
120 | volumeName: mongo-tm-experiments-pv-part3
121 | accessModes:
122 | - ReadWriteOnce
123 | resources:
124 | requests:
125 | storage: 50Gi
126 |
--------------------------------------------------------------------------------
/distributed/deploy/mongo_dev/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | version: "3.5"
2 |
3 | services:
4 | mongo:
5 | image: mongo:latest
6 | container_name: mongo
7 | environment:
8 | MONGO_INITDB_ROOT_USERNAME: admin
9 | MONGO_INITDB_ROOT_PASSWORD: admin
10 | ports:
11 | - "0.0.0.0:27017:27017"
12 | networks:
13 | - MONGO
14 | volumes:
15 | - type: volume
16 | source: MONGO_DATA
17 | target: /data/db
18 | - type: volume
19 | source: MONGO_CONFIG
20 | target: /data/configdb
21 | mongo-express:
22 | image: mongo-express:latest
23 | container_name: mongo-express
24 | environment:
25 | ME_CONFIG_MONGODB_ADMINUSERNAME: admin
26 | ME_CONFIG_MONGODB_ADMINPASSWORD: admin
27 | ME_CONFIG_MONGODB_SERVER: mongo
28 | ME_CONFIG_MONGODB_PORT: "27017"
29 | # this script was taken from https://github.com/eficode/wait-for
30 | volumes:
31 | - type: bind
32 | source: ./wait-for.sh
33 | target: /wait-for.sh
34 | entrypoint:
35 | - /bin/sh
36 | - /wait-for.sh
37 | - mongo:27017
38 | - --
39 | - tini
40 | - --
41 | - /docker-entrypoint.sh
42 | ports:
43 | - "0.0.0.0:8081:8081"
44 | networks:
45 | - MONGO
46 | depends_on:
47 | - mongo
48 |
49 | networks:
50 | MONGO:
51 | name: MONGO
52 |
53 | volumes:
54 | MONGO_DATA:
55 | name: MONGO_DATA
56 | MONGO_CONFIG:
57 | name: MONGO_CONFIG
58 |
--------------------------------------------------------------------------------
/distributed/deploy/mongo_dev/wait-for.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # The MIT License (MIT)
4 | #
5 | # Copyright (c) 2017 Eficode Oy
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | VERSION="2.2.3"
26 |
27 | set -- "$@" -- "$TIMEOUT" "$QUIET" "$PROTOCOL" "$HOST" "$PORT" "$result"
28 | TIMEOUT=15
29 | QUIET=0
30 | # The protocol to make the request with, either "tcp" or "http"
31 | PROTOCOL="tcp"
32 |
33 | echoerr() {
34 | if [ "$QUIET" -ne 1 ]; then printf "%s\n" "$*" 1>&2; fi
35 | }
36 |
37 | usage() {
38 | exitcode="$1"
39 | cat << USAGE >&2
40 | Usage:
41 | $0 host:port|url [-t timeout] [-- command args]
42 | -q | --quiet Do not output any status messages
43 | -t TIMEOUT | --timeout=timeout Timeout in seconds, zero for no timeout
44 | -v | --version Show the version of this tool
45 | -- COMMAND ARGS Execute command with args after the test finishes
46 | USAGE
47 | exit "$exitcode"
48 | }
49 |
50 | wait_for() {
51 | case "$PROTOCOL" in
52 | tcp)
53 | if ! command -v nc >/dev/null; then
54 | echoerr 'nc command is missing!'
55 | exit 1
56 | fi
57 | ;;
58 | http)
59 | if ! command -v wget >/dev/null; then
60 | echoerr 'wget command is missing!'
61 | exit 1
62 | fi
63 | ;;
64 | esac
65 |
66 | TIMEOUT_END=$(($(date +%s) + TIMEOUT))
67 |
68 | while :; do
69 | case "$PROTOCOL" in
70 | tcp)
71 | nc -w 1 -z "$HOST" "$PORT" > /dev/null 2>&1
72 | ;;
73 | http)
74 | wget --timeout=1 -q "$HOST" -O /dev/null > /dev/null 2>&1
75 | ;;
76 | *)
77 | echoerr "Unknown protocol '$PROTOCOL'"
78 | exit 1
79 | ;;
80 | esac
81 |
82 | result=$?
83 |
84 | if [ $result -eq 0 ] ; then
85 | if [ $# -gt 7 ] ; then
86 | for result in $(seq $(($# - 7))); do
87 | result=$1
88 | shift
89 | set -- "$@" "$result"
90 | done
91 |
92 | TIMEOUT=$2 QUIET=$3 PROTOCOL=$4 HOST=$5 PORT=$6 result=$7
93 | shift 7
94 | exec "$@"
95 | fi
96 | exit 0
97 | fi
98 |
99 | if [ $TIMEOUT -ne 0 -a $(date +%s) -ge $TIMEOUT_END ]; then
100 | echo "Operation timed out" >&2
101 | exit 1
102 | fi
103 |
104 | sleep 1
105 | done
106 | }
107 |
108 | while :; do
109 | case "$1" in
110 | http://*|https://*)
111 | HOST="$1"
112 | PROTOCOL="http"
113 | shift 1
114 | ;;
115 | *:* )
116 | HOST=$(printf "%s\n" "$1"| cut -d : -f 1)
117 | PORT=$(printf "%s\n" "$1"| cut -d : -f 2)
118 | shift 1
119 | ;;
120 | -v | --version)
121 | echo $VERSION
122 | exit
123 | ;;
124 | -q | --quiet)
125 | QUIET=1
126 | shift 1
127 | ;;
128 | -q-*)
129 | QUIET=0
130 | echoerr "Unknown option: $1"
131 | usage 1
132 | ;;
133 | -q*)
134 | QUIET=1
135 | result=$1
136 | shift 1
137 | set -- -"${result#-q}" "$@"
138 | ;;
139 | -t | --timeout)
140 | TIMEOUT="$2"
141 | shift 2
142 | ;;
143 | -t*)
144 | TIMEOUT="${1#-t}"
145 | shift 1
146 | ;;
147 | --timeout=*)
148 | TIMEOUT="${1#*=}"
149 | shift 1
150 | ;;
151 | --)
152 | shift
153 | break
154 | ;;
155 | --help)
156 | usage 0
157 | ;;
158 | -*)
159 | QUIET=0
160 | echoerr "Unknown option: $1"
161 | usage 1
162 | ;;
163 | *)
164 | QUIET=0
165 | echoerr "Unknown argument: $1"
166 | usage 1
167 | ;;
168 | esac
169 | done
170 |
171 | if ! [ "$TIMEOUT" -ge 0 ] 2>/dev/null; then
172 | echoerr "Error: invalid timeout '$TIMEOUT'"
173 | usage 3
174 | fi
175 |
176 | case "$PROTOCOL" in
177 | tcp)
178 | if [ "$HOST" = "" ] || [ "$PORT" = "" ]; then
179 | echoerr "Error: you need to provide a host and port to test."
180 | usage 2
181 | fi
182 | ;;
183 | http)
184 | if [ "$HOST" = "" ]; then
185 | echoerr "Error: you need to provide a host to test."
186 | usage 2
187 | fi
188 | ;;
189 | esac
190 |
191 | wait_for "$@"
192 |
--------------------------------------------------------------------------------
/distributed/deploy/ns.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | apiVersion: v1
3 | kind: Namespace
4 | metadata:
5 | name: mkhodorchenko
6 | labels:
7 | istio-injection: disabled
8 |
--------------------------------------------------------------------------------
/distributed/deploy/test-datasets-config.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | first_dataset:
3 | base_path: "/storage/tiny_20newsgroups"
4 | dictionary_path: "dictionary.txt"
5 | batches_dir: "batches"
6 | mutual_info_dict_path: "mutual_info_dict.pkl"
7 | experiments_path: "/tmp/tm_experiments"
8 | topic_count: 10
9 |
10 | # the rest settings will be the same as for the first dataset but will be added automatically
11 | second_dataset:
12 | base_path: "/storage/tiny_20newsgroups_2"
13 | topic_count: 10
14 |
--------------------------------------------------------------------------------
/distributed/docker/base.dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:20.04
2 | RUN apt-get update && apt-get install -y bash python3 python3-pip
3 | #RUN pip3 install 'celery == 4.4.7' 'bigartm == 0.9.2' 'tqdm == 4.50.2' 'numpy == 1.19.2' 'dataclasses-json == 0.5.2'
4 | COPY requirements.txt /tmp/
5 | RUN pip3 install -r /tmp/requirements.txt
6 | COPY . /kube-distributed-fitness
7 | RUN pip3 install /kube-distributed-fitness
--------------------------------------------------------------------------------
/distributed/docker/cli.dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:18.04
2 |
3 | RUN DEBIAN_FRONTEND=noninteractive apt-get update -y && \
4 | DEBIAN_FRONTEND=noninteractive apt-get install -y git curl make cmake build-essential libboost-all-dev
5 |
6 | RUN curl -LJO https://github.com/bigartm/bigartm/archive/refs/tags/v0.9.2.tar.gz
7 |
8 | RUN tar -xvf bigartm-0.9.2.tar.gz
9 |
10 | RUN mkdir bigartm-0.9.2/build
11 |
12 | RUN DEBIAN_FRONTEND=noninteractive apt-get update && DEBIAN_FRONTEND=noninteractive apt-get --yes install bash python3.8
13 |
14 | RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
15 |
16 | RUN python3.8 get-pip.py
17 |
18 | RUN pip3 install protobuf tqdm wheel
19 |
20 | RUN cd bigartm-0.9.2/build && cmake .. -DBUILD_INTERNAL_PYTHON_API=OFF
21 |
22 | RUN cd bigartm-0.9.2/build && make install
23 |
24 | COPY requirements.txt /tmp/
25 |
26 | RUN pip3 install -r /tmp/requirements.txt
27 |
28 | COPY . /kube-distributed-fitness
29 |
30 | RUN pip3 install /kube-distributed-fitness
31 |
--------------------------------------------------------------------------------
/distributed/docker/client.dockerfile:
--------------------------------------------------------------------------------
1 | FROM fitness-base:latest
2 | ENTRYPOINT ["python3", "-u", "/kube-distributed-fitness/kube_fitness/test_app.py"]
--------------------------------------------------------------------------------
/distributed/docker/flower.dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:alpine
2 |
3 | # Get latest root certificates
4 | RUN apk add --no-cache ca-certificates && update-ca-certificates
5 |
6 | # Install the required packages
7 | RUN pip install --no-cache-dir redis flower==1.0.0
8 |
9 | # PYTHONUNBUFFERED: Force stdin, stdout and stderr to be totally unbuffered. (equivalent to `python -u`)
10 | # PYTHONHASHSEED: Enable hash randomization (equivalent to `python -R`)
11 | # PYTHONDONTWRITEBYTECODE: Do not write byte files to disk, since we maintain it as readonly. (equivalent to `python -B`)
12 | ENV PYTHONUNBUFFERED=1 PYTHONHASHSEED=random PYTHONDONTWRITEBYTECODE=1
13 |
14 | # Default port
15 | EXPOSE 5555
16 |
17 | ENV FLOWER_DATA_DIR /data
18 | ENV PYTHONPATH ${FLOWER_DATA_DIR}
19 |
20 | WORKDIR $FLOWER_DATA_DIR
21 |
22 | # Add a user with an explicit UID/GID and create necessary directories
23 | RUN set -eux; \
24 | addgroup -g 1000 flower; \
25 | adduser -u 1000 -G flower flower -D; \
26 | mkdir -p "$FLOWER_DATA_DIR"; \
27 | chown flower:flower "$FLOWER_DATA_DIR"
28 | USER flower
29 |
30 | VOLUME $FLOWER_DATA_DIR
31 |
32 | # for '-A distributed_fitness' see kube_fitness.tasks.make_celery_app
33 | ENTRYPOINT celery flower --address=0.0.0.0 --port=5555
--------------------------------------------------------------------------------
/distributed/docker/health-checker.dockerfile:
--------------------------------------------------------------------------------
1 | FROM boxboat/kubectl:1.21.3
2 |
3 | WORKDIR /
4 | USER root:root
5 |
6 | RUN apk add perl
7 |
8 | ENTRYPOINT [ "/usr/local/bin/kubectl"]
9 |
--------------------------------------------------------------------------------
/distributed/docker/mlflow-webserver.dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.8
2 |
3 | RUN pip install PyMySQL==0.9.3 psycopg2-binary==2.8.5 protobuf==3.20.1 mlflow[extras]==1.18.0
4 |
5 | ENTRYPOINT ["mlflow", "server"]
6 |
--------------------------------------------------------------------------------
/distributed/docker/test.dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:20.04
2 | RUN apt-get update && apt-get install -y bash python3 python3-pip
3 | RUN pip3 install 'celery == 4.4.7' 'bigartm == 0.9.2' 'tqdm == 4.50.2' 'numpy == 1.19.2' 'dataclasses-json == 0.5.2'
--------------------------------------------------------------------------------
/distributed/docker/worker.dockerfile:
--------------------------------------------------------------------------------
1 | FROM fitness-base:latest
2 | ENTRYPOINT python3 /kube-distributed-fitness/kube_fitness/kube_fitness/main.py \
3 | --concurrency 1 \
4 | --queues fitness_tasks \
5 | --loglevel INFO
--------------------------------------------------------------------------------
/distributed/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "autotm-distributed"
3 | version = "0.1.0"
4 | description = ""
5 | authors = ["fonhorst "]
6 | readme = "README.md"
7 | packages = [{include = "autotm_distributed"}]
8 |
9 | [tool.poetry.dependencies]
10 | python = "^3.9"
11 | bigartm = "0.9.2"
12 |
13 |
14 | [build-system]
15 | requires = ["poetry-core"]
16 | build-backend = "poetry.core.masonry.api"
17 |
--------------------------------------------------------------------------------
/distributed/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pkg_resources
4 | from setuptools import setup
5 |
6 |
7 | def deploy_data_files():
8 | templates = [os.path.join("deploy", f) for f in os.listdir("deploy") if f.endswith(".j2")]
9 | files = templates + ["deploy/test-datasets-config.yaml"]
10 | return files
11 |
12 |
13 | with open('requirements.txt') as f:
14 | strs = f.readlines()
15 |
16 | install_requires = [str(requirement) for requirement in pkg_resources.parse_requirements(strs)]
17 |
18 | setup(
19 | name='kube_fitness',
20 | version='0.1.0',
21 | description='kube fitness',
22 | package_dir={"": "kube_fitness"},
23 | packages=["kube_fitness"],
24 | install_requires=install_requires,
25 | include_package_data=True,
26 | data_files=[
27 | ('share/kube-fitness-data/deploy', deploy_data_files())
28 | ],
29 | scripts=['bin/kube-fitnessctl', 'bin/fitnessctl'],
30 | entry_points = {
31 | 'console_scripts': ['deploy-config-generator=kube_fitness.deploy_config_generator:main'],
32 | }
33 | )
--------------------------------------------------------------------------------
/distributed/test/topic_model_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from kube_fitness import AVG_COHERENCE_SCORE
4 | from kube_fitness.tm import TopicModelFactory, calculate_fitness_of_individual
5 |
6 |
7 | class TopicModelCase(unittest.TestCase):
8 | def test_topic_model(self):
9 | topic_count = 10
10 | dataset_name_1 = "first"
11 | dataset_name_2 = "second"
12 |
13 | param_s = [
14 | 66717.86348968784, 2.0, 98.42286785825902,
15 | 80.18543570807961, 2.0, -19.948347560420373,
16 | -52.28141634493725, 2.0, -92.85597392137976,
17 | -60.49287378084627, 4.0, 3.0, 0.06840138630839943,
18 | 0.556001061599461, 0.9894122432621849, 11679.364068753106
19 | ]
20 |
21 | metrics = [
22 | AVG_COHERENCE_SCORE,
23 | 'perplexityScore', 'backgroundTokensRatioScore', 'contrast',
24 | 'purity', 'kernelSize', 'npmi_50_list',
25 | 'npmi_50', 'sparsity_phi', 'sparsity_theta',
26 | 'topic_significance_uni', 'topic_significance_vacuous', 'topic_significance_back',
27 | 'switchP_list',
28 | 'switchP', 'all_topics',
29 | *(f'coherence_{i}' for i in range(10, 60, 5)),
30 | *(f'coherence_{i}_list' for i in range(10, 60, 5)),
31 | ]
32 |
33 | TopicModelFactory.init_factory_settings(num_processors=2, dataset_settings={
34 | dataset_name_1: {
35 | "base_path": '/home/nikolay/wspace/test_tiny_dataset',
36 | "topic_count": topic_count,
37 | },
38 | dataset_name_2: {
39 | "base_path": '/home/nikolay/wspace/test_tiny_dataset_2',
40 | "topic_count": topic_count,
41 | },
42 | })
43 |
44 | print(f"Calculating dataset {dataset_name_1}")
45 | fitness = calculate_fitness_of_individual(dataset_name_1, param_s, topic_count=topic_count)
46 | self.assertSetEqual(set(fitness.keys()), set(metrics))
47 | for m in metrics:
48 | self.assertIsNotNone(fitness[m])
49 |
50 | print(f"Calculating dataset {dataset_name_2}")
51 | fitness = calculate_fitness_of_individual(dataset_name_2, param_s, topic_count=topic_count)
52 | self.assertSetEqual(set(fitness.keys()), set(metrics))
53 | for m in metrics:
54 | self.assertIsNotNone(fitness[m])
55 |
56 | with self.assertRaises(Exception):
57 | calculate_fitness_of_individual("unknown_dataset", param_s, topic_count=topic_count)
58 |
59 |
60 | if __name__ == '__main__':
61 | unittest.main()
62 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/style.css:
--------------------------------------------------------------------------------
1 | .wy-nav-content {
2 | max-width: none;
3 | }
4 |
5 | .rst-content code.xref {
6 | /* !important prevents the common CSS stylesheets from overriding
7 | this as on RTD they are loaded after this stylesheet */
8 | color: #E74C3C
9 | }
10 |
11 | html.writer-html4 .rst-content dl:not(.docutils) dl:not(.field-list)>dt, html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) dl:not(.field-list)>dt {
12 | border-left-color: rgb(9, 183, 14)
13 | }
14 |
--------------------------------------------------------------------------------
/docs/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 | .. currentmodule:: {{ module }}
4 |
5 |
6 | {{ name | underline}}
7 |
8 | .. autoclass:: {{ name }}
9 | :members:
10 |
11 |
12 | ..
13 | autogenerated from source/_templates/autosummary/class.rst
14 | note it does not have :inherited-members:
15 |
--------------------------------------------------------------------------------
/docs/_templates/autosummary/module.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | {{ name | underline }}
5 |
6 | .. automodule:: {{ fullname }}
7 |
8 | {% block classes %}
9 | {% if classes %}
10 | .. rubric:: {{ _('Classes') }}
11 |
12 | .. autosummary::
13 | :toctree: generated
14 | :nosignatures:
15 | :template: classtemplate.rst
16 | {% for item in classes %}
17 | {{ item }}
18 | {%- endfor %}
19 | {% endif %}
20 | {% endblock %}
21 |
22 | {% block functions %}
23 | {% if functions %}
24 | .. rubric:: {{ _('Functions') }}
25 |
26 | .. autosummary::
27 | :toctree: generated
28 | :nosignatures:
29 | :template: functiontemplate.rst
30 | {% for item in functions %}
31 | {{ item }}
32 | {%- endfor %}
33 | {% endif %}
34 | {% endblock %}
35 |
36 |
37 | {% block modules %}
38 | {% if modules %}
39 | .. rubric:: {{ _('Modules') }}
40 |
41 | .. autosummary::
42 | :toctree:
43 | :recursive:
44 | {% for item in modules %}
45 | {{ item }}
46 | {%- endfor %}
47 | {% endif %}
48 | {% endblock %}
49 |
--------------------------------------------------------------------------------
/docs/_templates/classtemplate.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 | .. currentmodule:: {{ module }}
4 |
5 |
6 | {{ name | underline }}
7 |
8 | .. autoclass:: {{ name }}
9 | :members:
10 |
11 |
12 | ..
13 | autogenerated from source/_templates/classtemplate.rst
14 | note it does not have :inherited-members:
15 |
--------------------------------------------------------------------------------
/docs/_templates/functiontemplate.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 | .. currentmodule:: {{ module }}
4 |
5 | {{ name | underline }}
6 |
7 | .. autofunction:: {{ fullname }}
8 |
9 | ..
10 | autogenerated from source/_templates/functiontemplate.rst
11 | note it does not have :inherited-members:
12 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Project information -----------------------------------------------------
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8 |
9 | import os
10 | import sys
11 |
12 | CURR_PATH = os.path.abspath(os.path.dirname(__file__))
13 | LIB_PATH = os.path.join(CURR_PATH, os.path.pardir)
14 | sys.path.insert(0, LIB_PATH)
15 |
16 | project = "AutoTM"
17 | copyright = "2023, Strong AI Lab"
18 | author = "Khodorchenko Maria, Butakov Nikolay"
19 | release = "0.1.0"
20 |
21 | # -- General configuration ---------------------------------------------------
22 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
23 |
24 | extensions = ["sphinx.ext.napoleon", "sphinx.ext.autosummary"]
25 |
26 | # Delete external references
27 | autosummary_mock_imports = [
28 | "numpy",
29 | "pandas",
30 | "artm",
31 | "gensim",
32 | "billiard",
33 | "plotly",
34 | "scipy",
35 | "spacy_langdetect",
36 | "sklearn",
37 | "spacy",
38 | "pymystem3",
39 | "tqdm",
40 | "nltk"
41 | ]
42 |
43 | templates_path = ["_templates"]
44 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
45 |
46 | # -- Options for HTML output -------------------------------------------------
47 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
48 |
49 |
50 | html_theme = "sphinx_rtd_theme"
51 | pygments_style = 'sphinx'
52 |
53 | html_static_path = ['_static']
54 |
55 | napoleon_google_docstring = True
56 | napoleon_numpy_docstring = False
57 |
58 | # Autosummary true if you want to generate it from very beginning
59 | # autosummary_generate = True
60 |
61 | # set_type_checking_flag = True
62 |
--------------------------------------------------------------------------------
/docs/img/MyLogo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/docs/img/MyLogo.png
--------------------------------------------------------------------------------
/docs/img/autotm_arch_v3 (1).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/docs/img/autotm_arch_v3 (1).png
--------------------------------------------------------------------------------
/docs/img/img_library_eng.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/docs/img/img_library_eng.png
--------------------------------------------------------------------------------
/docs/img/photo_2023-06-29_20-20-57.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/docs/img/photo_2023-06-29_20-20-57.jpg
--------------------------------------------------------------------------------
/docs/img/photo_2023-06-30_14-17-08.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/docs/img/photo_2023-06-30_14-17-08.jpg
--------------------------------------------------------------------------------
/docs/img/pipeling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/docs/img/pipeling.png
--------------------------------------------------------------------------------
/docs/img/strategy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/docs/img/strategy.png
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. AutoTM documentation master file, created by
2 | sphinx-quickstart on Wed Apr 26 11:07:27 2023.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to AutoTM's documentation!
7 | ==================================
8 |
9 | `AutoTM `_ is am open-source Python library for topic modeling on texts.
10 |
11 | The goal of the library is to provide a simple-to-use interface to perform fast EDA or create interpretable text embeddings.
12 |
13 | .. toctree::
14 | :maxdepth: 1
15 | :caption: Contents:
16 |
17 |
18 | Installation
19 | Python api
20 | User guide
21 |
22 |
23 | Indices and tables
24 | ==================
25 |
26 | * :ref:`genindex`
27 | * :ref:`modindex`
28 | * :ref:`search`
29 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/pages/algorithms_for_tuning.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | autotm.algorithms_for_tuning
5 | ======================
6 |
7 | Algorithms for tuning.
8 |
9 | .. .. currentmodule:: autotm.algorithms_for_tuning.bayesian_optimization
10 | .. currentmodule:: autotm.algorithms_for_tuning.genetic_algorithm
11 |
12 | .. autosummary::
13 | :toctree: ./generated
14 | :nosignatures:
15 | :template: autosummary/class.rst
16 |
17 | crossover.crossover_pmx
18 | crossover.crossover_one_point
19 | crossover.crossover_blend_new
20 | crossover.crossover_blend
21 | crossover.crossover
22 |
23 |
--------------------------------------------------------------------------------
/docs/pages/api/algorithms_for_tuning.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | autotm.algorithms_for_tuning
5 | ======================
6 |
7 | Provides an internal interface for working with optimization algorithms.
8 |
9 | Genetic Algorithm
10 | ---------------------
11 |
12 | .. currentmodule:: autotm.algorithms_for_tuning.genetic_algorithm
13 |
14 | .. autosummary::
15 | :toctree: ./generated
16 | :nosignatures:
17 | :template: autosummary/class.rst
18 |
19 | crossover.crossover_pmx
20 | crossover.crossover_one_point
21 | crossover.crossover_blend
22 | crossover.crossover
23 | mutation.mutation_one_param
24 | mutation.positioning_mutation
25 | mutation.mutation_combined
26 | mutation.mutation_psm
27 | mutation.mutation
28 |
29 | .. .. currentmodule:: autotm.algorithms_for_tuning.bayesian_optimization
--------------------------------------------------------------------------------
/docs/pages/api/fitness.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | autotm.fitness
5 | ======================
6 |
7 | Fitness.
8 |
9 | .. currentmodule:: autotm.fitness.tm
10 |
11 | .. autosummary::
12 | :toctree: ./generated
13 | :nosignatures:
14 | :template: autosummary/class.rst
15 |
16 | Dataset
17 | TopicModelFactory
18 | FitnessCalculatorWrapper
19 | calculate_fitness_of_individual
20 | TopicModel
21 |
--------------------------------------------------------------------------------
/docs/pages/api/index.rst:
--------------------------------------------------------------------------------
1 | .. toctree::
2 | :maxdepth: 2
3 |
4 | algorithms_for_tuning
5 | fitness
6 | preprocessing
7 | visualisation
--------------------------------------------------------------------------------
/docs/pages/api/preprocessing.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | autotm.preprocessing
5 | ======================
6 |
7 | Preprocessing.
8 |
9 | .. currentmodule:: autotm.preprocessing
10 |
11 | .. autosummary::
12 | :toctree: ./generated
13 | :nosignatures:
14 |
15 | dictionaries_preparation.get_words_dict
16 |
--------------------------------------------------------------------------------
/docs/pages/api/visualization.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | autotm.visualization
5 | ======================
6 |
7 | Visualization.
8 |
9 | .. currentmodule:: autotm.visualization.dynamic_tracker
10 |
11 | .. autosummary::
12 | :toctree: ./generated
13 | :nosignatures:
14 | :template: autosummary/class.rst
15 |
16 | MetricsCollector
17 |
18 |
--------------------------------------------------------------------------------
/docs/pages/fitness.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | autotm.fitness
5 | ======================
6 |
7 | Fitness.
8 |
9 | .. currentmodule:: autotm.fitness.tm
10 |
11 | .. autosummary::
12 | :toctree: ./generated
13 | :nosignatures:
14 | :template: autosummary/class.rst
15 |
16 | Dataset
17 | TopicModelFactory
18 | FitnessCalculatorWrapper
19 | calculate_fitness_of_individual
20 | TopicModel
21 |
--------------------------------------------------------------------------------
/docs/pages/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ==================
3 |
4 |
5 | Pip installation
6 | -----
7 |
8 | You can install library `AutoTM` from PyPI.
9 |
10 | .. code-block:: bash
11 |
12 | pip install autotm
13 |
14 |
15 | Development
16 | -----------
17 |
18 | You can also clone repository and install with poetry.
19 |
--------------------------------------------------------------------------------
/docs/pages/preprocessing.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | autotm.preprocessing
5 | ======================
6 |
7 | Preprocessing.
8 |
9 | .. currentmodule:: autotm.preprocessing
10 |
11 | .. autosummary::
12 | :toctree: ./generated
13 | :nosignatures:
14 |
15 | dictionaries_preparation.get_words_dict
16 |
--------------------------------------------------------------------------------
/docs/pages/userguide/index.rst:
--------------------------------------------------------------------------------
1 | .. _python_userguide:
2 |
3 | Python Guide
4 | ============
5 |
6 | .. toctree::
7 | :maxdepth: 2
8 |
9 | regularizers
10 | metrics
11 |
--------------------------------------------------------------------------------
/docs/pages/userguide/metrics.rst:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/docs/pages/userguide/metrics.rst
--------------------------------------------------------------------------------
/docs/pages/userguide/regularizers.rst:
--------------------------------------------------------------------------------
1 | Regularizers
2 | ================================
3 |
4 | AutoTM library is using a set of regularizers while optimizing learning strategy.
--------------------------------------------------------------------------------
/docs/pages/visualization.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | autotm.visualization
5 | ======================
6 |
7 | Visualization.
8 |
9 | .. currentmodule:: autotm.visualization.dynamic_tracker
10 |
11 | .. autosummary::
12 | :toctree: ./generated
13 | :nosignatures:
14 | :template: autosummary/class.rst
15 |
16 | MetricsCollector
17 |
18 |
--------------------------------------------------------------------------------
/examples/examples_autotm_fit_predict.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import os
4 | import uuid
5 | from typing import Dict, Any, Optional
6 |
7 | import pandas as pd
8 | from sklearn.model_selection import train_test_split
9 |
10 | from autotm.base import AutoTM
11 |
12 | logging.basicConfig(
13 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
14 | level=logging.DEBUG, datefmt='%Y-%m-%d %H:%M:%S'
15 | )
16 | logger = logging.getLogger()
17 |
18 | # "base" configuration is an example of default configuration of the algorithm
19 | # "base_en" is showing how to specify dataset parameters and pass it directly to the algorithm
20 | # "static_chromosome" is a previous version of optimization algorithm, may work faster comparing to the "pipeline" version, but results are slightly worse
21 | # "surrogate" is showing how to start training with surrogates option (increase optimization speed)
22 | # "llm" is an example of how to run optimization with gpt-base quality function
23 | # "bayes" is an option with Bayesian Optimization
24 | CONFIGURATIONS = {
25 | "base": {
26 | "alg_name": "ga",
27 | "num_iterations": 50,
28 | "use_pipeline": True
29 | },
30 | "base_en": {
31 | "alg_name": "ga",
32 | "dataset": {
33 | "lang": "en",
34 | "dataset_path": "data/sample_corpora/imdb_100.csv",
35 | "dataset_name": "imdb_100"
36 | },
37 | "num_iterations": 20,
38 | "use_pipeline": True
39 | },
40 | "static_chromosome": {
41 | "alg_name": "ga",
42 | "num_iterations": 20,
43 | "use_pipeline": False
44 | },
45 | "gost_example": {
46 | "topic_count": 50,
47 | "alg_name": "ga",
48 | "num_iterations": 10,
49 | "use_pipeline": True,
50 | "dataset": {
51 | "lang": "ru",
52 | "col_to_process": 'paragraph',
53 | "dataset_path": "data/sample_corpora/clean_docs_v17_gost_only.csv",
54 | "dataset_name": "gost"
55 | },
56 | # "individual_type": "llm"
57 | },
58 | "surrogate": {
59 | "alg_name": "ga",
60 | "num_iterations": 20,
61 | "use_pipeline": True,
62 | "surrogate_name": "random-forest-regressor"
63 | },
64 | "llm": {
65 | "alg_name": "ga",
66 | "num_iterations": 20,
67 | "use_pipeline": True,
68 | "individual_type": "llm"
69 | },
70 | "bayes": {
71 | "alg_name": "bayes",
72 | "num_evaluations": 150,
73 | }
74 | }
75 |
76 |
77 | def run(alg_name: str,
78 | alg_params: Dict[str, Any],
79 | dataset: Optional[Dict[str, Any]] = None,
80 | col_to_process: Optional[str] = None,
81 | topic_count: int = 20):
82 | if not dataset:
83 | dataset = {
84 | "lang": "ru",
85 | "dataset_path": "data/sample_corpora/sample_dataset_lenta.csv",
86 | "dataset_name": "lenta_ru"
87 | }
88 |
89 | df = pd.read_csv(dataset['dataset_path'])
90 | train_df, test_df = train_test_split(df, test_size=0.1)
91 |
92 | working_dir_path = f"./autotm_workdir_{uuid.uuid4()}"
93 | model_path = os.path.join(working_dir_path, "autotm_model")
94 |
95 | autotm = AutoTM(
96 | topic_count=topic_count,
97 | preprocessing_params={
98 | "lang": dataset['lang']
99 | },
100 | texts_column_name=col_to_process,
101 | alg_name=alg_name,
102 | alg_params=alg_params,
103 | working_dir_path=working_dir_path,
104 | exp_dataset_name=dataset["dataset_name"]
105 | )
106 | mixtures = autotm.fit_predict(train_df)
107 |
108 | logger.info(f"Calculated train mixtures: {mixtures.shape}\n\n{mixtures.head(10).to_string()}")
109 |
110 | # saving the model
111 | autotm.save(model_path)
112 |
113 | # loading and checking if everything is fine with predicting
114 | autotm_loaded = AutoTM.load(model_path)
115 | mixtures = autotm_loaded.predict(test_df)
116 |
117 | logger.info(f"Calculated train mixtures: {mixtures.shape}\n\n{mixtures.head(10).to_string()}")
118 |
119 |
120 | def main(conf_name: str = "base"):
121 | if conf_name not in CONFIGURATIONS:
122 | raise ValueError(
123 | f"Unknown configuration {conf_name}. Available configurations: {sorted(CONFIGURATIONS.keys())}"
124 | )
125 |
126 | conf = CONFIGURATIONS[conf_name]
127 | alg_name = conf['alg_name']
128 | del conf['alg_name']
129 |
130 | dataset = None
131 | col_to_process = None
132 | if 'dataset' in conf:
133 | dataset = conf['dataset']
134 | col_to_process = conf['dataset'].get('col_to_process', None)
135 | del conf['dataset']
136 |
137 | topic_count = 20
138 | if 'topic_count' in conf:
139 | topic_count = conf['topic_count']
140 | del conf['topic_count']
141 |
142 | run(
143 | alg_name=alg_name,
144 | alg_params=conf,
145 | dataset=dataset,
146 | col_to_process=col_to_process,
147 | topic_count=topic_count
148 | )
149 |
150 |
151 | if __name__ == "__main__":
152 | main(conf_name="gost_example")
153 |
--------------------------------------------------------------------------------
/examples/graph_building_for_stroitelstvo.py:
--------------------------------------------------------------------------------
1 | import time
2 | import logging
3 | import os
4 | import uuid
5 | import pandas as pd
6 | import numpy as np
7 | from datasets import load_dataset
8 | from sklearn.model_selection import train_test_split
9 | from autotm.base import AutoTM
10 | from autotm.ontology.ontology_extractor import build_graph
11 | import networkx as nx
12 |
13 |
14 | df = pd.read_dataset('../data/sample_corpora/dataset_books_stroitelstvo.csv')
15 |
16 | working_dir_path = 'autotm_artifacts'
17 |
18 | autotm = AutoTM(
19 | topic_count=10,
20 | texts_column_name='text',
21 | preprocessing_params={
22 | "lang": "ru",
23 | },
24 | alg_params={
25 | "num_iterations": 10,
26 | },
27 | working_dir_path=working_dir_path
28 | )
29 |
30 | mixtures = autotm.fit_predict(df)
31 |
32 | # посмотрим на получаемые темы
33 | autotm.print_topics()
34 |
35 | # Составим словарь наименований тем! Обратите внимание, что есть вероятность получения других смесей тем при новом запуске
36 | labels_dict = {'main0':'Кровля',
37 | 'main1': 'Фундамент',
38 | 'main2':'Техника строительства',
39 | 'main3': 'Электросбережение',
40 | 'main4':'Генплан',
41 | 'main5': 'Участок',
42 | 'main6': 'Лестница',
43 | 'main7':'Внешняя отделка',
44 | 'main8': 'Стены',
45 | 'main9': 'Отделка (покраска)',
46 | }
47 |
48 | res_dict, nodes = build_graph(autotm, topic_labels=labels_dict)
49 |
50 | # визуализируем граф
51 | g = nx.DiGraph()
52 | g.add_nodes_from(nodes)
53 | for k, v in res_dict.items():
54 | g.add_edges_from(([(k)]))
55 | nx.draw(g, with_labels=True)
--------------------------------------------------------------------------------
/examples/graph_profile_example.py:
--------------------------------------------------------------------------------
1 | import time
2 | import logging
3 | import os
4 | import uuid
5 | import pandas as pd
6 | import numpy as np
7 | from datasets import load_dataset
8 | from sklearn.model_selection import train_test_split
9 | from autotm.base import AutoTM
10 | from autotm.ontology.ontology_extractor import build_graph
11 | from autotm.clustering import cluster_phi
12 | import networkx as nx
13 |
14 |
15 | df = load_dataset('zloelias/lenta-ru') # https://huggingface.co/datasets
16 | text_sample = df['train']['text']
17 | df = pd.DataFrame({'text': text_sample})
18 |
19 | working_dir_path = 'autotm_artifacts'
20 |
21 | # Инициализируем модель и укажем наименование
22 | # суррогатной модели для ускорения вычислений
23 | autotm = AutoTM(
24 | topic_count=20,
25 | texts_column_name='text',
26 | preprocessing_params={
27 | "lang": "ru",
28 | },
29 | # alg_name='ga',
30 | alg_params={
31 | "num_iterations": 15,
32 | "surrogate_alg_name": "GPR"
33 | },
34 | working_dir_path=working_dir_path
35 | )
36 |
37 | mixtures = autotm.fit_predict(df)
38 |
39 | # извлекаем матрицу phi - распределение слов по темам
40 | phi_df = autotm._model.get_phi()
41 |
42 | # проведем кластеризацию полученных данных
43 | # из которой можно увидеть, что достаточно хорошо выделяются кластера
44 | cluster_phi(phi_df, n_clusters=10, plot_img=True)
45 |
46 | # осмотрим, какие документы есть в кластерах
47 | mixtures['labels'] = y_kmeans.labels_
48 | res_df = df.join(mixtures)
49 | print(res_df[res_df['labels'] == 8].text.tolist()[:3])
50 |
51 | # займемся построением графа, как мы уже знаем
52 | # для лучшей интерпретируемости лучше сделать
53 | # словарь названий тем
54 | autotm.print_topics()
55 | labels_dict = {'main0':'Устройства',
56 | 'main1': 'Спорт',
57 | 'main2':'Экономика',
58 | 'main3': 'Чемпионат',
59 | 'main4':'Исследование',
60 | 'main5': 'Награждение',
61 | 'main6': 'Суд',
62 | 'main7':'Общее',
63 | 'main8': 'Авиаперелеты',
64 | 'main9': 'Музеи',
65 | 'main10': 'Правительство',
66 | 'main11': 'Интернет',
67 | 'main12': 'Искусство',
68 | 'main13': 'Война',
69 | 'main14': 'Нефть',
70 | 'main15': 'Космос',
71 | 'main16': 'Соревнования',
72 | 'main17': 'Биржа',
73 | 'main18': 'Финансы',
74 | 'main19': 'Концерт'
75 | }
76 |
77 | res_dict, nodes = build_graph(autotm, topic_labels=labels_dict)
78 |
79 | # посмотрим на результат обработки - словарь со связями
80 | print(res_dict)
81 |
82 | # Визуализируем получаемый граф
83 | g = nx.DiGraph()
84 | g.add_nodes_from(nodes)
85 | for k, v in res_dict.items():
86 | g.add_edges_from(([(k)]))
87 | nx.draw(g, with_labels=True)
88 |
89 | # Раскроем одну из нод
90 | mixtures_subset = mixtures.sort_values('main0', ascending=False).head(1000)
91 |
92 | subset_df = df.join(mixtures_subset)
93 |
94 | autotm_subs = AutoTM(
95 | topic_count=6,
96 | texts_column_name='text',
97 | preprocessing_params={
98 | "lang": "ru",
99 | },
100 | alg_params={
101 | "num_iterations": 10,
102 | "surrogate_alg_name": "GPR"
103 | },
104 | working_dir_path=working_dir_path
105 | )
106 |
107 | autotm_subs.print_topics()
108 | subs_labels = {
109 | 'main0': 'Технологии',
110 | 'main1': 'Разработка',
111 | 'main2': 'Игры',
112 | 'main3': 'Компьютеры',
113 | 'main4': 'Мобильные',
114 | 'main5': 'Переферия'
115 | }
116 |
117 | autotm_subs.fit_predict(subset_df)
118 | res_dict_subs, nodes_subs = build_graph(autotm_subs, topic_labels=subs_labels)
119 |
120 | g = nx.DiGraph()
121 | g.add_nodes_from(nodes_subs)
122 | for k, v in res_dict_subs.items():
123 | g.add_edges_from(([(k)]))
124 | nx.draw(g, with_labels=True)
--------------------------------------------------------------------------------
/examples/topic_modeling_of_corporative_data.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import os
4 | import uuid
5 | from typing import Dict, Any, Optional
6 |
7 | import pandas as pd
8 | from sklearn.model_selection import train_test_split
9 |
10 | from autotm.base import AutoTM
11 |
12 | logging.basicConfig(
13 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
14 | level=logging.DEBUG, datefmt='%Y-%m-%d %H:%M:%S'
15 | )
16 | logger = logging.getLogger()
17 |
18 |
19 | def main():
20 | df = pd.read_csv('data/sample_corpora/clean_docs_v17_gost_only.csv')
21 | train_df, test_df = train_test_split(df, test_size=0.1)
22 |
23 | working_dir_path = f"./autotm_workdir_{uuid.uuid4()}"
24 | model_path = os.path.join(working_dir_path, "autotm_model")
25 |
26 | autotm = AutoTM(
27 | topic_count=50,
28 | texts_column_name='paragraph',
29 | preprocessing_params={
30 | "lang": "ru"
31 | },
32 | alg_name="ga",
33 | alg_params={
34 | "num_iterations": 10,
35 | "use_pipeline": True
36 | },
37 | individual_type="llm",
38 | working_dir_path=working_dir_path,
39 | exp_dataset_name="gost"
40 | )
41 | mixtures = autotm.fit_predict(train_df)
42 |
43 | logger.info(f"Calculated train mixtures: {mixtures.shape}\n\n{mixtures.head(10).to_string()}")
44 |
45 | # saving the model
46 | autotm.save(model_path)
47 |
48 | # loading and checking if everything is fine with predicting
49 | autotm_loaded = AutoTM.load(model_path)
50 | mixtures = autotm_loaded.predict(test_df)
51 |
52 | logger.info(f"Calculated train mixtures: {mixtures.shape}\n\n{mixtures.head(10).to_string()}")
53 |
54 |
55 |
56 | if __name__ == "__main__":
57 | main(conf_name="gost_example")
58 |
--------------------------------------------------------------------------------
/logging.config.yml:
--------------------------------------------------------------------------------
1 | version: 1
2 | console_log:
3 | level: DEBUG
4 | formatters:
5 | simple:
6 | class: logging.Formatter
7 | format: "%(asctime)s %(name)s %(levelname)s %(message)s"
8 | datefmt: "%Y-%m-%d %H:%M:%S"
9 | handlers:
10 | file_handler:
11 | class: logging.FileHandler
12 | filename: example.log
13 | level: DEBUG
14 | formatter: simple
15 | stream_handler:
16 | class: logging.StreamHandler
17 | stream: ext://sys.stderr
18 | level: cfg://console_log.level
19 | formatter: simple
20 | loggers:
21 | logging_example:
22 | level: DEBUG
23 | handlers: [file_handler]
24 | propagate: yes
25 | root:
26 | level: DEBUG
27 | handlers: [stream_handler]
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "autotm"
3 | version = "0.2.3.4"
4 | description = "Automatic hyperparameters tuning for topic models (ARTM approach) using evolutionary algorithms"
5 | authors = [
6 | "Khodorchenko Maria ",
7 | "Nikolay Butakov alipoov.nb@gmail.com"
8 | ]
9 | readme = "README.md"
10 | license = "Apache-2.0"
11 | homepage = "https://autotm.readthedocs.io/en/latest/"
12 | repository = "https://github.com/ngc436/AutoTM"
13 | packages = [
14 | {include = "autotm"}
15 | ]
16 |
17 |
18 | [tool.poetry.dependencies]
19 | python = ">=3.9, <3.12"
20 | bigartm = "0.9.2"
21 | protobuf = "<=3.20.0"
22 | tqdm = "4.50.2"
23 | numpy = "*"
24 | PyYAML = "5.3.1"
25 | dataclasses-json = "*"
26 | mlflow = "*"
27 | click = "8.0.1"
28 | scipy = "<=1.10.1"
29 | hyperopt = "*"
30 | pymystem3 = "*"
31 | nltk = "*"
32 | plotly = "*"
33 | spacy = ">=3.5"
34 | spacy-langdetect = "*"
35 | gensim = "^4.1.2"
36 | pandas = "*"
37 | billiard = "*"
38 | dill = "*"
39 | pytest = "*"
40 | celery = ">=4.4.7"
41 | redis = "3.5.3"
42 | jinja2 = "3.0"
43 | PyMySQL = "*"
44 | psycopg2-binary = "*"
45 | pymongo = "3.11.3"
46 | scikit-learn = "^1.1.1"
47 | #pydantic = "1.10.8"
48 | pydantic = "2.6.0"
49 | openai = "^1.31.0"
50 | seaborn = "^0.13.2"
51 |
52 | [tool.poetry.dev-dependencies]
53 | black = "*"
54 | sphinx = "*"
55 | flake8 = "*"
56 |
57 | [tool.pytest.ini_options]
58 | log_cli = true
59 | log_cli_level = "INFO"
60 | log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
61 | log_cli_date_format = "%Y-%m-%d %H:%M:%S"
62 |
63 | [tool.poetry.scripts]
64 | autotmctl = 'autotm.main:cli'
65 | fitness-worker = 'autotm.main_fitness_worker:main'
66 |
67 | [build-system]
68 | requires = ["poetry-core"]
69 | build-backend = "poetry.core.masonry.api"
70 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/algs/full_pipeline_example.py:
--------------------------------------------------------------------------------
1 | # input example_data padas df with 'text' column
2 | import logging
3 | import time
4 |
5 | import pandas as pd
6 |
7 | from autotm.algorithms_for_tuning.genetic_algorithm.genetic_algorithm import (
8 | run_algorithm,
9 | )
10 | from autotm.preprocessing.dictionaries_preparation import prepare_all_artifacts
11 | from autotm.preprocessing.text_preprocessing import process_dataset
12 |
13 |
14 | logging.basicConfig(level=logging.DEBUG)
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | PATH_TO_DATASET = "../../data/sample_corpora/sample_dataset_lenta.csv" # dataset with corpora to be processed
19 | SAVE_PATH = (
20 | "../data/processed_sample_corpora" # place where all the artifacts will be stored
21 | )
22 |
23 | dataset = pd.read_csv(PATH_TO_DATASET)
24 | col_to_process = "text"
25 | dataset_name = "lenta_sample"
26 | lang = "ru" # available languages: ru, en
27 | min_tokens_num = 3 # the minimal amount of tokens after processing to save the result
28 | num_iterations = 2
29 | topic_count = 10
30 | exp_id = int(time.time())
31 | print(exp_id)
32 |
33 | use_nelder_mead_in_mutation = False
34 | use_nelder_mead_in_crossover = False
35 | use_nelder_mead_in_selector = False
36 | train_option = "offline"
37 |
38 | if __name__ == "__main__":
39 | logger.info("Stage 1: Dataset preparation")
40 | process_dataset(
41 | PATH_TO_DATASET,
42 | col_to_process,
43 | SAVE_PATH,
44 | lang,
45 | min_tokens_count=min_tokens_num,
46 | )
47 |
48 | logger.info("Stage 2: Prepare all artefacts")
49 | prepare_all_artifacts(SAVE_PATH)
50 |
51 | logger.info("Stage 3: Tuning the topic model")
52 | # exp_id and dataset_name will be needed further to store results in mlflow
53 | best_result = run_algorithm(
54 | data_path=SAVE_PATH,
55 | dataset=dataset_name,
56 | exp_id=exp_id,
57 | topic_count=topic_count,
58 | log_file="./log_file_test.txt",
59 | num_iterations=num_iterations,
60 | use_nelder_mead_in_mutation=use_nelder_mead_in_mutation,
61 | use_nelder_mead_in_crossover=use_nelder_mead_in_crossover,
62 | use_nelder_mead_in_selector=use_nelder_mead_in_selector,
63 | train_option=train_option,
64 | )
65 | logger.info("All finished")
66 |
--------------------------------------------------------------------------------
/scripts/algs/nelder_mead_experiments.py:
--------------------------------------------------------------------------------
1 | from autotm.algorithms_for_tuning.nelder_mead_optimization.nelder_mead import (
2 | NelderMeadOptimization,
3 | )
4 | import time
5 | import pandas as pd
6 | import os
7 |
8 | # 2
9 |
10 | fnames = [
11 | "20newsgroups_sample",
12 | "amazon_food_sample",
13 | "banners_sample",
14 | "hotel-reviews_sample",
15 | "lenta_ru_sample",
16 | ]
17 | data_lang = ["en", "en", "ru", "en", "ru"]
18 | dataset_id = 0
19 |
20 | DATA_PATH = os.path.join("/ess_data/GOTM/datasets_TM_scoring", fnames[dataset_id])
21 |
22 | PATH_TO_DATASET = os.path.join(
23 | DATA_PATH, "dataset_processed.csv"
24 | ) # dataset with corpora to be processed
25 | SAVE_PATH = DATA_PATH # place where all the artifacts will be stored
26 |
27 | dataset = pd.read_csv(PATH_TO_DATASET)
28 | dataset_name = fnames[dataset_id] + "sample_with_nm"
29 | lang = data_lang[dataset_id] # available languages: ru, en
30 | min_tokens_num = 3 # the minimal amount of tokens after processing to save the result
31 | num_iterations = 200
32 | topic_count = 10
33 | exp_id = int(time.time())
34 | print(exp_id)
35 | train_option = "offline"
36 |
37 | if __name__ == "__main__":
38 | nelder_opt = NelderMeadOptimization(
39 | data_path=SAVE_PATH,
40 | dataset=dataset_name,
41 | exp_id=exp_id,
42 | topic_count=topic_count,
43 | train_option=train_option,
44 | )
45 |
46 | res = nelder_opt.run_algorithm(num_iterations=num_iterations)
47 | print(res)
48 | print(-res.fun)
49 |
--------------------------------------------------------------------------------
/scripts/experiments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/scripts/experiments/__init__.py
--------------------------------------------------------------------------------
/scripts/experiments/plot_boxplot_Final_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/scripts/experiments/plot_boxplot_Final_results.png
--------------------------------------------------------------------------------
/scripts/experiments/plot_boxplot_Start_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/scripts/experiments/plot_boxplot_Start_results.png
--------------------------------------------------------------------------------
/scripts/experiments/plot_progress_hotel-reviews_sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/scripts/experiments/plot_progress_hotel-reviews_sample.png
--------------------------------------------------------------------------------
/scripts/experiments/statistics/240602-124458_hotel-reviews_sample_fixed_progress.txt:
--------------------------------------------------------------------------------
1 | hotel-reviews_sample,False,4,0.6016707242187753
2 | hotel-reviews_sample,False,5,0.6016707242187753
3 | hotel-reviews_sample,False,6,0.6016707242187753
4 | hotel-reviews_sample,False,7,0.6016707242187753
5 | hotel-reviews_sample,False,8,0.6016707242187753
6 | hotel-reviews_sample,False,9,0.6016707242187753
7 | hotel-reviews_sample,False,10,0.6222956395585306
8 |
--------------------------------------------------------------------------------
/scripts/experiments/statistics/240602-124539_hotel-reviews_sample_fixed_progress.txt:
--------------------------------------------------------------------------------
1 | hotel-reviews_sample,False,4,0.6364207556768978
2 | hotel-reviews_sample,False,5,0.6364207556768978
3 | hotel-reviews_sample,False,6,0.6364207556768978
4 | hotel-reviews_sample,False,7,0.6364207556768978
5 | hotel-reviews_sample,False,8,0.6364207556768978
6 | hotel-reviews_sample,False,9,0.6364207556768978
7 | hotel-reviews_sample,False,10,0.6364207556768978
8 |
--------------------------------------------------------------------------------
/scripts/experiments/statistics/240602-124633_hotel-reviews_sample_fixed_progress.txt:
--------------------------------------------------------------------------------
1 | hotel-reviews_sample,False,4,0.6364207556768978
2 | hotel-reviews_sample,False,5,0.6364207556768978
3 | hotel-reviews_sample,False,6,0.6364207556768978
4 | hotel-reviews_sample,False,7,0.6364207556768978
5 | hotel-reviews_sample,False,8,0.6364207556768978
6 | hotel-reviews_sample,False,9,0.6364207556768978
7 | hotel-reviews_sample,False,10,0.6364207556768978
8 |
--------------------------------------------------------------------------------
/scripts/experiments/statistics/240602-124719_hotel-reviews_sample_pipeline_progress.txt:
--------------------------------------------------------------------------------
1 | hotel-reviews_sample,True,4,0.5368079177080817
2 | hotel-reviews_sample,True,5,0.5368079177080817
3 | hotel-reviews_sample,True,6,0.5368079177080817
4 | hotel-reviews_sample,True,7,0.5368079177080817
5 | hotel-reviews_sample,True,8,0.5368079177080817
6 | hotel-reviews_sample,True,9,0.5368079177080817
7 | hotel-reviews_sample,True,10,0.8378543293795918
8 |
--------------------------------------------------------------------------------
/scripts/experiments/statistics/240602-124816_hotel-reviews_sample_pipeline_progress.txt:
--------------------------------------------------------------------------------
1 | hotel-reviews_sample,True,4,0.7293230550311838
2 | hotel-reviews_sample,True,5,0.7293230550311838
3 | hotel-reviews_sample,True,6,0.7293230550311838
4 | hotel-reviews_sample,True,7,0.7293230550311838
5 | hotel-reviews_sample,True,8,0.7293230550311838
6 | hotel-reviews_sample,True,9,0.7293230550311838
7 | hotel-reviews_sample,True,10,0.7388121081634285
8 |
--------------------------------------------------------------------------------
/scripts/experiments/statistics/240602-124950_hotel-reviews_sample_pipeline_progress.txt:
--------------------------------------------------------------------------------
1 | hotel-reviews_sample,True,4,0.6729334700766527
2 | hotel-reviews_sample,True,5,0.6729334700766527
3 | hotel-reviews_sample,True,6,0.6729334700766527
4 | hotel-reviews_sample,True,7,0.6729334700766527
5 | hotel-reviews_sample,True,8,0.6729334700766527
6 | hotel-reviews_sample,True,9,0.6729334700766527
7 | hotel-reviews_sample,True,10,0.6729334700766527
8 |
--------------------------------------------------------------------------------
/scripts/experiments/statistics/240603-231051_hotel-reviews_sample_fixed_parameters.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/scripts/experiments/statistics/240603-231051_hotel-reviews_sample_fixed_parameters.txt
--------------------------------------------------------------------------------
/scripts/experiments/statistics/240603-231154_hotel-reviews_sample_fixed_progress.txt:
--------------------------------------------------------------------------------
1 | hotel-reviews_sample,False,4,0.6783270373502042
2 | hotel-reviews_sample,False,5,0.6783270373502042
3 | hotel-reviews_sample,False,6,0.6783270373502042
4 | hotel-reviews_sample,False,7,0.6783270373502042
5 | hotel-reviews_sample,False,8,0.682302624412245
6 | hotel-reviews_sample,False,9,0.682302624412245
7 | hotel-reviews_sample,False,10,0.682302624412245
8 |
--------------------------------------------------------------------------------
/scripts/experiments/statistics/240603-231313_hotel-reviews_sample_pipeline_progress.txt:
--------------------------------------------------------------------------------
1 | hotel-reviews_sample,True,4,0.7047850565368978
2 | hotel-reviews_sample,True,5,0.7047850565368978
3 | hotel-reviews_sample,True,6,0.7047850565368978
4 | hotel-reviews_sample,True,7,0.7047850565368978
5 | hotel-reviews_sample,True,8,0.7258411930803266
6 | hotel-reviews_sample,True,9,0.7258411930803266
7 | hotel-reviews_sample,True,10,0.7258411930803266
8 |
--------------------------------------------------------------------------------
/scripts/experiments/statistics/240603-232451_hotel-reviews_sample_fixed_progress.txt:
--------------------------------------------------------------------------------
1 | hotel-reviews_sample,False,4,0.652324503842612
2 | hotel-reviews_sample,False,5,0.652324503842612
3 | hotel-reviews_sample,False,6,0.652324503842612
4 | hotel-reviews_sample,False,7,0.652324503842612
5 | hotel-reviews_sample,False,8,0.652324503842612
6 | hotel-reviews_sample,False,9,0.652324503842612
7 | hotel-reviews_sample,False,10,0.6837248912219591
8 |
--------------------------------------------------------------------------------
/scripts/experiments/statistics/240603-232546_hotel-reviews_sample_fixed_progress.txt:
--------------------------------------------------------------------------------
1 | hotel-reviews_sample,False,4,0.7727277367346939
2 | hotel-reviews_sample,False,5,0.7727277367346939
3 | hotel-reviews_sample,False,6,0.7727277367346939
4 | hotel-reviews_sample,False,7,0.7727277367346939
5 | hotel-reviews_sample,False,8,0.7727277367346939
6 | hotel-reviews_sample,False,9,0.7727277367346939
7 | hotel-reviews_sample,False,10,0.7727277367346939
8 |
--------------------------------------------------------------------------------
/scripts/experiments/statistics/240603-232733_hotel-reviews_sample_fixed_progress.txt:
--------------------------------------------------------------------------------
1 | hotel-reviews_sample,False,4,0.8807936593885717
2 | hotel-reviews_sample,False,5,0.8807936593885717
3 | hotel-reviews_sample,False,6,0.8807936593885717
4 | hotel-reviews_sample,False,7,0.8807936593885717
5 | hotel-reviews_sample,False,8,0.8807936593885717
6 | hotel-reviews_sample,False,9,0.8807936593885717
7 | hotel-reviews_sample,False,10,0.8807936593885717
8 |
--------------------------------------------------------------------------------
/scripts/experiments/statistics/240603-232847_hotel-reviews_sample_pipeline_parameters.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/scripts/experiments/statistics/240603-232847_hotel-reviews_sample_pipeline_parameters.txt
--------------------------------------------------------------------------------
/scripts/experiments/statistics/240603-232916_hotel-reviews_sample_fixed_progress.txt:
--------------------------------------------------------------------------------
1 | hotel-reviews_sample,False,4,0.8121951826857141
2 | hotel-reviews_sample,False,5,0.8121951826857141
3 | hotel-reviews_sample,False,6,0.8121951826857141
4 | hotel-reviews_sample,False,7,0.8121951826857141
5 | hotel-reviews_sample,False,8,0.8121951826857141
6 | hotel-reviews_sample,False,9,0.8121951826857141
7 | hotel-reviews_sample,False,10,0.8121951826857141
8 |
--------------------------------------------------------------------------------
/scripts/experiments/statistics/240603-233132_hotel-reviews_sample_pipeline_progress.txt:
--------------------------------------------------------------------------------
1 | hotel-reviews_sample,True,4,0.6630215941235102
2 | hotel-reviews_sample,True,5,0.6630215941235102
3 | hotel-reviews_sample,True,6,0.6630215941235102
4 | hotel-reviews_sample,True,7,0.6630215941235102
5 | hotel-reviews_sample,True,8,0.6630215941235102
6 | hotel-reviews_sample,True,9,0.6995545635590203
7 | hotel-reviews_sample,True,10,0.6995545635590203
8 |
--------------------------------------------------------------------------------
/scripts/other/preparation_pipeline.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os.path
3 |
4 | from autotm.preprocessing import PREPOCESSED_DATASET_FILENAME
5 | from autotm.preprocessing.dictionaries_preparation import prepare_all_artifacts, prepearing_cooc_dict
6 | from autotm.preprocessing.text_preprocessing import process_dataset
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | PATH_TO_DATASET = "../../data/sample_corpora/sample_dataset_lenta.csv" # dataset with corpora to be processed
12 | SAVE_PATH = "../../data/processed_sample_corpora" # place where all the artifacts will be stored
13 |
14 | col_to_process = "text"
15 | dataset_name = "lenta_sample"
16 | lang = "ru" # available languages: ru, en
17 | min_tokens_num = 3 # the minimal amount of tokens after processing to save the result
18 |
19 |
20 | if __name__ == "__main__":
21 | logger.info("Stage 1: Dataset preparation")
22 | if not os.path.exists(SAVE_PATH):
23 | process_dataset(
24 | PATH_TO_DATASET,
25 | col_to_process,
26 | SAVE_PATH,
27 | lang,
28 | min_tokens_count=min_tokens_num,
29 | )
30 | else:
31 | logger.info("The preprocessed dataset already exists. Found files on path: %s" % SAVE_PATH)
32 |
33 | logger.info("Stage 2: Prepare all artefacts")
34 | prepare_all_artifacts(SAVE_PATH)
35 |
36 | logger.info("All finished")
37 |
--------------------------------------------------------------------------------
/scripts/topic_modeling_of_corporative_data.py:
--------------------------------------------------------------------------------
1 | import time
2 | import logging
3 | import os
4 | import uuid
5 | import pandas as pd
6 | import numpy as np
7 | from datasets import load_dataset
8 | from sklearn.model_selection import train_test_split
9 | from autotm.base import AutoTM
10 | from autotm.ontology.ontology_extractor import build_graph
11 | import networkx as nx
12 |
13 |
14 | df = pd.read_dataset('../data/sample_corpora/clean_docs_v17_gost_only.csv')
15 |
16 | working_dir_path = 'autotm_artifacts'
17 |
18 | autotm = AutoTM(
19 | topic_count=50,
20 | texts_column_name='paragraph',
21 | preprocessing_params={
22 | "lang": "ru",
23 | },
24 | alg_params={
25 | "num_iterations": 10,
26 | },
27 | working_dir_path=working_dir_path
28 | )
29 |
30 | mixtures = autotm.fit_predict(df)
31 |
32 | # посмотрим на получаемые темы
33 | autotm.print_topics()
--------------------------------------------------------------------------------
/tests/integration/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/tests/integration/__init__.py
--------------------------------------------------------------------------------
/tests/integration/test_fit_predict.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import tempfile
4 |
5 | import pandas as pd
6 | import pytest
7 | from numpy.typing import ArrayLike
8 | from sklearn.model_selection import train_test_split
9 |
10 | from autotm.base import AutoTM
11 |
12 | logging.basicConfig(level=logging.INFO)
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def check_predictions(autotm: AutoTM, df: pd.DataFrame, mixtures: ArrayLike):
17 | n_samples, n_samples_mixture = df.shape[0], mixtures.shape[0]
18 | n_topics, n_topics_mixture = len(autotm.topics), mixtures.shape[1]
19 |
20 | assert n_samples_mixture == n_samples
21 | assert n_topics_mixture >= n_topics
22 | assert (~mixtures.isna()).all().all()
23 | assert (~mixtures.isnull()).all().all()
24 |
25 |
26 | # two calls to AutoTM in the same process are not supported due to problems with ARTM deadlocks
27 | @pytest.mark.parametrize('lang,dataset_path', [
28 | ('en', 'data/sample_corpora/imdb_100.csv'),
29 | # ('ru', 'data/sample_corpora/sample_dataset_lenta.csv'),
30 | ], ids=['imdb_100'])
31 | # ], ids=['imdb_100', 'lenta_ru'])
32 | def test_fit_predict(pytestconfig, lang, dataset_path):
33 | # dataset with corpora to be processed
34 | path_to_dataset = os.path.join(pytestconfig.rootpath, dataset_path)
35 | alg_name = "ga"
36 |
37 | df = pd.read_csv(path_to_dataset)
38 | train_df, test_df = train_test_split(df, test_size=0.1)
39 |
40 | with tempfile.TemporaryDirectory(prefix="fp_tmp_working_dir_") as tmp_working_dir:
41 | model_path = os.path.join(tmp_working_dir, "autotm_model")
42 |
43 | autotm = AutoTM(
44 | preprocessing_params={
45 | "lang": lang
46 | },
47 | alg_name=alg_name,
48 | alg_params={
49 | "num_iterations": 2,
50 | "num_individuals": 4,
51 | },
52 | working_dir_path=tmp_working_dir
53 | )
54 | mixtures = autotm.fit_predict(train_df)
55 | check_predictions(autotm, train_df, mixtures)
56 |
57 | # saving the model
58 | autotm.save(model_path)
59 |
60 | # loading and checking if everything is fine with predicting
61 | autotm_loaded = AutoTM.load(model_path)
62 | mixtures = autotm_loaded.predict(test_df)
63 | check_predictions(autotm, test_df, mixtures)
64 |
--------------------------------------------------------------------------------
/tests/unit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimclub/AutoTM/6c2339afd52101a8cf01ea260ef124282d3b17d9/tests/unit/__init__.py
--------------------------------------------------------------------------------
/tests/unit/conftest.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import os.path
3 | from typing import Dict
4 |
5 | import pytest
6 |
7 | from autotm.fitness.tm import ENV_AUTOTM_LLM_API_KEY
8 |
9 |
10 | def parse_vw(path: str) -> Dict[str, Dict[str, float]]:
11 | result = dict()
12 |
13 | with open(path, "r") as f:
14 | for line in f.readlines():
15 | elements = line.split(" ")
16 | word, tokens = elements[0], elements[1:]
17 | if word in result:
18 | raise ValueError("The word is repeated")
19 | result[word] = {token.split(':')[0]: float(token.split(':')[1]) for token in tokens}
20 |
21 | pairs = ((min(word_1, word_2), max(word_1, word_2), value) for word_1, pairs in result.items() for word_2, value in pairs.items())
22 | pairs = sorted(pairs, key=lambda x: x[0])
23 | gpairs = itertools.groupby(pairs, key=lambda x: x[0])
24 | ordered_result = {word_1: {word_2: value for _, word_2, value in pps} for word_1, pps in gpairs}
25 | return ordered_result
26 |
27 |
28 | @pytest.fixture(scope="session")
29 | def test_corpora_path(pytestconfig: pytest.Config) -> str:
30 | return os.path.join(pytestconfig.rootpath, "data", "processed_lenta_ru_sample_corpora")
31 |
32 | @pytest.fixture(scope="session")
33 | def openai_api_key() -> str:
34 | if ENV_AUTOTM_LLM_API_KEY not in os.environ:
35 | raise ValueError(f"Env var {ENV_AUTOTM_LLM_API_KEY} with openai API key is not set")
36 | return os.environ[ENV_AUTOTM_LLM_API_KEY]
37 |
--------------------------------------------------------------------------------
/tests/unit/test_dictionaries_preparation.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pytest
3 |
4 | from autotm.preprocessing.dictionaries_preparation import (
5 | _add_word_to_dict)
6 |
7 | DATASET_PROCESSED_TINY = pd.DataFrame(
8 | {
9 | "processed_text": [
10 | "this is text for testing purposes",
11 | "test the testing test to test the test",
12 | "the text is a good example of",
13 | ]
14 | }
15 | )
16 |
17 |
18 | @pytest.fixture()
19 | def tiny_dataset_(tmpdir):
20 | dataset_processed = tmpdir.join("dataset.txt")
21 | dataset_processed.write(DATASET_PROCESSED_TINY)
22 | return dataset_processed
23 |
24 |
25 | def test__add_word_to_dict():
26 | test_dict = {}
27 | assert _add_word_to_dict('test', test_dict) == {'test': 1}
28 |
--------------------------------------------------------------------------------
/tests/unit/test_llm_fitness.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 |
5 | from autotm.fitness.tm import estimate_topics_with_llm, ENV_AUTOTM_LLM_API_KEY
6 |
7 |
8 | @pytest.mark.skipif(ENV_AUTOTM_LLM_API_KEY not in os.environ, reason=F"ChatGPT API key is not available. {ENV_AUTOTM_LLM_API_KEY} is not set.")
9 | def test_llm_fitness_estimation(openai_api_key: str):
10 | topics = {
11 | "main0": 'entry launch remark rule satellite european build contest player test francis given author canadian cipher',
12 | "main1": 'engineer newspaper holland chapter staff douglas princeton tempest colostate senior executive jose nixon phoenix chemical',
13 | "main3": 'religion atheist belief religious atheism faith christianity existence strong universe follow theist become accept statement',
14 | "main4": 'population bullet arab israeli border village muslim thousand slaughter policy daughter wife authorities switzerland religious',
15 | "main6": 'unix directory comp package library email linux workstation graphics vendor editor user hardware export product',
16 | "main7": 'woman building city left child home face helmet apartment kill wife azerbaijani live father later',
17 | "main10": 'attack lebanese muslim hernlem israeli left troops peace fire away quite stop religion israel especially',
18 | "main11": 'science holland cult compass study tempest investigation methodology nixon psychology department star left colostate scientific',
19 | "main12": 'создавать мужчина премьер добавлять причина оставаться клуб александр сергей закон идти комитет безопасность национальный предлагать',
20 | "main20": 'победа россиянин чемпион американец ассоциация встречаться завершаться килограмм побеждать карьера поражение состояться всемирный категория боец одерживать поедино суметь соперник проигрывать',
21 | "back0": 'garbage garbage garbage',
22 | "back1": 'trash trash trash'
23 | }
24 | topics = {k: v.split(' ') for k, v in topics.items()}
25 |
26 | fitness = estimate_topics_with_llm(
27 | model_id="test_model",
28 | topics=topics,
29 | api_key=openai_api_key,
30 | max_estimated_topics=4,
31 | estimations_per_topic=3
32 | )
33 |
34 | print(f"Fitness: {fitness}")
35 |
36 | assert fitness > 0
37 |
--------------------------------------------------------------------------------
/tests/unit/test_preprocessing.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from autotm.preprocessing.text_preprocessing import (remove_html,
4 | process_punkt,
5 | get_lemma, lemmatize_text_ru)
6 |
7 |
8 | @pytest.mark.parametrize(
9 | "input,expected_output",
10 | [
11 | (
12 | 'Text size using points.
'
13 | 'Text size using pixels.
Text size using relative sizes.
',
14 | 'Text size using points.Text size using pixels.Text size using relative sizes.'
15 | ),
16 | (
17 | '', ''
18 | )
19 | ]
20 | )
21 | def test_html_removal(input, expected_output):
22 | """Test html is removed from text"""
23 | assert remove_html(input) == expected_output
24 |
25 |
26 | @pytest.mark.parametrize(
27 | "input,expected_output",
28 | [
29 | (
30 | " lots of space ", "lots of space"
31 | ),
32 | (
33 | "No#, punctuation? id4321 removed", "No punctuation removed"
34 | )
35 | ]
36 | )
37 | def test_process_punct(input, expected_output):
38 | """Test punctuation removal is correct"""
39 | assert process_punkt(input) == expected_output
40 |
41 |
42 | @pytest.mark.parametrize(
43 | "input,expected_output",
44 | [
45 | ("id19898", "id19898"),
46 | ("нормальные", "нормальные"),
47 | ("dogs", "dog"),
48 | ("toughest", "tough")
49 | ]
50 | )
51 | def test_get_lemma(input, expected_output):
52 | """Test lemmatization works as intended"""
53 | assert get_lemma(input) == expected_output
54 |
55 |
56 | @pytest.mark.parametrize(
57 | "input,expected_output",
58 | [
59 | ("Эти тексты для пользователя id198890, потому что он пес", "текст пользователь пес"),
60 | ("Оболочка работает с CPython 2.6+/3.3+ и PyPy 1.9+", "оболочка работать cpython pypy")
61 | ]
62 | )
63 | def test_full_russian_processing(input, expected_output):
64 | """Test preprocessing for russian"""
65 | assert lemmatize_text_ru(input) == expected_output
66 |
--------------------------------------------------------------------------------
/toloka/Estimate_topics_interpretability/input-data.json:
--------------------------------------------------------------------------------
1 | {
2 | "exp_id": {
3 | "type": "string",
4 | "hidden": true,
5 | "required": true
6 | },
7 | "wordset": {
8 | "type": "string",
9 | "hidden": false,
10 | "required": true
11 | },
12 | "model_id": {
13 | "type": "string",
14 | "hidden": true,
15 | "required": true
16 | },
17 | "topic_id": {
18 | "type": "string",
19 | "hidden": true,
20 | "required": true
21 | },
22 | "dataset_name": {
23 | "type": "string",
24 | "hidden": true,
25 | "required": true
26 | },
27 | "correct_bad_words": {
28 | "type": "json",
29 | "hidden": false,
30 | "required": false
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/toloka/Estimate_topics_interpretability/instructions.md:
--------------------------------------------------------------------------------
1 | Read the set of words. Try to uderstand if they correspond to a single common topic (may be with exception of a few words that don't). If it is true, give the topic name with a one word or a short sentance.
Optionally, identify the words that doesn't belong to the common topica among the wordset.
Examples:
Good topic : run, swimming, athelete, ski, training, exercises, paddle, refereeTopic name : sportGood topic (but with "bad" words that don't belong to the common topic): greenhouse, chicken, tomato, eggs, Topic name: Bad topic (cannot be identified) Bad topic (mixing of two or more topics)
Advertising and spam . This category includes texts that ask users to go to an external resource or buy a product, offer earnings, or advertise dating sites. Such messages often contain shortened links to various sites.Nonsense . The text is a meaningless set of characters or words. Emoticons and emojis, nicknames and hashtags don't belong to this category.Insults . This category includes insults and threats that are clearly targeted at a user.Violation of the law . The text incites extremist activities, violence, criminal activity, or hatred based on gender, race, nationality, or belonging to a social group; or the text promotes suicide, drugs, or the sale of weapons.Profanity . This category includes comments that contain obscenities or profanity.
2 |
--------------------------------------------------------------------------------
/toloka/Estimate_topics_interpretability/output-data.json:
--------------------------------------------------------------------------------
1 | {
2 | "quality": {
3 | "type": "string",
4 | "hidden": false,
5 | "required": true
6 | },
7 | "bad_words": {
8 | "type": "json",
9 | "hidden": false,
10 | "required": true
11 | },
12 | "topic_name": {
13 | "type": "string",
14 | "hidden": false,
15 | "required": false
16 | },
17 | "golden_bad_words": {
18 | "type": "boolean",
19 | "hidden": false,
20 | "required": true
21 | },
22 | "golden_binary_quality": {
23 | "type": "boolean",
24 | "hidden": false,
25 | "required": true
26 | }
27 | }
28 |
29 |
--------------------------------------------------------------------------------
/toloka/Estimate_topics_interpretability/task.css:
--------------------------------------------------------------------------------
1 | /* Task on the page */
2 | .task {
3 | border: 1px solid #ccc;
4 | width: 500px;
5 | padding: 15px;
6 | display: inline-block;
7 | }
8 |
9 | .tsk-block {
10 | border-radius: 3px;
11 | margin-bottom: 10px;
12 | }
13 |
14 | .obj-text {
15 | border: 1px solid #ccc;
16 | padding: 15px 15px 15px 71px;
17 | position: relative;
18 | background-color: #e6f7dc;
19 | }
20 |
21 | /* Quotation mark */
22 | .quote-sign {
23 | background-image: url('data:image/svg+xml;utf8, ');
24 | background-position: center center;
25 | background-repeat: no-repeat;
26 | background-size: contain;
27 | width: 36px;
28 | height: 36px;
29 | position: absolute;
30 | top: 7px;
31 | left: 15px;
32 | }
33 |
34 | .tsk-block fieldset {
35 | padding: 10px 20px;
36 | border-radius: 3px;
37 | border: 1px solid #ccc;
38 | margin: 0;
39 | }
40 |
41 | .tsk-block legend {
42 | font-weight: bold;
43 | padding: 0 6px;
44 | }
45 |
46 | .field_type_checkbox {
47 | display: block;
48 | }
49 |
50 | .task__error {
51 | border-radius: 3px;
52 | }
53 |
54 | .second_scale {
55 | display: none;
56 | }
57 |
58 | /* Displaying task content on mobile devices */
59 | @media screen and (max-width: 600px) {
60 | .task-suite {
61 | padding: 0;
62 | }
63 |
64 | .task {
65 | width: 100%;
66 | margin: 0;
67 | }
68 |
69 | .task-suite div:not(:last-child) {
70 | margin-bottom: 10px;
71 | }
72 |
73 | .hint_label,
74 | .field__hotkey {
75 | display: none;
76 | }
77 |
78 | .field_type_checkbox {
79 | white-space: normal;
80 | }
81 |
82 | .quote-sign {
83 | width: 20px;
84 | height: 20px;
85 | top: 13px;
86 | }
87 |
88 | .obj-text {
89 | padding: 15px 10px 15px 41px;
90 | }
91 | }
92 |
--------------------------------------------------------------------------------
/toloka/Estimate_topics_interpretability/task.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | {{{wordset}}}
5 |
6 |
7 |
8 |
9 |
10 | Can you name a single topic that is represented by this set of words?
11 | {{field type="radio" name="quality" value="good" size="L" label="Yes" hotkey="1" class="yes"}}
12 | {{field type="radio" name="quality" value="rather_good" size="L" label="Rather yes" hotkey="1" class="yes"}}
13 | {{field type="radio" name="quality" value="rather_bad" size="L" label="Rather no" hotkey="3" class="no"}}
14 | {{field type="radio" name="quality" value="bad" size="L" label="No" hotkey="4" class="no"}}
15 |
16 |
17 |
18 |
19 |
20 | Give a name in one or few words to the topic described with the wordset
21 | {{field type="input" name="topic_name" value="" placeholder="Topic name" validation-show="right-center"}}
22 |
23 |
24 | Which words are out of the topic?
25 | {{#each words}}
26 | {{field type="checkbox" name=(concat "bad_words." this.name) label=this.title}}
27 | {{/each}}
28 |
29 |
30 |
--------------------------------------------------------------------------------
/toloka/Estimate_topics_interpretability/task.js:
--------------------------------------------------------------------------------
1 | exports.Task = extend(TolokaHandlebarsTask, function(options) {
2 | TolokaHandlebarsTask.call(this, options);
3 | }, {
4 | is_good_topic: function(solution) {
5 | const positives = ['good', 'rather_good']
6 | return positives.includes(solution.output_values['quality'])
7 | },
8 |
9 | intersection: function(setA, setB) {
10 | return new Set(
11 | [...setA].filter(element => setB.has(element))
12 | );
13 | },
14 |
15 | union: function(setA, setB) {
16 | return new Set(
17 | [...setA, ...setB]
18 | );
19 | },
20 |
21 | bad_words_set: function(obj) {
22 | let badWordsSet = new Set();
23 | for (var prop in obj) {
24 | if (Object.prototype.hasOwnProperty.call(obj, prop)) {
25 | let word = prop;
26 | let is_bad = obj[word];
27 | if (is_bad) {
28 | badWordsSet.add(word);
29 | }
30 | }
31 | }
32 |
33 | return badWordsSet;
34 | },
35 |
36 | setSolution: function(solution) {
37 | TolokaHandlebarsTask.prototype.setSolution.apply(this, arguments);
38 | var workspaceOptions = this.getWorkspaceOptions();
39 |
40 | var tname = solution.output_values['topic_name'] || "";
41 | this.setSolutionOutputValue("topic_name", tname);
42 |
43 | if (this.rendered) {
44 | if (!workspaceOptions.isReviewMode && !workspaceOptions.isReadOnly) {
45 | // Show a set of checkboxes if the answer "There are violations" (BAD) is selected. Otherwise, hide it
46 | if (solution.output_values['quality']) {
47 |
48 | var row = this.getDOMElement().querySelector('.second_scale');
49 | row.style.display = this.is_good_topic(solution) ? 'block' : 'none';
50 |
51 | if (!this.is_good_topic(solution)) {
52 | let data = this.getTemplateData();
53 | let words_out = {};
54 | for (let i = 0; i < data.words.length; i++) {
55 | words_out[data.words[i].name] = false;
56 | }
57 |
58 | this.setSolutionOutputValue("bad_words", words_out);
59 |
60 | this.setSolutionOutputValue("topic_name", "");
61 |
62 | }
63 | }
64 | }
65 | }
66 | },
67 |
68 | getTemplateData: function() {
69 | let data = TolokaHandlebarsTask.prototype.getTemplateData.call(this);
70 |
71 | const words = data.wordset.split(" ");
72 | let word_outs = [];
73 | for (let i = 0; i < words.length; i++) {
74 | word_outs.push({'name': words[i], 'title': words[i]});
75 | }
76 |
77 | data.words = word_outs;
78 |
79 | return data;
80 | },
81 |
82 | // Error message processing
83 | addError: function(message, field, errors) {
84 | errors || (errors = {
85 | task_id: this.getOptions().task.id,
86 | errors: {}
87 | });
88 | errors.errors[field] = {
89 | message: message
90 | };
91 |
92 | return errors;
93 | },
94 |
95 | // Checking the answers: if the answer "There are violations" is selected, at least one checkbox must be checked
96 | validate: function(solution) {
97 | var errors = null;
98 | var topic_name = solution.output_values.topic_name;
99 | topic_name = typeof topic_name !== 'undefined' ? topic_name.trim() : "";
100 | let bad_topic_name = topic_name.length < 3 || topic_name.length > 50
101 |
102 | if (this.is_good_topic(solution) && bad_topic_name) {
103 | errors = this.addError("Topic name is less than 3 symbols or more than 50", '__TASK__', errors);
104 | }
105 |
106 | var correctBadWords = this.getTask().input_values.correct_bad_words;
107 | var golden;
108 | if (!correctBadWords) {
109 | golden = false;
110 | } else {
111 | var badWords = solution.output_values.bad_words;
112 |
113 | let correctBadWordsSet = this.bad_words_set(correctBadWords);
114 | let badWordsSet = this.bad_words_set(badWords);
115 |
116 | var intersection = this.intersection(correctBadWordsSet, badWordsSet) ;
117 | var union = this.union(correctBadWordsSet, badWordsSet);
118 | var golden = intersection.size / union.size >= 0.8 ? true : false;
119 | }
120 | this.setSolutionOutputValue("golden_bad_words", golden);
121 |
122 | var goldenBinaryQuality = this.is_good_topic(solution);
123 | this.setSolutionOutputValue("golden_binary_quality", goldenBinaryQuality);
124 |
125 | return errors || TolokaHandlebarsTask.prototype.validate.apply(this, arguments);
126 | },
127 |
128 | // Open the second question block in verification mode to see the checkboxes marked by the performer
129 | onRender: function() {
130 | var workspaceOptions = this.getWorkspaceOptions();
131 |
132 | if (workspaceOptions.isReviewMode || workspaceOptions.isReadOnly || this.is_good_topic(this.getSolution())){
133 | var row = this.getDOMElement().querySelector('.second_scale');
134 | row.style.display = 'block';
135 | }
136 |
137 | this.rendered = true;
138 | }
139 | });
140 |
141 | function extend(ParentClass, constructorFunction, prototypeHash) {
142 | constructorFunction = constructorFunction || function() {
143 | };
144 | prototypeHash = prototypeHash || {};
145 | if (ParentClass) {
146 | constructorFunction.prototype = Object.create(ParentClass.prototype);
147 | }
148 | for (var i in prototypeHash) {
149 | constructorFunction.prototype[i] = prototypeHash[i];
150 | }
151 | return constructorFunction;
152 | }
153 |
--------------------------------------------------------------------------------