├── .gitignore
├── README.md
├── assets
└── Team_AI-it_Malicious_Comments_Collecting_Service.pdf
├── automl
├── configs
│ └── model
│ │ ├── example.yaml
│ │ └── mobilenetv3.yaml
├── prediction
│ └── sample_submission.csv
├── proj_dataloader.py
├── proj_utils.py
├── src
│ ├── __init__.py
│ ├── dataloader.py
│ ├── loss.py
│ ├── model.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ ├── base_generator.py
│ │ ├── bert.py
│ │ ├── conv.py
│ │ ├── dwconv.py
│ │ ├── electra.py
│ │ ├── flatten.py
│ │ ├── linear.py
│ │ ├── lstm.py
│ │ ├── mbconv.py
│ │ └── poolings.py
│ ├── trainer.py
│ └── utils
│ │ ├── common.py
│ │ ├── data.py
│ │ ├── pytransform
│ │ └── __init__.py
│ │ └── torch_utils.py
├── tests
│ ├── test_model_conversion.py
│ └── test_model_parser.py
├── train.py
└── tune.py
├── base
├── __init__.py
├── base_data_loader.py
├── base_model.py
└── base_trainer.py
├── config.json
├── config_automl_test.json
├── data_loader
├── data_loaders.py
└── kd_data_loaders.py
├── kd_config.json
├── kd_train.py
├── logger
├── __init__.py
├── logger.py
└── logger_config.json
├── model
├── loss.py
├── lr_scheduler.py
├── metric.py
└── model.py
├── parse_config.py
├── pkm_config.json
├── prototype
├── fullstack
│ ├── .DS_Store
│ ├── Makefile
│ ├── __init__.py
│ └── app
│ │ ├── .DS_Store
│ │ ├── __init__.py
│ │ ├── base
│ │ ├── __init__.py
│ │ ├── base_data_loader.py
│ │ ├── base_model.py
│ │ └── base_trainer.py
│ │ ├── config.json
│ │ ├── confirm_button_hack.py
│ │ ├── database.py
│ │ ├── frontend.py
│ │ ├── load_data.py
│ │ ├── main.py
│ │ ├── model
│ │ ├── __init__.py
│ │ └── model.py
│ │ ├── predict.py
│ │ ├── service
│ │ ├── api_response.py
│ │ └── error_handler.py
│ │ ├── test
│ │ ├── db_test.py
│ │ └── exp.ipynb
│ │ └── utils.py
└── streamlit
│ ├── .gitignore
│ ├── app.py
│ ├── base
│ ├── __init__.py
│ ├── base_data_loader.py
│ ├── base_model.py
│ └── base_trainer.py
│ ├── config.json
│ ├── confirm_button_hack.py
│ ├── load_data.py
│ ├── model
│ └── model.py
│ ├── pipeline_test.py
│ ├── predict.py
│ ├── service
│ ├── api_response.py
│ └── error_handler.py
│ └── utils.py
├── requirements.txt
├── simple_test.py
├── test.py
├── test_automl.py
├── tokenizer
├── special_tokens_map.json
├── tokenizer_config.json
└── vocab.txt
├── train.py
├── trainer
├── __init__.py
├── kd_trainer.py
└── trainer.py
└── utils
├── __init__.py
├── api_response.py
├── error_handler.py
├── memory.py
├── query.py
├── util.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 | # input data, saved log, checkpoints
104 | data/
105 | input/
106 | saved/
107 | datasets/
108 | wandb/
109 | jh_test/
110 |
111 | # editor, os cache directory
112 | .vscode/
113 | .idea/
114 | __MACOSX/
115 | *.pt
116 | *.pkl
117 | *.ipynb
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Malicious Comments Collection System
2 |
3 | ## 1. Introduction
4 |
5 | 
6 |
7 | 인터넷이 발달하면서 특정 인물들에 대한 무분별한 악플들이 사람들을 괴롭히고 있습니다. 이런 악플러를 신고 및 고소를 하는데 증거 수집은 필수이지만 오랜 시간을 들여 증거수집이 필요합니다. 특히, 현재 프로세스는 회사나 개인 차원에서 직접 수집을 하거나 팬들의 제보를 통해 이루어지므로 비효율적이며 수동적입니다. 따라서 이런 점을 개선하고자 해당 프로젝트를 진행하게 되었습니다.
8 |
9 | **Malicious Comments Collection System**는 악플을 수집하고 악플을 검토하는 부분을 자동화하는데에 목적이 있습니다. 수집된 자료들은 추후 고소 목적으로 활용이 될 것입니다.
10 |
11 | ### Team AI-it
12 |
13 | > "아-잇" 이라고 발음되는 것이 키치하게 재밌어서 팀명으로 정해보았습니다.
14 |
15 | #### Members
16 |
17 | | 이연걸 | 김재현 | 박진영 | 조범준 | 진혜원 | 안성민 | 양재욱 |
18 | | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
19 | | |
|
|
|
|
|
|
20 | | [](https://github.com/LeeYeonGeol) | [](https://github.com/CozyKim) | [](https://github.com/nazzang49) | [](https://github.com/goattier) | [](https://github.com/hyewon11) | [](https://github.com/tttangmin) | [](https://github.com/didwodnr123) |
21 |
22 | ### Contribution
23 |
24 | - [`이연걸`](https://github.com/LeeYeonGeol) Project Management • Service Dataset • Front-end & Back-end Update • EDA
25 | - [`김재현`](https://github.com/CozyKim) Modeling • Model Optimization • AutoML • EDA
26 | - [`박진영`](https://github.com/nazzang49) Model Optimization • Application Cloud Release (GKE) • Service Architecture
27 | - [`조범준`](https://github.com/goattier) Baseline Code • Modeling • Model Optimization • EDA
28 | - [`진혜원`](https://github.com/hyewon11) Service Dataset • EDA • Front-end & Back-end Update
29 | - [`안성민`](https://github.com/tttangmin) EDA • Modeling
30 | - [`양재욱`](https://github.com/didwodnr123) Front-end (Streamlit) • Back-end (FastAPI) • MongoDB • EDA
31 |
32 | ## 2. Model
33 |
34 | ### KcELECTRA Backbone Model + CNN & RNN Based Classifier (Best LB f1-score: 64.856)
35 | 
36 |
37 | ### Clustering with Triplet Loss + KNN (Best LB f1-score: 66.192)
38 | 
39 |
40 | ### 2nd / 67team (21.12.23 기준)
41 | 
42 |
43 |
44 | ## 3. Flow Chart
45 |
46 | ### System Architecture
47 |
48 | 
49 |
50 | ### Pipeline
51 |
52 | 
53 |
54 | ## 4. How to Use
55 |
56 | ### Install Requirements
57 |
58 | ```bash
59 | pip install -r requirements.txt
60 | ```
61 |
62 | ### Project Tree
63 |
64 | ```
65 | |-- assets
66 | |-- automl
67 | |-- base
68 | | |-- __init__.py
69 | | |-- base_data_loader.py
70 | | |-- base_model.py
71 | | └-- base_trainer.py
72 | |-- data_loader
73 | | └-- data_loaders.py
74 | |-- logger
75 | | |-- __init__.py
76 | | |-- logger.py
77 | | └-- logger_config.json
78 | |-- model
79 | | |-- loss.py
80 | | |-- lr_scheduler.py
81 | | |-- metric.py
82 | | └-- model.py
83 | |-- prototype
84 | |-- tokenizer
85 | | |-- special_tokens_map.json
86 | | |-- tokenizer_config.json
87 | | └-- vocab.txt
88 | |-- trainer
89 | | |-- __init__.py
90 | | |-- kd_trainer.py
91 | | └-- trainer.py
92 | |-- config.json
93 | |-- config_automl_test.json
94 | |-- kd_config.json
95 | |-- kd_train.py
96 | |-- parse_config.py
97 | |-- pkm_config.json
98 | |-- requirements.txt
99 | |-- simple_test.py
100 | |-- test.py
101 | |-- test_automl.py
102 | |-- train.py
103 | └-- utils
104 | |-- __init__.py
105 | |-- api_response.py
106 | |-- error_handler.py
107 | |-- memory.py
108 | |-- query.py
109 | |-- util.py
110 | └-- utils.py
111 | ```
112 |
113 | ### Getting Started
114 | - Train & Validation
115 | ```python
116 | python train.py -c config.json
117 | ```
118 | - Inference
119 | ```python
120 | python test.py -c config.json # test_config.json
121 | ```
122 | - Train (Knowledge Distillation)
123 | ```python
124 | python kd_train.py -c kd_config.json
125 | ```
126 |
127 | ## 5. Demo (TODO)
128 |
129 | ## 6. Reference
130 | - [Korean HateSpeech Detection Kaggle Competition](https://www.kaggle.com/c/korean-hate-speech-detection/data)
131 | - [Korean HateSpeech Dataset](https://github.com/kocohub/korean-hate-speech)
132 | - [BEEP! Korean Corpus of Online News Comments for Toxic Speech Detection](https://aclanthology.org/2020.socialnlp-1.4/)
133 | - [PyTorch Template Project By victoresque](https://github.com/victoresque/pytorch-template)
134 |
--------------------------------------------------------------------------------
/assets/Team_AI-it_Malicious_Comments_Collecting_Service.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-12/09c6e84a3618050ab0593df6f75beacf0340f9a6/assets/Team_AI-it_Malicious_Comments_Collecting_Service.pdf
--------------------------------------------------------------------------------
/automl/configs/model/example.yaml:
--------------------------------------------------------------------------------
1 | input_channel: 3
2 |
3 | depth_multiple: 1.0
4 | width_multiple: 1.0
5 |
6 | backbone:
7 | # Example model in PyTorch Tutorial (https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)
8 | # [repeat, module, args]
9 | [
10 | [1, Conv, [6, 5, 1, 0]],
11 | [1, MaxPool, [2]],
12 | [1, Conv, [16, 5, 1, 0]],
13 | [1, MaxPool, [2]],
14 | [1, GlobalAvgPool, []],
15 | [1, Flatten, []],
16 | [1, Linear, [120, ReLU]],
17 | [1, Linear, [84, ReLU]],
18 | [1, Linear, [9]]
19 | ]
20 |
--------------------------------------------------------------------------------
/automl/configs/model/mobilenetv3.yaml:
--------------------------------------------------------------------------------
1 | input_channel: 3
2 |
3 | depth_multiple: 1.0
4 | width_multiple: 1.0
5 |
6 | backbone:
7 | # [repeat, module, args]
8 | [
9 | # Conv argument: [out_channel, kernel_size, stride, padding_size]
10 | # if padding_size is not given or null, the padding_size will be auto adjusted as padding='SAME' in TensorFlow
11 | [1, Conv, [16, 3, 2, null, 1, "HardSwish"]],
12 | # k t c SE HS s
13 | [1, InvertedResidualv3, [3, 1, 16, 0, 0, 1]],
14 | [1, InvertedResidualv3, [3, 4, 24, 0, 0, 2]], # 2-P2/4, 24 # stride 1 for cifar, 2 for others
15 | [1, InvertedResidualv3, [3, 3, 24, 0, 0, 1]],
16 | [1, InvertedResidualv3, [5, 3, 40, 1, 0, 2]], # 4-P3/8, 40
17 | [1, InvertedResidualv3, [5, 3, 40, 1, 0, 1]],
18 | [1, InvertedResidualv3, [5, 3, 40, 1, 0, 1]],
19 | [1, InvertedResidualv3, [3, 6, 80, 0, 1, 2]], # 7-P4/16, 80
20 | [1, InvertedResidualv3, [3, 2.5, 80, 0, 1, 1]],
21 | [1, InvertedResidualv3, [3, 2.3, 80, 0, 1, 1]],
22 | [1, InvertedResidualv3, [3, 2.3, 80, 0, 1, 1]],
23 | [1, InvertedResidualv3, [3, 6, 112, 1, 1, 1]],
24 | [1, InvertedResidualv3, [3, 6, 112, 1, 1, 1]], # 12 -P5/32, 112
25 | [1, InvertedResidualv3, [5, 6, 160, 1, 1, 2]],
26 | [1, InvertedResidualv3, [5, 6, 160, 1, 1, 1]],
27 | [1, InvertedResidualv3, [5, 6, 160, 1, 1, 1]],
28 | [1, Conv, [960, 1, 1]],
29 | [1, GlobalAvgPool, []],
30 | [1, Conv, [1280, 1, 1]],
31 | [1, Flatten, []],
32 | [1, Linear, [6]]
33 | ]
34 |
--------------------------------------------------------------------------------
/automl/proj_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 | import torch
4 | import emoji
5 | import wandb
6 | import pandas as pd
7 | from pathlib import Path
8 | from itertools import repeat
9 | from collections import OrderedDict
10 | from soynlp.normalizer import repeat_normalize
11 |
12 |
13 | def ensure_dir(dirname):
14 | dirname = Path(dirname)
15 | if not dirname.is_dir():
16 | dirname.mkdir(parents=True, exist_ok=False)
17 |
18 | def read_json(fname):
19 | fname = Path(fname)
20 | with fname.open('rt') as handle:
21 | return json.load(handle, object_hook=OrderedDict)
22 |
23 | def write_json(content, fname):
24 | fname = Path(fname)
25 | with fname.open('wt') as handle:
26 | json.dump(content, handle, indent=4, sort_keys=False)
27 |
28 | def inf_loop(data_loader):
29 | ''' wrapper function for endless data loader. '''
30 | for loader in repeat(data_loader):
31 | yield from loader
32 |
33 | def prepare_device(n_gpu_use):
34 | """
35 | setup GPU device if available. get gpu device indices which are used for DataParallel
36 | """
37 | n_gpu = torch.cuda.device_count()
38 | if n_gpu_use > 0 and n_gpu == 0:
39 | print("Warning: There\'s no GPU available on this machine,"
40 | "training will be performed on CPU.")
41 | n_gpu_use = 0
42 | if n_gpu_use > n_gpu:
43 | print(f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are "
44 | "available on this machine.")
45 | n_gpu_use = n_gpu
46 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
47 | list_ids = list(range(n_gpu_use))
48 | return device, list_ids
49 |
50 |
51 | class MetricTracker:
52 | def __init__(self, *keys):
53 | self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average'])
54 | self.reset()
55 |
56 | def reset(self):
57 | for col in self._data.columns:
58 | self._data[col].values[:] = 0
59 |
60 | def update(self, key, value, n=1):
61 | self._data.total[key] += value * n
62 | self._data.counts[key] += n
63 | self._data.average[key] = self._data.total[key] / self._data.counts[key]
64 |
65 | def avg(self, key):
66 | return self._data.average[key]
67 |
68 | def result(self):
69 | return dict(self._data.average)
70 |
71 |
72 |
73 | def preprocess(sents):
74 | """
75 | kcELECTRA-base preprocess procedure + modification
76 | """
77 | preprocessed_sents = []
78 |
79 | emojis = set()
80 | for k in emoji.UNICODE_EMOJI.keys():
81 | emojis.update(emoji.UNICODE_EMOJI[k].keys())
82 |
83 | punc_bracket_pattern = re.compile(f'[\'\"\[\]\(\)]')
84 | base_pattern = re.compile(f'[^.,?!/@$%~%·∼()\x00-\x7Fㄱ-힣{emojis}]+')
85 | url_pattern = re.compile(
86 | r'(http|ftp|https)?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)'
87 | )
88 |
89 | for sent in sents:
90 | sent = punc_bracket_pattern.sub(' ', sent)
91 | sent = base_pattern.sub(' ', sent)
92 | sent = url_pattern.sub('', sent)
93 | sent = sent.strip()
94 | sent = repeat_normalize(sent, num_repeats=2)
95 | preprocessed_sents.append(sent)
96 |
97 | return preprocessed_sents
98 |
99 |
100 | class Preprocess():
101 | '''A class for preprocessing contexts from train and wikipedia
102 | Args:
103 | sents (list): context list
104 | langs (list): language list should be removed from sentence
105 | '''
106 |
107 | PERMIT_REMOVE_LANGS = [
108 | 'arabic',
109 | 'russian',
110 | ]
111 |
112 | def __init__(self, sents: list):
113 | self.sents = sents
114 |
115 | def proc_preprocessing(self):
116 | """
117 | A function for doing preprocess
118 | """
119 | self.remove_hashtag()
120 | self.remove_user_mention()
121 | self.remove_bad_char()
122 | self.clean_punc()
123 | self.remove_useless_char()
124 | self.remove_linesign()
125 | self.remove_repeated_spacing()
126 |
127 | return self.sents
128 |
129 | def remove_hashtag(self):
130 | """
131 | A function for removing hashtag
132 | """
133 | preprocessed_sents = []
134 | for sent in self.sents:
135 | sent = re.sub(r"#\S+", "", sent).strip()
136 | if sent:
137 | preprocessed_sents.append(sent)
138 | self.sents = preprocessed_sents
139 |
140 | def remove_user_mention(self):
141 | """
142 | A function for removing mention tag
143 | """
144 | preprocessed_sents = []
145 | for sent in self.sents:
146 | sent = re.sub(r"@\w+", "", sent).strip()
147 | if sent:
148 | preprocessed_sents.append(sent)
149 | self.sents = preprocessed_sents
150 |
151 | def remove_bad_char(self):
152 | """
153 | A function for removing raw unicode including unk
154 | """
155 | bad_chars = {"\u200b": "", "…": " ... ", "\ufeff": ""}
156 | preprcessed_sents = []
157 | for sent in self.sents:
158 | for bad_char in bad_chars:
159 | sent = sent.replace(bad_char, bad_chars[bad_char])
160 | sent = re.sub(r"[\+á?\xc3\xa1]", "", sent)
161 | if sent:
162 | preprcessed_sents.append(sent)
163 | self.sents = preprcessed_sents
164 |
165 | def clean_punc(self):
166 | """
167 | A function for removing useless punctuation
168 | """
169 | punct_mapping = {"‘": "'", "₹": "e", "´": "'", "°": "", "€": "e", "™": "tm", "√": " sqrt ", "×": "x", "²": "2",
170 | "—": "-", "–": "-", "’": "'", "_": "-", "`": "'", '“': '"', '”': '"', '“': '"', "£": "e",
171 | '∞': 'infinity', 'θ': 'theta', '÷': '/', 'α': 'alpha', '•': '.', 'à': 'a', '−': '-',
172 | 'β': 'beta', '∅': '', '³': '3', 'π': 'pi', 'ㅂㅅ': '병신', 'ㄲㅈ': '꺼져', 'ㅂㄷ': '부들', 'ㅆㄹㄱ': '쓰레기', 'ㅆㅂ': '씨발',
173 | 'ㅈㅅ': '죄송', 'ㅈㄹ': '지랄', 'ㅈㄴ': '정말'}
174 |
175 | preprocessed_sents = []
176 | for sent in self.sents:
177 | for p in punct_mapping:
178 | sent = sent.replace(p, punct_mapping[p])
179 | sent = sent.strip()
180 | if sent:
181 | preprocessed_sents.append(sent)
182 | self.sents = preprocessed_sents
183 |
184 | def remove_useless_char(self):
185 | preprocessed_sents = []
186 | re_obj = re.compile('[^가-힣a-z0-9\x20]+')
187 |
188 | for sent in self.sents:
189 | temp = re_obj.findall(sent)
190 | if temp != []:
191 | for ch in temp:
192 | sent = sent.replace(ch, " ")
193 | sent = sent.strip()
194 | if sent:
195 | preprocessed_sents.append(sent)
196 |
197 | self.sents = preprocessed_sents
198 |
199 | def remove_repeated_spacing(self):
200 | """
201 | A function for reducing whitespaces into one
202 | """
203 | preprocessed_sents = []
204 | for sent in self.sents:
205 | sent = re.sub(r"\s+", " ", sent).strip()
206 | if sent:
207 | preprocessed_sents.append(sent)
208 | self.sents = preprocessed_sents
209 |
210 | def spacing_sent(self):
211 | """
212 | A function for spacing properly
213 | """
214 | preprocessed_sents = []
215 | for sent in self.sents:
216 | sent = self.spacing(sent)
217 | if sent:
218 | preprocessed_sents.append(sent)
219 | self.sents = preprocessed_sents
220 |
221 | def remove_linesign(self):
222 | """
223 | A function for removing line sings like \n
224 | """
225 | preprocessed_sents = []
226 | for sent in self.sents:
227 | sent = re.sub(r"[\n\t\r\v\f\\\\n\\t\\r\\v\\f]", "", sent)
228 | if sent:
229 | preprocessed_sents.append(sent)
230 | self.sents = preprocessed_sents
231 |
--------------------------------------------------------------------------------
/automl/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-12/09c6e84a3618050ab0593df6f75beacf0340f9a6/automl/src/__init__.py
--------------------------------------------------------------------------------
/automl/src/dataloader.py:
--------------------------------------------------------------------------------
1 | """Tune Model.
2 |
3 | - Author: Junghoon Kim, Jongkuk Lim, Jimyeong Kim
4 | - Contact: placidus36@gmail.com, lim.jeikei@gmail.com, wlaud1001@snu.ac.kr
5 | - Reference
6 | https://github.com/j-marple-dev/model_compression
7 | """
8 | import glob
9 | import os
10 | from typing import Any, Dict, List, Tuple, Union
11 |
12 | import torch
13 | import yaml
14 | from torch.utils.data import DataLoader, random_split
15 | from torchvision.datasets import ImageFolder, VisionDataset
16 |
17 | from src.utils.data import weights_for_balanced_classes
18 | from src.utils.torch_utils import split_dataset_index
19 |
20 |
21 | def create_dataloader(
22 | config: Dict[str, Any],
23 | ) -> Tuple[DataLoader, DataLoader, DataLoader]:
24 | """Simple dataloader.
25 |
26 | Args:
27 | cfg: yaml file path or dictionary type of the data.
28 |
29 | Returns:
30 | train_loader
31 | valid_loader
32 | test_loader
33 | """
34 | # Data Setup
35 | train_dataset, val_dataset, test_dataset = get_dataset(
36 | data_path=config["DATA_PATH"],
37 | dataset_name=config["DATASET"],
38 | img_size=config["IMG_SIZE"],
39 | val_ratio=config["VAL_RATIO"],
40 | transform_train=config["AUG_TRAIN"],
41 | transform_test=config["AUG_TEST"],
42 | transform_train_params=config["AUG_TRAIN_PARAMS"],
43 | transform_test_params=config.get("AUG_TEST_PARAMS"),
44 | )
45 |
46 | return get_dataloader(
47 | train_dataset=train_dataset,
48 | val_dataset=val_dataset,
49 | test_dataset=test_dataset,
50 | batch_size=config["BATCH_SIZE"],
51 | )
52 |
53 |
54 | def get_dataset(
55 | data_path: str = "./save/data",
56 | dataset_name: str = "CIFAR10",
57 | img_size: float = 32,
58 | val_ratio: float=0.2,
59 | transform_train: str = "simple_augment_train",
60 | transform_test: str = "simple_augment_test",
61 | transform_train_params: Dict[str, int] = None,
62 | transform_test_params: Dict[str, int] = None,
63 | ) -> Tuple[VisionDataset, VisionDataset, VisionDataset]:
64 | """Get dataset for training and testing."""
65 | if not transform_train_params:
66 | transform_train_params = dict()
67 | if not transform_test_params:
68 | transform_test_params = dict()
69 |
70 | # preprocessing policies
71 | transform_train = getattr(
72 | __import__("src.augmentation.policies", fromlist=[""]),
73 | transform_train,
74 | )(dataset=dataset_name, img_size=img_size, **transform_train_params)
75 | transform_test = getattr(
76 | __import__("src.augmentation.policies", fromlist=[""]),
77 | transform_test,
78 | )(dataset=dataset_name, img_size=img_size, **transform_test_params)
79 |
80 | label_weights = None
81 | # pytorch dataset
82 | if dataset_name == "TACO":
83 | train_path = os.path.join(data_path, "train")
84 | val_path = os.path.join(data_path, "val")
85 | test_path = os.path.join(data_path, "test")
86 |
87 | train_dataset = ImageFolder(root=train_path, transform=transform_train)
88 | val_dataset = ImageFolder(root=val_path, transform=transform_test)
89 | test_dataset = ImageFolder(root=test_path, transform=transform_test)
90 |
91 | else:
92 | Dataset = getattr(
93 | __import__("torchvision.datasets", fromlist=[""]), dataset_name
94 | )
95 | train_dataset = Dataset(
96 | root=data_path, train=True, download=True, transform=transform_train
97 | )
98 | # from train dataset, train: 80%, val: 20%
99 | train_length = int(len(train_dataset) * (1.0-val_ratio))
100 | train_dataset, val_dataset = random_split(
101 | train_dataset, [train_length, len(train_dataset) - train_length]
102 | )
103 | test_dataset = Dataset(
104 | root=data_path, train=False, download=False, transform=transform_test
105 | )
106 | return train_dataset, val_dataset, test_dataset
107 |
108 |
109 | def get_dataloader(
110 | train_dataset: VisionDataset,
111 | val_dataset: VisionDataset,
112 | test_dataset: VisionDataset,
113 | batch_size: int,
114 | ) -> Tuple[DataLoader, DataLoader, DataLoader]:
115 | """Get dataloader for training and testing."""
116 |
117 | train_loader = DataLoader(
118 | dataset=train_dataset,
119 | pin_memory=(torch.cuda.is_available()),
120 | shuffle=True,
121 | batch_size=batch_size,
122 | num_workers=10,
123 | drop_last=True
124 | )
125 | valid_loader = DataLoader(
126 | dataset=val_dataset,
127 | pin_memory=(torch.cuda.is_available()),
128 | shuffle=False,
129 | batch_size=batch_size,
130 | num_workers=5
131 | )
132 | test_loader = DataLoader(
133 | dataset=test_dataset,
134 | pin_memory=(torch.cuda.is_available()),
135 | shuffle=False,
136 | batch_size=batch_size,
137 | num_workers=5
138 | )
139 | return train_loader, valid_loader, test_loader
140 |
--------------------------------------------------------------------------------
/automl/src/loss.py:
--------------------------------------------------------------------------------
1 | """Custom loss for long tail problem.
2 |
3 | - Author: Junghoon Kim
4 | - Email: placidus36@gmail.com
5 | """
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class CustomCriterion:
13 | """Custom Criterion."""
14 |
15 | def __init__(self, samples_per_cls, device, fp16=False, loss_type="softmax"):
16 | if not samples_per_cls:
17 | loss_type = "softmax"
18 | else:
19 | self.samples_per_cls = samples_per_cls
20 | self.frequency_per_cls = samples_per_cls / np.sum(samples_per_cls)
21 | self.no_of_classes = len(samples_per_cls)
22 | self.device = device
23 | self.fp16 = fp16
24 |
25 | if loss_type == "softmax":
26 | self.criterion = nn.CrossEntropyLoss()
27 | elif loss_type == "logit_adjustment_loss":
28 | tau = 1.0
29 | self.logit_adj_val = (
30 | torch.tensor(tau * np.log(self.frequency_per_cls))
31 | .float()
32 | .to(self.device)
33 | )
34 | self.logit_adj_val = (
35 | self.logit_adj_val.half() if fp16 else self.logit_adj_val.float()
36 | )
37 | self.logit_adj_val = self.logit_adj_val.to(device)
38 | self.criterion = self.logit_adjustment_loss
39 |
40 | def __call__(self, logits, labels):
41 | """Call criterion."""
42 | return self.criterion(logits, labels)
43 |
44 | def logit_adjustment_loss(self, logits, labels):
45 | """Logit adjustment loss."""
46 | logits_adjusted = logits + self.logit_adj_val.repeat(labels.shape[0], 1)
47 | loss = F.cross_entropy(input=logits_adjusted, target=labels)
48 | return loss
49 |
--------------------------------------------------------------------------------
/automl/src/model.py:
--------------------------------------------------------------------------------
1 | """Model parser and model.
2 |
3 | - Author: Jongkuk Lim
4 | - Contact: lim.jeikei@gmail.com
5 | """
6 |
7 | from typing import Dict, List, Type, Union
8 |
9 | import torch
10 | import torch.nn as nn
11 | import yaml
12 |
13 | from .modules import ModuleGenerator
14 |
15 |
16 | class Model(nn.Module):
17 | """Base model class."""
18 |
19 | def __init__(
20 | self,
21 | cfg: Union[str, Dict[str, Type]] = "./model_configs/show_case.yaml",
22 | verbose: bool = False,
23 | ) -> None:
24 | """Parse model from the model config file.
25 |
26 | Args:
27 | cfg: yaml file path or dictionary type of the model.
28 | verbose: print the model parsing information.
29 | """
30 | super().__init__()
31 | self.model_parser = ModelParser(cfg=cfg, verbose=verbose)
32 | self.model = self.model_parser.model
33 |
34 | def forward(self, x: torch.Tensor) -> torch.Tensor:
35 | """Forward."""
36 | return self.forward_one(x)
37 |
38 | def forward_one(self, x: torch.Tensor) -> torch.Tensor:
39 | """Forward onetime."""
40 |
41 | return self.model(x)
42 |
43 |
44 | class ModelParser:
45 | """Generate PyTorch model from the model yaml file."""
46 |
47 | def __init__(
48 | self,
49 | cfg: Union[str, Dict[str, Type]] = "./model_configs/show_case.yaml",
50 | verbose: bool = False,
51 | ) -> None:
52 | """Generate PyTorch model from the model yaml file.
53 |
54 | Args:
55 | cfg: model config file or dict values read from the model config file.
56 | verbose: print the parsed model information.
57 | """
58 |
59 | self.verbose = verbose
60 | if isinstance(cfg, dict):
61 | self.cfg = cfg
62 | else:
63 | with open(cfg) as f:
64 | self.cfg = yaml.load(f, Loader=yaml.FullLoader)
65 |
66 | # self.in_channel = self.cfg["input_channel"]
67 |
68 | # self.depth_multiply = self.cfg["depth_multiple"]
69 | # self.width_multiply = self.cfg["width_multiple"]
70 |
71 | # error: Incompatible types in assignment (expression has type "Type[Any]",
72 | # variable has type "List[Union[int, str, float]]")
73 | self.model_cfg: List[Union[int, str, float]] = self.cfg["backbone"] # type: ignore
74 |
75 | self.model = self._parse_model()
76 |
77 | def log(self, msg: str):
78 | """Log."""
79 | if self.verbose:
80 | print(msg)
81 |
82 | def _parse_model(self) -> nn.Sequential:
83 | """Parse model."""
84 | layers: List[nn.Module] = []
85 | log: str = (
86 | f"{'idx':>3} | {'n':>3} | {'params':>10} "
87 | f"| {'module':>15} | {'arguments':>20} | {'in_channel':>12} | {'out_channel':>13}"
88 | )
89 | self.log(log)
90 | self.log(len(log) * "-") # type: ignore
91 |
92 | # in_channel = self.in_channel
93 | for i, (repeat, module, args) in enumerate(self.model_cfg): # type: ignore
94 | repeat = (
95 | max(round(repeat * self.depth_multiply), 1) if repeat > 1 else repeat
96 | )
97 |
98 | module_generator = ModuleGenerator(module)( # type: ignore
99 | *args,
100 | # width_multiply=self.width_multiply,
101 | )
102 | m = module_generator(repeat=repeat)
103 |
104 | layers.append(m)
105 | # in_channel = module_generator.out_channel
106 |
107 | log = (
108 | f"{i:3d} | {repeat:3d} | "
109 | f"{m.n_params:10,d} | {m.type:>15} | {str(args):>20} | "
110 | # f"{str(module_generator.in_channel):>12}"
111 | # f"{str(module_generator.out_channel):>13}"
112 | )
113 |
114 | self.log(log)
115 |
116 | parsed_model = nn.Sequential(*layers)
117 | n_param = sum([x.numel() for x in parsed_model.parameters()])
118 | n_grad = sum([x.numel() for x in parsed_model.parameters() if x.requires_grad])
119 | # error: Incompatible return value type (got "Tuple[Sequential, List[int]]",
120 | # expected "Tuple[Module, List[Optional[int]]]")
121 | self.log(
122 | f"Model Summary: {len(list(parsed_model.modules())):,d} "
123 | f"layers, {n_param:,d} parameters, {n_grad:,d} gradients"
124 | )
125 |
126 | return parsed_model
127 |
--------------------------------------------------------------------------------
/automl/src/modules/__init__.py:
--------------------------------------------------------------------------------
1 | """PyTorch Module and ModuleGenerator."""
2 |
3 | from src.modules.base_generator import GeneratorAbstract, ModuleGenerator
4 | from src.modules.bottleneck import Bottleneck, BottleneckGenerator
5 | from src.modules.conv import Conv, ConvGenerator, FixedConvGenerator
6 | from src.modules.dwconv import DWConv, DWConvGenerator
7 | from src.modules.flatten import FlattenGenerator
8 | from src.modules.invertedresidualv2 import (
9 | InvertedResidualv2,
10 | InvertedResidualv2Generator,
11 | )
12 | from src.modules.invertedresidualv3 import (
13 | InvertedResidualv3,
14 | InvertedResidualv3Generator,
15 | )
16 | from src.modules.linear import Linear, LinearGenerator
17 | from src.modules.poolings import (
18 | AvgPoolGenerator,
19 | GlobalAvgPool,
20 | GlobalAvgPoolGenerator,
21 | MaxPoolGenerator,
22 | )
23 | from src.modules.bert import Bert, BertGenerator
24 | from src.modules.electra import Electra, ElectraGenerator
25 | from src.modules.lstm import Lstm, LstmGenerator
26 | from src.modules.electra_lstm import ElectraWithLSTM, ElectraWithLSTMGenerator
27 | from src.modules.bert_lstm import BertWithLSTM, BertWithLSTMGenerator
28 |
29 | __all__ = [
30 | "ModuleGenerator",
31 | "GeneratorAbstract",
32 | "Bottleneck",
33 | "Conv",
34 | "DWConv",
35 | "Linear",
36 | "GlobalAvgPool",
37 | "InvertedResidualv2",
38 | "InvertedResidualv3",
39 | "BottleneckGenerator",
40 | "FixedConvGenerator",
41 | "ConvGenerator",
42 | "LinearGenerator",
43 | "DWConvGenerator",
44 | "FlattenGenerator",
45 | "MaxPoolGenerator",
46 | "AvgPoolGenerator",
47 | "GlobalAvgPoolGenerator",
48 | "InvertedResidualv2Generator",
49 | "InvertedResidualv3Generator",
50 | "Bert",
51 | "BertGenerator",
52 | "Electra",
53 | "ElectraGenerator" "Lstm",
54 | "LstmGenerator",
55 | "ElectraWithLSTM",
56 | "ElectraWithLSTMGenerator",
57 | "BertWithLSTM",
58 | "BertWithLSTMGenerator",
59 | ]
60 |
--------------------------------------------------------------------------------
/automl/src/modules/activations.py:
--------------------------------------------------------------------------------
1 | """Custom activation to work with onnx.
2 |
3 | Reference:
4 | https://github.com/rwightman/pytorch-image-models/blob/9a25fdf3ad0414b4d66da443fe60ae0aa14edc84/timm/models/layers/activations.py
5 | - Author: Junghoon Kim
6 | - Contact: placidus36@gmail.com
7 | """
8 | import torch
9 | import torch.nn as nn
10 | from torch.nn import functional as F
11 |
12 |
13 | def hard_sigmoid(x: torch.Tensor, inplace: bool = False):
14 | """Hard sigmoid."""
15 | if inplace:
16 | return x.add_(3.0).clamp_(0.0, 6.0).div_(6.0)
17 | else:
18 | return F.relu6(x + 3.0) / 6.0
19 |
20 |
21 | class HardSigmoid(nn.Module):
22 | """Hard sigmoid."""
23 |
24 | def __init__(self, inplace: bool = False):
25 | """Initialize."""
26 | super().__init__()
27 | self.inplace = inplace
28 |
29 | def forward(self, x: torch.Tensor):
30 | """Forward."""
31 | return hard_sigmoid(x, self.inplace)
32 |
33 |
34 | def hard_swish(x: torch.Tensor, inplace: bool = False):
35 | """Hard swish."""
36 | inner = F.relu6(x + 3.0).div_(6.0)
37 | return x.mul_(inner) if inplace else x.mul(inner)
38 |
39 |
40 | class HardSwish(nn.Module):
41 | """Custom hardswish to work with onnx."""
42 |
43 | def __init__(self, inplace: bool = False):
44 | """Initialize."""
45 | super().__init__()
46 | self.inplace = inplace
47 |
48 | def forward(self, x: torch.Tensor):
49 | """Forward."""
50 | return hard_swish(x, self.inplace)
51 |
52 |
53 | def swish(x: torch.Tensor, inplace: bool = False):
54 | """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
55 | and also as Swish (https://arxiv.org/abs/1710.05941).
56 | TODO Rename to SiLU with addition to PyTorch
57 | Adopted to handle onnx conversion
58 | """
59 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
60 |
61 |
62 | class Swish(nn.Module):
63 | """Swish."""
64 |
65 | def __init__(self, inplace: bool = False):
66 | """Initialize."""
67 | super().__init__()
68 | self.inplace = inplace
69 |
70 | def forward(self, x: torch.Tensor):
71 | """Forward."""
72 | return swish(x, self.inplace)
73 |
--------------------------------------------------------------------------------
/automl/src/modules/base_generator.py:
--------------------------------------------------------------------------------
1 | """Base Module Generator.
2 |
3 | This module is responsible for GeneratorAbstract and ModuleGenerator.
4 |
5 | - Author: Jongkuk Lim
6 | - Contact: lim.jeikei@gmail.com
7 | """
8 | from abc import ABC, abstractmethod
9 | from typing import List, Union
10 |
11 | from torch import nn as nn
12 |
13 | from src.utils.torch_utils import make_divisible
14 |
15 |
16 | class GeneratorAbstract(ABC):
17 | """Abstract Module Generator."""
18 |
19 | CHANNEL_DIVISOR: int = 8
20 |
21 | def __init__(
22 | self,
23 | *args,
24 | from_idx: Union[int, List[int]] = -1,
25 | ):
26 | """Initialize module generator.
27 |
28 | Args:
29 | *args: Module arguments
30 | from_idx: Module input index
31 | """
32 | self.args = tuple(args)
33 | self.from_idx = from_idx
34 |
35 | @property
36 | def name(self) -> str:
37 | """Module name."""
38 | return self.__class__.__name__.replace("Generator", "")
39 |
40 | def _get_module(self, module: Union[nn.Module, List[nn.Module]]) -> nn.Module:
41 | """Get module from __call__ function."""
42 | if isinstance(module, list):
43 | module = nn.Sequential(*module)
44 |
45 | # error: Incompatible types in assignment (expression has type "Union[Tensor, Module, int]",
46 | # variable has type "Union[Tensor, Module]")
47 | # error: List comprehension has incompatible type List[int];
48 | # expected List[Union[Tensor, Module]]
49 | module.n_params = sum([x.numel() for x in module.parameters()]) # type: ignore
50 | # error: Cannot assign to a method
51 | module.type = self.name # type: ignore
52 |
53 | return module
54 |
55 | # @classmethod
56 | # def _get_divisible_channel(cls, n_channel: int) -> int:
57 | # """Get divisible channel by default divisor.
58 |
59 | # Args:
60 | # n_channel: number of channel.
61 |
62 | # Returns:
63 | # Ex) given {n_channel} is 52 and {GeneratorAbstract.CHANNEL_DIVISOR} is 8.,
64 | # return channel is 56 since ceil(52/8) = 7 and 7*8 = 56
65 | # """
66 | # return make_divisible(n_channel, divisor=cls.CHANNEL_DIVISOR)
67 |
68 | # @property
69 | # @abstractmethod
70 | # def out_channel(self) -> int:
71 | # """Out channel of the module."""
72 |
73 | @abstractmethod
74 | def __call__(self, repeat: int = 1):
75 | """Returns nn.Module component"""
76 |
77 |
78 | class ModuleGenerator:
79 | """Module generator class."""
80 |
81 | def __init__(self, module_name: str):
82 | """Generate module based on the {module_name}
83 |
84 | Args:
85 | module_name: {module_name}Generator class must have been implemented.
86 | """
87 | self.module_name = module_name
88 | # self.in_channel = in_channel
89 |
90 | def __call__(self, *args, **kwargs):
91 | # replace getattr
92 | return getattr(
93 | __import__("src.modules", fromlist=[""]),
94 | f"{self.module_name}Generator",
95 | )(*args, **kwargs)
96 |
--------------------------------------------------------------------------------
/automl/src/modules/bert.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | from abc import abstractmethod
4 | from transformers import AutoModel
5 | from src.modules.base_generator import GeneratorAbstract
6 |
7 |
8 | class BaseModel(nn.Module):
9 | """
10 | Base class for all models
11 | """
12 |
13 | @abstractmethod
14 | def forward(self, *inputs):
15 | """
16 | Forward pass logic
17 |
18 | :return: Model output
19 | """
20 | raise NotImplementedError
21 |
22 | def __str__(self):
23 | """
24 | Model prints with number of trainable parameters
25 | """
26 | model_parameters = filter(lambda p: p.requires_grad, self.parameters())
27 | params = sum([np.prod(p.size()) for p in model_parameters])
28 | return super().__str__() + "\nTrainable parameters: {}".format(params)
29 |
30 |
31 | class Bert(BaseModel):
32 | def __init__(self, name="klue/bert-base"):
33 | super().__init__()
34 | # self.config = AutoConfig.from_pretrained(name)
35 | self.model = AutoModel.from_pretrained(name)
36 |
37 | def forward(self, inputs):
38 | return self.model(**inputs)[0]
39 |
40 |
41 | class BertGenerator(GeneratorAbstract):
42 | """Pretrained Bert block generator."""
43 |
44 | def __init__(self, *args, **kwargs):
45 | super().__init__(*args, **kwargs)
46 |
47 | @property
48 | def base_module(self) -> nn.Module:
49 | """Returns module class from src.common_modules based on the class name."""
50 | return getattr(__import__("src.modules", fromlist=[""]), self.name)
51 |
52 | def __call__(self, repeat: int = 1):
53 | """call method.
54 | Build Bert Model
55 | """
56 | module = []
57 | args = self.args
58 | for i in range(repeat):
59 | module.append(self.base_module(*args))
60 | return self._get_module(module)
61 |
--------------------------------------------------------------------------------
/automl/src/modules/conv.py:
--------------------------------------------------------------------------------
1 | """Conv module, generator.
2 |
3 | - Author: Jongkuk Lim
4 | - Contact: lim.jeikei@gmail.com
5 | """
6 | # pylint: disable=useless-super-delegation
7 | from typing import Union
8 |
9 | import torch
10 | from torch import nn as nn
11 |
12 | from src.modules.base_generator import GeneratorAbstract
13 | from src.utils.torch_utils import Activation, autopad
14 |
15 |
16 | class Conv(nn.Module):
17 | """Standard convolution with batch normalization and activation."""
18 |
19 | def __init__(
20 | self,
21 | in_channel: int,
22 | out_channels: int,
23 | kernel_size: int,
24 | stride: int = 1,
25 | padding: Union[int, None] = None,
26 | groups: int = 1,
27 | activation: Union[str, None] = "ReLU",
28 | ) -> None:
29 | """Standard convolution with batch normalization and activation.
30 |
31 | Args:
32 | in_channel: input channels.
33 | out_channels: output channels.
34 | kernel_size: kernel size.
35 | stride: stride.
36 | padding: input padding. If None is given, autopad is applied
37 | which is identical to padding='SAME' in TensorFlow.
38 | groups: group convolution.
39 | activation: activation name. If None is given, nn.Identity is applied
40 | which is no activation.
41 | """
42 | super().__init__()
43 | # error: Argument "padding" to "Conv2d" has incompatible type "Union[int, List[int]]";
44 | # expected "Union[int, Tuple[int, int]]"
45 | self.conv = nn.Conv2d(
46 | in_channel,
47 | out_channels,
48 | kernel_size,
49 | stride,
50 | padding=autopad(kernel_size, padding), # type: ignore
51 | groups=groups,
52 | bias=False,
53 | )
54 | self.bn = nn.BatchNorm2d(out_channels)
55 | self.act = Activation(activation)()
56 |
57 | def forward(self, x: torch.Tensor) -> torch.Tensor:
58 | """Forward."""
59 | return self.act(self.bn(self.conv(x)))
60 |
61 | def fusefoward(self, x: torch.Tensor) -> torch.Tensor:
62 | """Fuse forward."""
63 | return self.act(self.conv(x))
64 |
65 |
66 | class ConvGenerator(GeneratorAbstract):
67 | """Conv2d generator for parsing module."""
68 |
69 | def __init__(self, *args, **kwargs):
70 | super().__init__(*args, **kwargs)
71 |
72 | @property
73 | def out_channel(self) -> int:
74 | """Get out channel size."""
75 | return self._get_divisible_channel(self.args[0] * self.width_multiply)
76 |
77 | @property
78 | def base_module(self) -> nn.Module:
79 | """Returns module class from src.common_modules based on the class name."""
80 | return getattr(__import__("src.modules", fromlist=[""]), self.name)
81 |
82 | def __call__(self, repeat: int = 1):
83 | args = [self.in_channel, self.out_channel, *self.args[1:]]
84 | if repeat > 1:
85 | stride = 1
86 | # Important!: stride only applies at the end of the repeat.
87 | if len(args) > 2:
88 | stride = args[3]
89 | args[3] = 1
90 |
91 | module = []
92 | for i in range(repeat):
93 | if len(args) > 1 and stride > 1 and i == repeat - 1:
94 | args[3] = stride
95 |
96 | module.append(self.base_module(*args))
97 | args[0] = self.out_channel
98 | else:
99 | module = self.base_module(*args)
100 |
101 | return self._get_module(module)
102 |
103 |
104 | class FixedConvGenerator(GeneratorAbstract):
105 | """FixedConv2d generator for parsing module.
106 | Fixed Conv doesn't change out channel
107 | """
108 |
109 | def __init__(self, *args, **kwargs):
110 | super().__init__(*args, **kwargs)
111 |
112 | @property
113 | def out_channel(self) -> int:
114 | """Get out channel size."""
115 | return int(self.args[0])
116 |
117 | @property
118 | def base_module(self) -> nn.Module:
119 | """Returns module class from src.common_modules based on the class name."""
120 | return getattr(
121 | __import__("src.modules", fromlist=[""]), self.name.replace("Fixed", "")
122 | )
123 |
124 | def __call__(self, repeat: int = 1):
125 | args = [self.in_channel, self.out_channel, *self.args[1:]]
126 | if repeat > 1:
127 | stride = 1
128 | # Important!: stride only applies at the end of the repeat.
129 | if len(args) > 2:
130 | stride = args[3]
131 | args[3] = 1
132 |
133 | module = []
134 | for i in range(repeat):
135 | if len(args) > 1 and stride > 1 and i == repeat - 1:
136 | args[3] = stride
137 |
138 | module.append(self.base_module(*args))
139 | args[0] = self.out_channel
140 | else:
141 | module = self.base_module(*args)
142 |
143 | return self._get_module(module)
144 |
145 |
146 | class FixedConvGenerator(GeneratorAbstract):
147 | """FixedConv2d generator for parsing module.
148 |
149 | Fixed Conv doesn't change out channel
150 | """
151 |
152 | def __init__(self, *args, **kwargs):
153 | super().__init__(*args, **kwargs)
154 |
155 | @property
156 | def out_channel(self) -> int:
157 | """Get out channel size."""
158 | return int(self.args[0])
159 |
160 | @property
161 | def base_module(self) -> nn.Module:
162 | """Returns module class from src.common_modules based on the class name."""
163 | return getattr(
164 | __import__("src.modules", fromlist=[""]), self.name.replace("Fixed", "")
165 | )
166 |
167 | def __call__(self, repeat: int = 1):
168 | args = [self.in_channel, self.out_channel, *self.args[1:]]
169 | if repeat > 1:
170 | stride = 1
171 | # Important!: stride only applies at the end of the repeat.
172 | if len(args) > 2:
173 | stride = args[3]
174 | args[3] = 1
175 |
176 | module = []
177 | for i in range(repeat):
178 | if len(args) > 1 and stride > 1 and i == repeat - 1:
179 | args[3] = stride
180 |
181 | module.append(self.base_module(*args))
182 | args[0] = self.out_channel
183 | else:
184 | module = self.base_module(*args)
185 |
186 | return self._get_module(module)
187 |
--------------------------------------------------------------------------------
/automl/src/modules/dwconv.py:
--------------------------------------------------------------------------------
1 | """DWConv module, generator.
2 |
3 | - Author: Jongkuk Lim
4 | - Contact: lim.jeikei@gmail.com
5 | """
6 | import math
7 | # pylint: disable=useless-super-delegation
8 | from typing import Union
9 |
10 | import torch
11 | from torch import nn as nn
12 |
13 | from src.modules.base_generator import GeneratorAbstract
14 | from src.utils.torch_utils import Activation, autopad
15 |
16 |
17 | class DWConv(nn.Module):
18 | """Depthwise convolution with batch normalization and activation."""
19 |
20 | def __init__(
21 | self,
22 | in_channel: int,
23 | out_channels: int,
24 | kernel_size: int,
25 | stride: int = 1,
26 | padding: Union[int, None] = None,
27 | activation: Union[str, None] = "ReLU",
28 | ) -> None:
29 | """Depthwise convolution with batch normalization and activation.
30 |
31 | Args:
32 | in_channel: input channels.
33 | out_channels: output channels.
34 | kernel_size: kernel size.
35 | stride: stride.
36 | padding: input padding. If None is given, autopad is applied
37 | which is identical to padding='SAME' in TensorFlow.
38 | activation: activation name. If None is given, nn.Identity is applied
39 | which is no activation.
40 | """
41 | super().__init__()
42 | # error: Argument "padding" to "Conv2d" has incompatible type "Union[int, List[int]]";
43 | # expected "Union[int, Tuple[int, int]]"
44 | self.conv = nn.Conv2d(
45 | in_channel,
46 | out_channels,
47 | kernel_size,
48 | stride,
49 | padding=autopad(kernel_size, padding), # type: ignore
50 | groups=math.gcd(in_channel, out_channels),
51 | bias=False,
52 | )
53 | self.bn = nn.BatchNorm2d(out_channels)
54 | self.act = Activation(activation)()
55 |
56 | def forward(self, x: torch.Tensor) -> torch.Tensor:
57 | """Forward."""
58 | return self.act(self.bn(self.conv(x)))
59 |
60 | def fusefoward(self, x: torch.Tensor) -> torch.Tensor:
61 | """Fuse forward."""
62 | return self.act(self.conv(x))
63 |
64 |
65 | class DWConvGenerator(GeneratorAbstract):
66 | """Depth-wise convolution generator for parsing module."""
67 |
68 | def __init__(self, *args, **kwargs):
69 | super().__init__(*args, **kwargs)
70 |
71 | @property
72 | def out_channel(self) -> int:
73 | """Get out channel size."""
74 | return self._get_divisible_channel(self.args[0] * self.width_multiply)
75 |
76 | @property
77 | def base_module(self) -> nn.Module:
78 | """Returns module class from src.common_modules based on the class name."""
79 | return getattr(__import__("src.modules", fromlist=[""]), self.name)
80 |
81 | def __call__(self, repeat: int = 1):
82 | args = [self.in_channel, self.out_channel, *self.args[1:]]
83 | if repeat > 1:
84 | stride = 1
85 | # Important!: stride only applies at the end of the repeat.
86 | if len(args) > 2:
87 | stride = args[3]
88 | args[3] = 1
89 |
90 | module = []
91 | for i in range(repeat):
92 | if len(args) > 1 and stride > 1 and i == repeat - 1:
93 | args[3] = stride
94 |
95 | module.append(self.base_module(*args))
96 | args[0] = self.out_channel
97 | else:
98 | module = self.base_module(*args)
99 |
100 | return self._get_module(module)
101 |
--------------------------------------------------------------------------------
/automl/src/modules/electra.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from abc import abstractmethod
5 | from transformers import AutoConfig, AutoModel
6 | from src.modules.base_generator import GeneratorAbstract
7 |
8 |
9 | class BaseModel(nn.Module):
10 | """
11 | Base class for all models
12 | """
13 |
14 | @abstractmethod
15 | def forward(self, *inputs):
16 | """
17 | Forward pass logic
18 |
19 | :return: Model output
20 | """
21 | raise NotImplementedError
22 |
23 | def __str__(self):
24 | """
25 | Model prints with number of trainable parameters
26 | """
27 | model_parameters = filter(lambda p: p.requires_grad, self.parameters())
28 | params = sum([np.prod(p.size()) for p in model_parameters])
29 | return super().__str__() + "\nTrainable parameters: {}".format(params)
30 |
31 |
32 | class Electra(BaseModel):
33 | def __init__(self, name="beomi/beep-KcELECTRA-base-hate"):
34 | super().__init__()
35 | self.model = AutoModel.from_pretrained(name)
36 |
37 | def forward(self, inputs):
38 | with torch.no_grad():
39 | outputs = self.model(**inputs)
40 | return outputs[0]
41 |
42 |
43 | class ElectraGenerator(GeneratorAbstract):
44 | """Pretrained Electra block generator."""
45 |
46 | def __init__(self, *args, **kwargs):
47 | super().__init__(*args, **kwargs)
48 |
49 | @property
50 | def base_module(self) -> nn.Module:
51 | """Returns module class from src.common_modules based on the class name."""
52 | return getattr(__import__("src.modules", fromlist=[""]), self.name)
53 |
54 | def __call__(self, repeat: int = 1):
55 | """call method.
56 | Build module
57 | """
58 | module = []
59 | args = self.args
60 | for i in range(repeat):
61 | module.append(self.base_module(*args))
62 | return self._get_module(module)
63 |
--------------------------------------------------------------------------------
/automl/src/modules/flatten.py:
--------------------------------------------------------------------------------
1 | """Flatten module, generator.
2 |
3 | - Author: Jongkuk Lim
4 | - Contact: lim.jeikei@gmail.com
5 | """
6 | from torch import nn as nn
7 |
8 | from src.modules.base_generator import GeneratorAbstract
9 |
10 |
11 | class FlattenGenerator(GeneratorAbstract):
12 | """Flatten module generator."""
13 |
14 | def __init__(self, *args, **kwargs):
15 | super().__init__(*args, **kwargs)
16 |
17 | @property
18 | def out_channel(self) -> int:
19 | return self.in_channel
20 |
21 | def __call__(self, repeat: int = 1):
22 | return self._get_module(nn.Flatten())
23 |
--------------------------------------------------------------------------------
/automl/src/modules/linear.py:
--------------------------------------------------------------------------------
1 | """Linear module, generator.
2 |
3 | - Author: Jongkuk Lim
4 | - Contact: lim.jeikei@gmail.com
5 | """
6 | from typing import Union
7 |
8 | import torch
9 | from torch import nn as nn
10 |
11 | from src.modules.base_generator import GeneratorAbstract
12 | from src.utils.torch_utils import Activation
13 |
14 |
15 | class Linear(nn.Module):
16 | """Linear module."""
17 |
18 | def __init__(self, in_channel: int, out_channel: int, activation: Union[str, None]):
19 | """
20 |
21 | Args:
22 | in_channel: input channels.
23 | out_channel: output channels.
24 | activation: activation name. If None is given, nn.Identity is applied
25 | which is no activation.
26 | """
27 | super().__init__()
28 | self.linear = nn.Linear(in_channel, out_channel)
29 | self.activation = Activation(activation)()
30 |
31 | def forward(self, x: torch.Tensor) -> torch.Tensor:
32 | """Forward."""
33 | return self.activation(self.linear(x))
34 |
35 |
36 | class LinearGenerator(GeneratorAbstract):
37 | """Linear (fully connected) module generator for parsing."""
38 |
39 | def __init__(self, *args, **kwargs):
40 | """Initailize."""
41 | super().__init__(*args, **kwargs)
42 |
43 | @property
44 | def out_channel(self) -> int:
45 | """Get out channel size."""
46 | return self.args[0]
47 |
48 | def __call__(self, repeat: int = 1):
49 | # TODO: Apply repeat
50 | act = self.args[1] if len(self.args) > 1 else None
51 |
52 | return self._get_module(
53 | Linear(self.in_channel, self.out_channel, activation=act)
54 | )
55 |
--------------------------------------------------------------------------------
/automl/src/modules/lstm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from abc import abstractmethod
5 | from src.modules.base_generator import GeneratorAbstract
6 |
7 |
8 | class BaseModel(nn.Module):
9 | """
10 | Base class for all models
11 | """
12 |
13 | @abstractmethod
14 | def forward(self, *inputs):
15 | """
16 | Forward pass logic
17 |
18 | :return: Model output
19 | """
20 | raise NotImplementedError
21 |
22 | def __str__(self):
23 | """
24 | Model prints with number of trainable parameters
25 | """
26 | model_parameters = filter(lambda p: p.requires_grad, self.parameters())
27 | params = sum([np.prod(p.size()) for p in model_parameters])
28 | return super().__str__() + "\nTrainable parameters: {}".format(params)
29 |
30 |
31 | class Lstm(BaseModel):
32 | def __init__(self, name="rnn", xdim=28, hdim=256, ydim=3, n_layer=3, dropout=0):
33 | super(Lstm, self).__init__()
34 | self.name = name
35 | self.xdim = xdim
36 | self.hdim = hdim
37 | self.ydim = ydim
38 | self.n_layer = n_layer # K
39 | self.dropout = dropout
40 |
41 | self.rnn = nn.LSTM(
42 | input_size=xdim,
43 | hidden_size=hdim,
44 | num_layers=n_layer,
45 | batch_first=True,
46 | dropout=dropout,
47 | )
48 | self.lin = nn.Linear(self.hdim, self.ydim)
49 |
50 | def forward(self, x):
51 | # Set initial hidden and cell states
52 | device = (
53 | torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
54 | )
55 | # print(device)
56 | h0 = torch.zeros(self.n_layer, x.size(0), self.hdim).to(device)
57 | c0 = torch.zeros(self.n_layer, x.size(0), self.hdim).to(device)
58 |
59 | # RNN
60 | rnn_out, (hn, cn) = self.rnn(x, (h0, c0))
61 | # x:[N x L x Q] => rnn_out:[N x L x D]
62 | # Linear
63 | out = self.lin(rnn_out[:, -1, :]).view([-1, self.ydim])
64 | return out
65 |
66 |
67 | class LstmGenerator(GeneratorAbstract):
68 | """Pretrained Bert block generator."""
69 |
70 | def __init__(self, *args, **kwargs):
71 | super().__init__(*args, **kwargs)
72 |
73 | @property
74 | def base_module(self) -> nn.Module:
75 | """Returns module class from src.common_modules based on the class name."""
76 | return getattr(__import__("src.modules", fromlist=[""]), self.name)
77 |
78 | def __call__(self, repeat: int = 1):
79 | """call method."""
80 | module = []
81 | args = self.args
82 | for i in range(repeat):
83 | module.append(self.base_module(*args))
84 | return self._get_module(module)
85 |
--------------------------------------------------------------------------------
/automl/src/modules/mbconv.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from src.modules.base_generator import GeneratorAbstract
7 |
8 |
9 | class MBConv(nn.Module):
10 | """MBConvBlock used in Efficientnet.
11 |
12 | Reference:
13 | https://github.com/narumiruna/efficientnet-pytorch/blob/master/efficientnet/models/efficientnet.py
14 | Note:
15 | Drop connect rate is disabled.
16 | """
17 |
18 | def __init__(
19 | self,
20 | in_planes,
21 | out_planes,
22 | expand_ratio,
23 | kernel_size,
24 | stride,
25 | reduction_ratio=4,
26 | drop_connect_rate=0.0,
27 | ):
28 | super(MBConv, self).__init__()
29 | self.drop_connect_rate = drop_connect_rate
30 | self.use_residual = in_planes == out_planes and stride == 1
31 | assert stride in [1, 2]
32 | assert kernel_size in [3, 5]
33 |
34 | hidden_dim = in_planes * expand_ratio
35 | reduced_dim = max(1, in_planes // reduction_ratio)
36 |
37 | layers = []
38 | # pw
39 | if in_planes != hidden_dim:
40 | layers.append(ConvBNReLU(in_planes, hidden_dim, 1))
41 |
42 | layers.extend(
43 | [
44 | # dw
45 | ConvBNReLU(
46 | hidden_dim,
47 | hidden_dim,
48 | kernel_size,
49 | stride=stride,
50 | groups=hidden_dim,
51 | ),
52 | # se
53 | SqueezeExcitation(hidden_dim, reduced_dim),
54 | # pw-linear
55 | nn.Conv2d(hidden_dim, out_planes, 1, bias=False),
56 | nn.BatchNorm2d(out_planes),
57 | ]
58 | )
59 | self.conv = nn.Sequential(*layers)
60 |
61 | def _drop_connect(self, x):
62 | if not self.training:
63 | return x
64 | if self.drop_connect_rate >= 1.0:
65 | return x
66 | keep_prob = 1.0 - self.drop_connect_rate
67 | batch_size = x.size(0)
68 | random_tensor = keep_prob
69 | random_tensor += torch.rand(batch_size, 1, 1, 1, device=x.device)
70 | binary_tensor = random_tensor.floor()
71 | return x.div(keep_prob) * binary_tensor
72 |
73 | def forward(self, x):
74 | if self.use_residual:
75 | return x + self._drop_connect(self.conv(x))
76 | else:
77 | return self.conv(x)
78 |
79 |
80 | class ConvBNReLU(nn.Sequential):
81 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, groups=1):
82 | padding = self._get_padding(kernel_size, stride)
83 | super(ConvBNReLU, self).__init__(
84 | nn.ZeroPad2d(padding),
85 | nn.Conv2d(
86 | in_planes,
87 | out_planes,
88 | kernel_size,
89 | stride,
90 | padding=0,
91 | groups=groups,
92 | bias=False,
93 | ),
94 | nn.BatchNorm2d(out_planes),
95 | Swish(),
96 | )
97 |
98 | def _get_padding(self, kernel_size, stride):
99 | p = max(kernel_size - stride, 0)
100 | return [p // 2, p - p // 2, p // 2, p - p // 2]
101 |
102 |
103 | class SwishImplementation(torch.autograd.Function):
104 | @staticmethod
105 | def forward(ctx, i):
106 | result = i * torch.sigmoid(i)
107 | ctx.save_for_backward(i)
108 | return result
109 |
110 | @staticmethod
111 | def backward(ctx, grad_output):
112 | i = ctx.saved_variables[0]
113 | sigmoid_i = torch.sigmoid(i)
114 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
115 |
116 |
117 | class Swish(nn.Module):
118 | def forward(self, x):
119 | return SwishImplementation.apply(x)
120 |
121 |
122 | def _round_repeats(repeats, depth_mult):
123 | if depth_mult == 1.0:
124 | return repeats
125 | return int(math.ceil(depth_mult * repeats))
126 |
127 |
128 | class SqueezeExcitation(nn.Module):
129 | """Squeeze-Excitation layer used in MBConv."""
130 |
131 | def __init__(self, in_planes, reduced_dim):
132 | super(SqueezeExcitation, self).__init__()
133 | self.se = nn.Sequential(
134 | nn.AdaptiveAvgPool2d(1),
135 | nn.Conv2d(in_planes, reduced_dim, 1),
136 | Swish(),
137 | nn.Conv2d(reduced_dim, in_planes, 1),
138 | nn.Sigmoid(),
139 | )
140 |
141 | def forward(self, x):
142 | return x * self.se(x)
143 |
144 |
145 | class MBConvGenerator(GeneratorAbstract):
146 | """Bottleneck block generator."""
147 |
148 | def __init__(self, *args, **kwargs):
149 | super().__init__(*args, **kwargs)
150 |
151 | @property
152 | def out_channel(self) -> int:
153 | """Get out channel size."""
154 | return self._get_divisible_channel(self.args[0] * self.width_multiply)
155 |
156 | @property
157 | def base_module(self) -> nn.Module:
158 | """Returns module class from src.common_modules based on the class name."""
159 | return getattr(__import__("src.modules", fromlist=[""]), self.name)
160 |
161 | def __call__(self, repeat: int = 1):
162 | """call method.
163 |
164 | InvertedResidualv3 args consists,
165 | repeat(=n), [c, t, s] // note original notation from paper is [t, c, n, s]
166 | """
167 | module = []
168 | t, c, s, k = self.args # c is equivalent as self.out_channel
169 | inp, oup = self.in_channel, self.out_channel
170 | for i in range(repeat):
171 | stride = s if i == 0 else 1
172 | module.append(
173 | self.base_module(
174 | in_planes=inp,
175 | out_planes=oup,
176 | expand_ratio=t,
177 | stride=stride,
178 | kernel_size=k,
179 | )
180 | )
181 | inp = oup
182 | return self._get_module(module)
183 |
--------------------------------------------------------------------------------
/automl/src/modules/poolings.py:
--------------------------------------------------------------------------------
1 | """Module generator related to pooling operations.
2 |
3 | - Author: Jongkuk Lim
4 | - Contact: lim.jeikei@gmail.com
5 | """
6 | # pylint: disable=useless-super-delegation
7 | from torch import nn
8 |
9 | from src.modules.base_generator import GeneratorAbstract
10 |
11 |
12 | class MaxPoolGenerator(GeneratorAbstract):
13 | """Max pooling module generator."""
14 |
15 | def __init__(self, *args, **kwargs):
16 | super().__init__(*args, **kwargs)
17 |
18 | @property
19 | def out_channel(self) -> int:
20 | """Get out channel size."""
21 | # error: Value of type "Optional[List[int]]" is not indexable
22 | return self.in_channel
23 |
24 | @property
25 | def base_module(self) -> nn.Module:
26 | """Base module."""
27 | return getattr(nn, f"{self.name}2d")
28 |
29 | def __call__(self, repeat: int = 1):
30 | module = (
31 | [self.base_module(*self.args) for _ in range(repeat)]
32 | if repeat > 1
33 | else self.base_module(*self.args)
34 | )
35 | return self._get_module(module)
36 |
37 |
38 | class AvgPoolGenerator(MaxPoolGenerator):
39 | """Average pooling module generator."""
40 |
41 | def __init__(self, *args, **kwargs):
42 | super().__init__(*args, **kwargs)
43 |
44 |
45 | class GlobalAvgPool(nn.AdaptiveAvgPool2d):
46 | """Global average pooling module."""
47 |
48 | def __init__(self, output_size=1):
49 | """Initialize."""
50 | super().__init__(output_size=output_size)
51 |
52 |
53 | class GlobalAvgPoolGenerator(GeneratorAbstract):
54 | """Global average pooling module generator."""
55 |
56 | def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
57 | super().__init__(*args, **kwargs)
58 | self.output_size = 1
59 | if len(args) > 1:
60 | self.output_size = args[1]
61 |
62 | @property
63 | def out_channel(self) -> int:
64 | """Get out channel size."""
65 | return self.in_channel
66 |
67 | def __call__(self, repeat: int = 1):
68 | return self._get_module(GlobalAvgPool(self.output_size))
69 |
--------------------------------------------------------------------------------
/automl/src/utils/common.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Union
2 |
3 | import numpy as np
4 | import yaml
5 | from torchvision.datasets import ImageFolder, VisionDataset
6 | import os
7 | import json
8 | from pathlib import Path
9 |
10 |
11 | def read_yaml(cfg: Union[str, Dict[str, Any]]):
12 | if not isinstance(cfg, dict):
13 | with open(cfg) as f:
14 | config = yaml.load(f, Loader=yaml.FullLoader)
15 | else:
16 | config = cfg
17 | return config
18 |
19 |
20 | def write_yaml(cfg: Union[str, Dict[str, Any]], name, path=""):
21 | if isinstance(cfg, dict):
22 | if not os.path.exists(path):
23 | os.mkdir(path)
24 | with open(os.path.join(path, name + ".yaml"), "w") as f:
25 | yaml.dump(cfg, f)
26 | else:
27 | ValueError
28 |
29 |
30 | def get_label_counts(dataset_path: str):
31 | """Counts for each label."""
32 | if not dataset_path:
33 | return None
34 | td = ImageFolder(root=dataset_path)
35 | # get label distribution
36 | label_counts = [0] * len(td.classes)
37 | for p, l in td.samples:
38 | label_counts[l] += 1
39 | return label_counts
40 |
41 |
42 | def write_json(content, fname):
43 | fname = Path(fname)
44 | with fname.open("wt") as handle:
45 | json.dump(content, handle, indent=4, sort_keys=False)
46 |
--------------------------------------------------------------------------------
/automl/src/utils/data.py:
--------------------------------------------------------------------------------
1 | """Utils for model compression.
2 |
3 | - Author: wlaud1001
4 | - Email: wlaud1001@snu.ac.kr
5 | - Reference:
6 | https://github.com/j-marple-dev/model_compression
7 | """
8 |
9 | import random
10 | from multiprocessing import Pool
11 | from typing import Tuple
12 |
13 |
14 | def get_rand_bbox_coord(
15 | w: int, h: int, len_ratio: float
16 | ) -> Tuple[Tuple[int, int], Tuple[int, int]]:
17 | """Get a coordinate of random box."""
18 | size_hole_w = int(len_ratio * w)
19 | size_hole_h = int(len_ratio * h)
20 | x = random.randint(0, w) # [0, w]
21 | y = random.randint(0, h) # [0, h]
22 |
23 | x0 = max(0, x - size_hole_w // 2)
24 | y0 = max(0, y - size_hole_h // 2)
25 | x1 = min(w, x + size_hole_w // 2)
26 | y1 = min(h, y + size_hole_h // 2)
27 | return (x0, y0), (x1, y1)
28 |
29 | def weights_for_balanced_classes(subset, nclasses):
30 | count = [0] * nclasses
31 | for i in subset:
32 | count[i[1]] += 1
33 | weight_per_class = [0.] * nclasses
34 | N = float(sum(count))
35 | for i in range(nclasses):
36 | weight_per_class[i] = N/float(count[i])
37 | weight = [0] * len(images)
38 | for idx, val in enumerate(images):
39 | weight[idx] = weight_per_class[val[1]]
40 | return weightget_rand_bbox_coord
--------------------------------------------------------------------------------
/automl/src/utils/torch_utils.py:
--------------------------------------------------------------------------------
1 | """Common utility functions.
2 |
3 | - Author: Jongkuk Lim
4 | - Contact: lim.jeikei@gmail.com
5 | """
6 |
7 | import math
8 | import os
9 | from typing import List, Optional, Tuple, Union
10 |
11 | import numpy as np
12 | import torch
13 | from torch import nn
14 | from torch.utils.data import Subset
15 | import random
16 | from .common import write_yaml
17 |
18 |
19 | def convert_model_to_torchscript(
20 | model: nn.Module, path: Optional[str] = None
21 | ) -> torch.jit.ScriptModule:
22 | """Convert PyTorch Module to TorchScript.
23 |
24 | Args:
25 | model: PyTorch Module.
26 |
27 | Return:
28 | TorchScript module.
29 | """
30 | model.eval()
31 | jit_model = torch.jit.script(model)
32 |
33 | if path:
34 | jit_model.save(path)
35 |
36 | return jit_model
37 |
38 |
39 | def split_dataset_index(
40 | train_dataset: torch.utils.data.Dataset, n_data: int, split_ratio: float = 0.1
41 | ) -> Tuple[Subset, Subset]:
42 | """Split dataset indices with split_ratio.
43 |
44 | Args:
45 | n_data: number of total data
46 | split_ratio: split ratio (0.0 ~ 1.0)
47 |
48 | Returns:
49 | SubsetRandomSampler ({split_ratio} ~ 1.0)
50 | SubsetRandomSampler (0 ~ {split_ratio})
51 | """
52 | indices = np.arange(n_data)
53 | split = int(split_ratio * indices.shape[0])
54 |
55 | train_idx = indices[split:]
56 | valid_idx = indices[:split]
57 |
58 | train_subset = Subset(train_dataset, train_idx)
59 | valid_subset = Subset(train_dataset, valid_idx)
60 |
61 | return train_subset, valid_subset
62 |
63 |
64 | def save_model(model, path, data, device):
65 | """save model to torch script, onnx."""
66 | try:
67 | torch.save(model.state_dict(), f=path)
68 | ts_path = os.path.splitext(path)[:-1][0] + ".ts"
69 | convert_model_to_torchscript(model, ts_path)
70 | except Exception:
71 | print("Failed to save torch")
72 |
73 |
74 | def save_model2(model, path, data, device, model_config):
75 | """save model to torch script, onnx."""
76 | try:
77 | if not os.path.exists(path):
78 | os.mkdir(path)
79 | torch.save(model.state_dict(), f=os.path.join(path, "result_model.pt"))
80 | write_yaml(model_config, "model_config", path=path)
81 | except Exception:
82 | print("Failed to save torch")
83 |
84 |
85 | def model_info(model, verbose=False):
86 | """Print out model info."""
87 | n_p = sum(x.numel() for x in model.parameters()) # number parameters
88 | n_g = sum(
89 | x.numel() for x in model.parameters() if x.requires_grad
90 | ) # number gradients
91 | if verbose:
92 | print(
93 | "%5s %40s %9s %12s %20s %10s %10s"
94 | % ("layer", "name", "gradient", "parameters", "shape", "mu", "sigma")
95 | )
96 | for i, (name, p) in enumerate(model.named_parameters()):
97 | name = name.replace("module_list.", "")
98 | print(
99 | "%5g %40s %9s %12g %20s %10.3g %10.3g"
100 | % (
101 | i,
102 | name,
103 | p.requires_grad,
104 | p.numel(),
105 | list(p.shape),
106 | p.mean(),
107 | p.std(),
108 | )
109 | )
110 |
111 | print(
112 | f"Model Summary: {len(list(model.modules()))} layers, "
113 | f"{n_p:,d} parameters, {n_g:,d} gradients"
114 | )
115 |
116 |
117 | @torch.no_grad()
118 | def check_runtime(
119 | model: nn.Module, word_length: List[int], device: torch.device, repeat: int = 100
120 | ) -> float:
121 | # test part
122 | # device = "cpu"
123 | # model.to(device)
124 | # test part
125 |
126 | repeat = min(repeat, 20)
127 | inputs = {
128 | "input_ids": torch.randint(0, 30000, [1, word_length]).to(device),
129 | "token_type_ids": torch.randint(0, 1, [1, word_length]).to(device),
130 | "attention_mask": torch.randint(0, 1, [1, word_length]).to(device),
131 | }
132 | measure = []
133 | start = torch.cuda.Event(enable_timing=True)
134 | end = torch.cuda.Event(enable_timing=True)
135 |
136 | model.eval()
137 | for _ in range(repeat):
138 | start.record()
139 | _ = model(inputs)
140 | end.record()
141 | # Waits for everything to finish running
142 | torch.cuda.synchronize()
143 | measure.append(start.elapsed_time(end))
144 |
145 | measure.sort()
146 | n = len(measure)
147 | k = int(round(n * (0.2) / 2))
148 | trimmed_measure = measure[k + 1 : n - k]
149 |
150 | with torch.autograd.profiler.profile(use_cuda=True) as prof:
151 | _ = model(inputs)
152 | print(prof)
153 | print("measured time(ms)", np.mean(trimmed_measure))
154 | model.train()
155 | return np.mean(trimmed_measure)
156 |
157 |
158 | def make_divisible(v: float, divisor: int = 8, min_value: Optional[int] = None) -> int:
159 | """
160 | This function is taken from the original tf repo.
161 | It ensures that all layers have a channel number that is divisible by 8
162 | It can be seen here:
163 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
164 | """
165 | if min_value is None:
166 | min_value = divisor
167 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
168 | # Make sure that round down does not go down by more than 10%.
169 | if new_v < 0.9 * v:
170 | new_v += divisor
171 | return new_v
172 |
173 |
174 | def autopad(
175 | kernel_size: Union[int, List[int]], padding: Union[int, None] = None
176 | ) -> Union[int, List[int]]:
177 | """Auto padding calculation for pad='same' in TensorFlow."""
178 | # Pad to 'same'
179 | if isinstance(kernel_size, int):
180 | kernel_size = [kernel_size]
181 |
182 | return padding or [x // 2 for x in kernel_size]
183 |
184 |
185 | class Activation:
186 | """Convert string activation name to the activation class."""
187 |
188 | def __init__(self, act_type: Union[str, None]) -> None:
189 | """Convert string activation name to the activation class.
190 |
191 | Args:
192 | type: Activation name.
193 |
194 | Returns:
195 | nn.Identity if {type} is None.
196 | """
197 | self.type = act_type
198 | self.args = [1] if self.type == "Softmax" else []
199 |
200 | def __call__(self) -> nn.Module:
201 | if self.type is None:
202 | return nn.Identity()
203 | elif hasattr(nn, self.type):
204 | return getattr(nn, self.type)(*self.args)
205 | else:
206 | return getattr(
207 | __import__("src.modules.activations", fromlist=[""]), self.type
208 | )()
209 |
210 |
211 | if __name__ == "__main__":
212 | # test
213 | check_runtime(None, [32, 32] + [3])
214 |
--------------------------------------------------------------------------------
/automl/tests/test_model_conversion.py:
--------------------------------------------------------------------------------
1 | """Unit test for model conversion to TorchScript.
2 |
3 | - Author: Jongkuk Lim
4 | - Contact: lim.jeikei@gmail.com
5 | """
6 |
7 |
8 | import os
9 |
10 | import torch
11 |
12 | from src.model import Model
13 | from src.utils.torch_utils import convert_model_to_torchscript
14 |
15 |
16 | class TestModelConversion:
17 | """Test model conversion."""
18 |
19 | # pylint: disable=no-self-use
20 |
21 | INPUT1 = torch.rand(1, 3, 128, 128)
22 | INPUT2 = torch.rand(8, 3, 256, 256)
23 | SAVE_PATH = "tests/.test_model.ts"
24 |
25 | def _convert_evaluation(self, path: str) -> None:
26 | """Model conversion test."""
27 | model = Model(path)
28 | convert_model_to_torchscript(model, path=TestModelConversion.SAVE_PATH)
29 |
30 | ts_model = torch.jit.load(TestModelConversion.SAVE_PATH)
31 |
32 | out_tensor1 = ts_model(TestModelConversion.INPUT1)
33 | out_tensor2 = ts_model(TestModelConversion.INPUT2)
34 |
35 | os.remove(TestModelConversion.SAVE_PATH)
36 | assert out_tensor1.shape == torch.Size((1, 9))
37 | assert out_tensor2.shape == torch.Size((8, 9))
38 |
39 | def test_mobilenetv3(self):
40 | """Test convert mobilenetv3 model to TorchScript."""
41 | self._convert_evaluation(os.path.join("configs", "model", "mobilenetv3.yaml"))
42 |
43 | def test_example(self):
44 | """Test convert example model to TorchScript."""
45 | self._convert_evaluation(os.path.join("configs", "model", "example.yaml"))
46 |
47 |
48 | if __name__ == "__main__":
49 | test = TestModelConversion()
50 | test.test_mobilenetv3()
51 | test.test_example()
52 |
--------------------------------------------------------------------------------
/automl/tests/test_model_parser.py:
--------------------------------------------------------------------------------
1 | """Model parse test.
2 |
3 | - Author: Jongkuk Lim
4 | - Contact: lim.jeikei@gmail.com
5 | """
6 |
7 | import os
8 |
9 | import torch
10 |
11 | from src.model import Model
12 |
13 |
14 | class TestModelParser:
15 | """Test model parser."""
16 |
17 | # pylint: disable=no-self-use
18 |
19 | INPUT = torch.rand(8, 3, 256, 256)
20 |
21 | def test_mobilenetv3(self):
22 | """Test mobilenetv3 model."""
23 | model = Model(os.path.join("configs", "model", "mobilenetv3.yaml"))
24 | assert model(TestModelParser.INPUT).shape == torch.Size([8, 9])
25 |
26 | def test_example(self):
27 | """Test example model."""
28 | model = Model(os.path.join("configs", "model", "example.yaml"))
29 | assert model(TestModelParser.INPUT).shape == torch.Size([8, 9])
30 |
31 |
32 | if __name__ == "__main__":
33 | test = TestModelParser()
34 |
35 | test.test_mobilenetv3()
36 | test.test_example()
37 |
--------------------------------------------------------------------------------
/automl/train.py:
--------------------------------------------------------------------------------
1 | """Baseline train
2 | - Author: Junghoon Kim
3 | - Contact: placidus36@gmail.com
4 | """
5 |
6 | import argparse
7 | import os
8 | from datetime import datetime
9 | from typing import Any, Dict, Tuple, Union
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.optim as optim
14 | import yaml
15 |
16 | from src.dataloader import create_dataloader
17 | from src.loss import CustomCriterion
18 | from src.model import Model
19 | from src.trainer import TorchTrainer
20 | from src.utils.common import get_label_counts, read_yaml
21 | from src.utils.torch_utils import check_runtime, model_info
22 |
23 |
24 | def train(
25 | model_config: Dict[str, Any],
26 | data_config: Dict[str, Any],
27 | log_dir: str,
28 | fp16: bool,
29 | device: torch.device,
30 | ) -> Tuple[float, float, float]:
31 | """Train."""
32 | # save model_config, data_config
33 | with open(os.path.join(log_dir, "data.yml"), "w") as f:
34 | yaml.dump(data_config, f, default_flow_style=False)
35 | with open(os.path.join(log_dir, "model.yml"), "w") as f:
36 | yaml.dump(model_config, f, default_flow_style=False)
37 |
38 | model_instance = Model(model_config, verbose=True)
39 | model_path = os.path.join(log_dir, "best.pt")
40 | print(f"Model save path: {model_path}")
41 | if os.path.isfile(model_path):
42 | model_instance.model.load_state_dict(
43 | torch.load(model_path, map_location=device)
44 | )
45 | model_instance.model.to(device)
46 |
47 | # Create dataloader
48 | train_dl, val_dl, test_dl = create_dataloader(data_config)
49 |
50 | # Create optimizer, scheduler, criterion
51 | optimizer = torch.optim.SGD(
52 | model_instance.model.parameters(), lr=data_config["INIT_LR"], momentum=0.9
53 | )
54 | scheduler = torch.optim.lr_scheduler.OneCycleLR(
55 | optimizer=optimizer,
56 | max_lr=data_config["INIT_LR"],
57 | steps_per_epoch=len(train_dl),
58 | epochs=data_config["EPOCHS"],
59 | pct_start=0.05,
60 | )
61 | criterion = CustomCriterion(
62 | samples_per_cls=get_label_counts(data_config["DATA_PATH"])
63 | if data_config["DATASET"] == "TACO"
64 | else None,
65 | device=device,
66 | )
67 | # Amp loss scaler
68 | scaler = (
69 | torch.cuda.amp.GradScaler() if fp16 and device != torch.device("cpu") else None
70 | )
71 |
72 | # Create trainer
73 | trainer = TorchTrainer(
74 | model=model_instance.model,
75 | criterion=criterion,
76 | optimizer=optimizer,
77 | scheduler=scheduler,
78 | scaler=scaler,
79 | device=device,
80 | model_path=model_path,
81 | verbose=1,
82 | )
83 | best_acc, best_f1 = trainer.train(
84 | train_dataloader=train_dl,
85 | n_epoch=data_config["EPOCHS"],
86 | val_dataloader=val_dl if val_dl else test_dl,
87 | )
88 |
89 | # evaluate model with test set
90 | model_instance.model.load_state_dict(torch.load(model_path))
91 | test_loss, test_f1, test_acc = trainer.test(
92 | model=model_instance.model, test_dataloader=val_dl if val_dl else test_dl
93 | )
94 | return test_loss, test_f1, test_acc
95 |
96 |
97 | if __name__ == "__main__":
98 | parser = argparse.ArgumentParser(description="Train model.")
99 | parser.add_argument(
100 | "--model",
101 | default="configs/model/mobilenetv3.yaml",
102 | type=str,
103 | help="model config",
104 | )
105 | parser.add_argument(
106 | "--data", default="configs/data/taco.yaml", type=str, help="data config"
107 | )
108 | args = parser.parse_args()
109 |
110 | model_config = read_yaml(cfg=args.model)
111 | data_config = read_yaml(cfg=args.data)
112 |
113 | data_config["DATA_PATH"] = os.environ.get("SM_CHANNEL_TRAIN", data_config["DATA_PATH"])
114 |
115 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
116 | log_dir = os.environ.get("SM_MODEL_DIR", os.path.join("exp", 'latest'))
117 |
118 | if os.path.exists(log_dir):
119 | modified = datetime.fromtimestamp(os.path.getmtime(log_dir + '/best.pt'))
120 | new_log_dir = os.path.dirname(log_dir) + '/' + modified.strftime("%Y-%m-%d_%H-%M-%S")
121 | os.rename(log_dir, new_log_dir)
122 |
123 | os.makedirs(log_dir, exist_ok=True)
124 |
125 | test_loss, test_f1, test_acc = train(
126 | model_config=model_config,
127 | data_config=data_config,
128 | log_dir=log_dir,
129 | fp16=data_config["FP16"],
130 | device=device,
131 | )
132 |
133 |
--------------------------------------------------------------------------------
/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_data_loader import *
2 | from .base_model import *
3 | from .base_trainer import *
4 |
--------------------------------------------------------------------------------
/base/base_data_loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import DataLoader
3 | from torch.utils.data.dataloader import default_collate
4 | from torch.utils.data.sampler import SubsetRandomSampler
5 |
6 |
7 | class BaseDataLoader(DataLoader):
8 | """
9 | Base class for all data loaders
10 | """
11 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
12 | self.validation_split = validation_split
13 | self.shuffle = shuffle
14 |
15 | self.batch_idx = 0
16 | self.n_samples = len(dataset)
17 |
18 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
19 |
20 | self.init_kwargs = {
21 | 'dataset': dataset,
22 | 'batch_size': batch_size,
23 | 'shuffle': self.shuffle,
24 | 'collate_fn': collate_fn,
25 | 'num_workers': num_workers
26 | }
27 | super().__init__(sampler=self.sampler, **self.init_kwargs)
28 |
29 | def _split_sampler(self, split):
30 | if split == 0.0:
31 | return None, None
32 |
33 | idx_full = np.arange(self.n_samples)
34 |
35 | np.random.seed(0)
36 | np.random.shuffle(idx_full)
37 |
38 | if isinstance(split, int):
39 | assert split > 0
40 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
41 | len_valid = split
42 | else:
43 | len_valid = int(self.n_samples * split)
44 |
45 | valid_idx = idx_full[0:len_valid]
46 | train_idx = np.delete(idx_full, np.arange(0, len_valid))
47 |
48 | train_sampler = SubsetRandomSampler(train_idx)
49 | valid_sampler = SubsetRandomSampler(valid_idx)
50 |
51 | # turn off shuffle option which is mutually exclusive with sampler
52 | self.shuffle = False
53 | self.n_samples = len(train_idx)
54 |
55 | return train_sampler, valid_sampler
56 |
57 | def split_validation(self):
58 | if self.valid_sampler is None:
59 | return None
60 | else:
61 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
62 |
--------------------------------------------------------------------------------
/base/base_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | from abc import abstractmethod
4 |
5 |
6 | class BaseModel(nn.Module):
7 | """
8 | Base class for all models
9 | """
10 | @abstractmethod
11 | def forward(self, *inputs):
12 | """
13 | Forward pass logic
14 |
15 | :return: Model output
16 | """
17 | raise NotImplementedError
18 |
19 | def __str__(self):
20 | """
21 | Model prints with number of trainable parameters
22 | """
23 | model_parameters = filter(lambda p: p.requires_grad, self.parameters())
24 | params = sum([np.prod(p.size()) for p in model_parameters])
25 | return super().__str__() + '\nTrainable parameters: {}'.format(params)
26 |
--------------------------------------------------------------------------------
/base/base_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import shutil
4 | from abc import abstractmethod
5 | from numpy import inf
6 | from utils import write_json
7 |
8 |
9 | class BaseTrainer:
10 | """
11 | Base class for all trainers
12 | """
13 | def __init__(self, model, criterion, metric_ftns, optimizer, config):
14 | self.config = config
15 | self.logger = config.get_logger('trainer', config['trainer']['verbosity'])
16 |
17 | self.model = model
18 | self.criterion = criterion
19 | self.metric_ftns = metric_ftns
20 | self.optimizer = optimizer
21 |
22 | cfg_trainer = config['trainer']
23 | self.epochs = cfg_trainer['epochs']
24 | self.save_steps = cfg_trainer['save']['steps']
25 | self.save_limits = cfg_trainer['save']['limits']
26 | self.monitor = cfg_trainer.get('monitor', 'off')
27 |
28 | # configuration to monitor model performance and save best
29 | if self.monitor == 'off':
30 | self.mnt_mode = 'off'
31 | self.mnt_best = 0
32 | else:
33 | self.mnt_mode, self.mnt_metric = self.monitor.split()
34 | assert self.mnt_mode in ['min', 'max']
35 |
36 | self.mnt_best = inf if self.mnt_mode == 'min' else -inf
37 | self.early_stop = cfg_trainer.get('early_stop', inf)
38 | if self.early_stop <= 0:
39 | self.early_stop = inf
40 |
41 | self.not_improved_count = 0
42 |
43 | self.checkpoint_dir = cfg_trainer['save']['dir']
44 |
45 | if config.resume is not None:
46 | self._resume_checkpoint(config.resume)
47 |
48 | @abstractmethod
49 | def train(self):
50 | """
51 | Full training logic.
52 | """
53 |
54 | raise NotImplementedError
55 |
56 | @abstractmethod
57 | def _validation(self, step):
58 | """
59 | Full validation logic
60 |
61 | :param step: Current step number
62 | """
63 |
64 | raise NotImplementedError
65 |
66 | def _evaluate_performance(self, log):
67 | # evaluate model performance according to configured metric, save best checkpoint as model_best
68 | is_best = False
69 | if self.mnt_mode != 'off':
70 | try:
71 | # check whether model performance improved or not, according to specified metric(mnt_metric)
72 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
73 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
74 | except KeyError:
75 | self.logger.warning("Warning: Metric '{}' is not found. "
76 | "Model performance monitoring is disabled.".format(self.mnt_metric))
77 | self.mnt_mode = 'off'
78 | improved = False
79 |
80 | if improved:
81 | self.mnt_best = log[self.mnt_metric]
82 | self.not_improved_count = 0
83 | is_best = True
84 | else:
85 | self.not_improved_count += 1
86 |
87 | return is_best
88 |
89 | def _save_checkpoint(self, log, is_best=False):
90 | """
91 | Saving checkpoints
92 |
93 | :param epoch: current epoch number
94 | :param log: logging information of the epoch
95 | :param save_best: if True, rename the saved checkpoint to 'best_model.pt'
96 | """
97 | save_path = f'{self.checkpoint_dir}models/{self.config["name"]}/'
98 | chk_pt_path = save_path + f"steps_{log['steps']}/"
99 |
100 | # make path if there isn't
101 | if not os.path.exists(chk_pt_path):
102 | os.makedirs(chk_pt_path)
103 | # delete the oldest checkpoint not to exceed save limits
104 | if len(os.listdir(save_path)) > self.save_limits:
105 | shutil.rmtree(os.path.join(
106 | save_path,
107 | sorted(os.listdir(save_path),key = lambda x : (len(x), x))[0]
108 | )
109 | )
110 |
111 | self.logger.info("Saving checkpoint: {} ...".format(chk_pt_path))
112 | torch.save(self.model, os.path.join(chk_pt_path, "model.pt"))
113 | torch.save(self.optimizer.state_dict(), os.path.join(chk_pt_path, "optimizer.pt"))
114 |
115 | # save updated config file to the checkpoint dir
116 | write_json(self.config._config, os.path.join(chk_pt_path, "config.json"))
117 | write_json(log, os.path.join(chk_pt_path, "log.json"))
118 |
119 | # save best model.
120 | if is_best:
121 | best_path = f'{self.checkpoint_dir}best/{self.config["name"]}/'
122 |
123 | # make path if there isn't
124 | if not os.path.exists(best_path):
125 | os.makedirs(best_path)
126 | # delete old best files
127 | for file_ in os.listdir(best_path):
128 | os.remove(best_path + file_)
129 |
130 | self.logger.info("Saving current best: model_best.pt ...")
131 | torch.save(self.model, os.path.join(best_path, "best_model.pt"))
132 | torch.save(self.optimizer.state_dict(), os.path.join(best_path, "optimizer.pt"))
133 |
134 | # save updated config file to the checkpoint dir
135 | write_json(self.config._config, os.path.join(best_path, "config.json"))
136 | write_json(log, os.path.join(best_path, "log.json"))
137 |
138 | def _resume_checkpoint(self, resume_path):
139 | """
140 | Resume from saved checkpoints
141 |
142 | :param resume_path: Checkpoint path to be resumed
143 | """
144 | resume_path = str(resume_path)
145 | self.logger.info("Loading checkpoint: {} ...".format(resume_path))
146 | checkpoint = torch.load(resume_path)
147 | self.start_epoch = checkpoint['epoch'] + 1
148 | self.mnt_best = checkpoint['monitor_best']
149 |
150 | # load architecture params from checkpoint.
151 | if checkpoint['config']['arch'] != self.config['arch']:
152 | self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
153 | "checkpoint. This may yield an exception while state_dict is being loaded.")
154 | self.model.load_state_dict(checkpoint['state_dict'])
155 |
156 | # load optimizer state from checkpoint only when optimizer type is not changed.
157 | if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
158 | self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
159 | "Optimizer parameters not being resumed.")
160 | else:
161 | self.optimizer.load_state_dict(checkpoint['optimizer'])
162 |
163 | self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
164 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "beep_kcELECTRA_base_hate",
3 | "n_gpu": 1,
4 |
5 | "model": {
6 | "type": "BeepKcElectraHateModel",
7 | "args": {
8 | "name": "beomi/beep-KcELECTRA-base-hate",
9 | "num_classes": 3
10 | }
11 | },
12 | "tokenizer": {
13 | "type": "tokenizer/"
14 | },
15 | "data_loader": {
16 | "type": "MnistDataLoader",
17 | "args":{
18 | "data_dir": "AI-it/korean-hate-speech",
19 | "batch_size": 16,
20 | "max_length": 64,
21 | "shuffle": true,
22 | "validation_split": 0.1,
23 | "num_workers": 2
24 | },
25 | "data_files": {
26 | "train": "train_hate.csv",
27 | "valid": "dev_hate.csv"
28 | },
29 | "test_data_file": {
30 | "test": "test_hate_no_label.csv"
31 | }
32 | },
33 | "optimizer": {
34 | "type": "AdamW",
35 | "args":{
36 | "lr": 5e-5,
37 | "eps": 1e-8
38 | },
39 | "weight_decay": 0.0
40 | },
41 | "loss": "softmax",
42 | "metrics": [
43 | "macro_f1"
44 | ],
45 | "lr_scheduler": {
46 | "type": "CosineAnnealingLR",
47 | "args": {
48 | "T_max": 300,
49 | "eta_min": 1e-5
50 | }
51 | },
52 | "trainer": {
53 | "epochs": 2,
54 |
55 | "save": {
56 | "dir": "saved/",
57 | "steps": 300,
58 | "limits": 3
59 | },
60 | "verbosity": 2,
61 |
62 | "monitor": "max val/macro_f1",
63 | "early_stop": 2,
64 |
65 | "fp16": false
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/config_automl_test.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "beep_kcELECTRA_base_hate",
3 | "n_gpu": 1,
4 | "model": {
5 | "type": "BeepKcElectraHateModel",
6 | "args": {
7 | "name": "beomi/beep-KcELECTRA-base-hate",
8 | "num_classes": 3
9 | }
10 | },
11 | "data_loader": {
12 | "type": "MnistDataLoader",
13 | "args": {
14 | "data_dir": "data/",
15 | "batch_size": 32,
16 | "max_length": 64,
17 | "shuffle": true,
18 | "validation_split": 0.1,
19 | "num_workers": 2
20 | }
21 | },
22 | "optimizer": {
23 | "type": "AdamW",
24 | "args": {
25 | "lr": 5e-5,
26 | "eps": 1e-8
27 | },
28 | "weight_decay": 0.0
29 | },
30 | "loss": "softmax",
31 | "metrics": [
32 | "macro_f1"
33 | ],
34 | "lr_scheduler": {
35 | "type": "StepLR",
36 | "args": {
37 | "step_size": 50,
38 | "gamma": 0.1
39 | }
40 | },
41 | "trainer": {
42 | "epochs": 2,
43 | "save": {
44 | "dir": "saved/",
45 | "steps": 300,
46 | "limits": 3
47 | },
48 | "verbosity": 2,
49 | "monitor": "max val/macro_f1",
50 | "early_stop": 2
51 | },
52 | "data_dir": "AI-it/korean-hate-speech",
53 | "data_files": {
54 | "train": "train_hate.csv",
55 | "valid": "dev_hate.csv"
56 | },
57 | "test_data_file": {
58 | "test": "test_hate_no_label.csv"
59 | },
60 | "saved_folder": {
61 | "path": "automl/save",
62 | "trial": 2,
63 | "model_config": "model_config.yaml",
64 | "model_weight": "result_model.pt"
65 | }
66 | }
--------------------------------------------------------------------------------
/data_loader/data_loaders.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import pandas as pd
4 |
5 | import torch
6 | from torchvision import datasets, transforms
7 | from torch.utils.data import Dataset, DataLoader
8 |
9 | from torchsampler import ImbalancedDatasetSampler
10 | from transformers import PreTrainedTokenizer
11 | from utils import Preprocess, preprocess
12 | from datasets import load_dataset
13 | from base import BaseDataLoader
14 |
15 | LABEL_2_IDX = {
16 | "none": 0,
17 | "offensive": 1,
18 | "hate": 2
19 | }
20 | IDX_2_LABEL = {
21 | 0: "none",
22 | 1: "offensive",
23 | 2: "hate"
24 | }
25 |
26 |
27 | class KhsDataLoader(DataLoader):
28 | def __init__(self, tokenizer: PreTrainedTokenizer, max_length: int = None):
29 | self.tokenizer = tokenizer
30 | self.max_length = max_length if max_length else self.tokenizer.model_max_length
31 |
32 | def train_collate_fn(self, input_examples):
33 | input_texts, input_labels = [], []
34 | for input_example in input_examples:
35 | text, label = input_example
36 | input_texts.append(text)
37 | input_labels.append(label)
38 |
39 | encoded_texts = self.tokenizer.batch_encode_plus(
40 | input_texts,
41 | add_special_tokens=True,
42 | max_length=self.max_length,
43 | truncation=True,
44 | padding="max_length",
45 | return_tensors="pt",
46 | return_token_type_ids=True,
47 | return_attention_mask=True,
48 | ) # input_ids, token_type_ids, attention_mask
49 |
50 | input_ids = encoded_texts["input_ids"]
51 | token_type_ids = encoded_texts["token_type_ids"]
52 | attention_mask = encoded_texts["attention_mask"]
53 |
54 | return input_ids, token_type_ids, attention_mask, torch.tensor(input_labels)
55 |
56 | def test_collate_fn(self, input_examples):
57 | input_texts = []
58 | for input_example in input_examples:
59 | text = input_example
60 | input_texts.append(text)
61 |
62 | encoded_texts = self.tokenizer.batch_encode_plus(
63 | input_texts,
64 | add_special_tokens=True,
65 | max_length=self.max_length,
66 | truncation=True,
67 | padding="max_length",
68 | return_tensors="pt",
69 | return_token_type_ids=True,
70 | return_attention_mask=True,
71 | ) # input_ids, token_type_ids, attention_mask
72 |
73 | input_ids = encoded_texts["input_ids"]
74 | token_type_ids = encoded_texts["token_type_ids"]
75 | attention_mask = encoded_texts["attention_mask"]
76 |
77 | return input_ids, token_type_ids, attention_mask
78 |
79 | def get_dataloader(self, name, data_dir, data_files, batch_size, **kwargs):
80 | data_files = dict(data_files)
81 | datasets = load_dataset(data_dir, data_files=data_files, use_auth_token=True)
82 | dataset = get_preprocessed_data(datasets[name], name)
83 | dataset = KhsDataset(dataset, name)
84 |
85 | sampler = None
86 |
87 | if name == "test":
88 | collate_fn = self.test_collate_fn
89 | else:
90 | collate_fn = self.train_collate_fn
91 | sampler = ImbalancedDatasetSampler(dataset)
92 |
93 | g = torch.Generator()
94 | g.manual_seed(0)
95 |
96 | return DataLoader(
97 | dataset,
98 | batch_size=batch_size,
99 | shuffle=False,
100 | sampler=sampler,
101 | collate_fn=collate_fn,
102 | num_workers=4,
103 | pin_memory=True,
104 | worker_init_fn=seed_worker,
105 | generator=g,
106 | **kwargs
107 | )
108 |
109 |
110 | class KhsDataset(Dataset):
111 | def __init__(self, data, data_type="train"):
112 | self.data_type = data_type
113 | self.texts = list(data.texts)
114 | if self.data_type == "train" or self.data_type == "valid":
115 | self.labels = list(data.labels)
116 |
117 | def __len__(self):
118 | return len(self.texts)
119 |
120 | def __getitem__(self, index):
121 | text = self.texts[index]
122 |
123 | if self.data_type == "train" or self.data_type == "valid":
124 | label = self.labels[index]
125 | converted_label = LABEL_2_IDX[label]
126 |
127 | return text, converted_label
128 |
129 | return text
130 |
131 |
132 | def get_preprocessed_data(dataset, name):
133 | if name == "test":
134 | preprocessed_sents = preprocess(dataset["comments"])
135 | out_dataset = pd.DataFrame(
136 | {
137 | "texts": preprocessed_sents,
138 | }
139 | )
140 | else:
141 | preprocessed_sents = preprocess(dataset["comments"])
142 | out_dataset = pd.DataFrame(
143 | {"texts": preprocessed_sents, "labels": dataset["label"]}
144 | )
145 |
146 | return out_dataset
147 |
148 |
149 | # https://pytorch.org/docs/stable/notes/randomness.html
150 | def seed_worker(worker_id):
151 | worker_seed = torch.initial_seed() % 2**32
152 | np.random.seed(worker_seed)
153 | random.seed(worker_seed)
--------------------------------------------------------------------------------
/data_loader/kd_data_loaders.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import pandas as pd
4 |
5 | import torch
6 | from torchvision import datasets, transforms
7 | from torch.utils.data import Dataset, DataLoader
8 |
9 | from torchsampler import ImbalancedDatasetSampler
10 | from transformers import PreTrainedTokenizer
11 | from utils import Preprocess, preprocess
12 | from datasets import load_dataset
13 | from base import BaseDataLoader
14 |
15 | LABEL_2_IDX = {
16 | "none": 0,
17 | "offensive": 1,
18 | "hate": 2
19 | }
20 | IDX_2_LABEL = {
21 | 0: "none",
22 | 1: "offensive",
23 | 2: "hate"
24 | }
25 |
26 |
27 | class KhsDataLoader(DataLoader):
28 | def __init__(
29 | self,
30 | student_tokenizer: PreTrainedTokenizer,
31 | teacher_tokenizer: PreTrainedTokenizer,
32 | max_length: int = None
33 | ):
34 | self.student_tokenizer = student_tokenizer
35 | self.teacher_tokenizer = teacher_tokenizer
36 | self.max_length = max_length if max_length else self.student_tokenizer.model_max_length
37 |
38 | def train_collate_fn(self, input_examples):
39 | input_texts, input_labels = [], []
40 |
41 | for input_example in input_examples:
42 | text, label = input_example
43 | input_texts.append(text)
44 | input_labels.append(label)
45 |
46 | st_encoded_texts = self.student_tokenizer.batch_encode_plus(
47 | input_texts,
48 | add_special_tokens=True,
49 | max_length=self.max_length,
50 | truncation=True,
51 | padding="max_length",
52 | return_tensors="pt",
53 | return_token_type_ids=True,
54 | return_attention_mask=True,
55 | ) # input_ids, token_type_ids, attention_mask
56 |
57 | tc_encoded_texts = self.teacher_tokenizer.batch_encode_plus(
58 | input_texts,
59 | add_special_tokens=True,
60 | max_length=self.max_length,
61 | truncation=True,
62 | padding="max_length",
63 | return_tensors="pt",
64 | return_token_type_ids=True,
65 | return_attention_mask=True,
66 | ) # input_ids, token_type_ids, attention_mask
67 |
68 | st_input_ids = st_encoded_texts["input_ids"]
69 | st_token_type_ids = st_encoded_texts["token_type_ids"]
70 | st_attention_mask = st_encoded_texts["attention_mask"]
71 |
72 | tc_input_ids = tc_encoded_texts["input_ids"]
73 | tc_token_type_ids = tc_encoded_texts["token_type_ids"]
74 | tc_attention_mask = tc_encoded_texts["attention_mask"]
75 |
76 | return st_input_ids, st_token_type_ids, st_attention_mask, tc_input_ids, tc_token_type_ids, tc_attention_mask, torch.tensor(input_labels)
77 |
78 | def valid_collate_fn(self, input_examples):
79 | input_texts, input_labels = [], []
80 |
81 | for input_example in input_examples:
82 | text, label = input_example
83 | input_texts.append(text)
84 | input_labels.append(label)
85 |
86 | encoded_texts = self.student_tokenizer.batch_encode_plus(
87 | input_texts,
88 | add_special_tokens=True,
89 | max_length=self.max_length,
90 | truncation=True,
91 | padding="max_length",
92 | return_tensors="pt",
93 | return_token_type_ids=True,
94 | return_attention_mask=True,
95 | ) # input_ids, token_type_ids, attention_mask
96 |
97 | input_ids = encoded_texts["input_ids"]
98 | token_type_ids = encoded_texts["token_type_ids"]
99 | attention_mask = encoded_texts["attention_mask"]
100 |
101 | return input_ids, token_type_ids, attention_mask, torch.tensor(input_labels)
102 |
103 | def test_collate_fn(self, input_examples):
104 | input_texts = []
105 | for input_example in input_examples:
106 | text = input_example
107 | input_texts.append(text)
108 |
109 | encoded_texts = self.tokenizer.batch_encode_plus(
110 | input_texts,
111 | add_special_tokens=True,
112 | max_length=self.max_length,
113 | truncation=True,
114 | padding="max_length",
115 | return_tensors="pt",
116 | return_token_type_ids=True,
117 | return_attention_mask=True,
118 | ) # input_ids, token_type_ids, attention_mask
119 |
120 | input_ids = encoded_texts["input_ids"]
121 | token_type_ids = encoded_texts["token_type_ids"]
122 | attention_mask = encoded_texts["attention_mask"]
123 |
124 | return input_ids, token_type_ids, attention_mask
125 |
126 | def get_dataloader(self, name, data_dir, data_files, batch_size, **kwargs):
127 | data_files = dict(data_files)
128 | datasets = load_dataset(data_dir, data_files=data_files, use_auth_token=True)
129 | dataset = get_preprocessed_data(datasets[name], name)
130 | dataset = KhsDataset(dataset, name)
131 |
132 | sampler = None
133 |
134 | if name == "test":
135 | collate_fn = self.test_collate_fn
136 | elif name == "valid":
137 | collate_fn = self.valid_collate_fn
138 | else:
139 | collate_fn = self.train_collate_fn
140 | sampler = ImbalancedDatasetSampler(dataset)
141 |
142 | g = torch.Generator()
143 | g.manual_seed(0)
144 |
145 | return DataLoader(
146 | dataset,
147 | batch_size=batch_size,
148 | shuffle=False,
149 | sampler=sampler,
150 | collate_fn=collate_fn,
151 | num_workers=4,
152 | pin_memory=True,
153 | worker_init_fn=seed_worker,
154 | generator=g,
155 | **kwargs
156 | )
157 |
158 |
159 | class KhsDataset(Dataset):
160 | def __init__(self, data, data_type="train"):
161 | self.data_type = data_type
162 | self.texts = list(data.texts)
163 | if self.data_type == "train" or self.data_type == "valid":
164 | self.labels = list(data.labels)
165 |
166 | def __len__(self):
167 | return len(self.texts)
168 |
169 | def __getitem__(self, index):
170 | text = self.texts[index]
171 |
172 | if self.data_type == "train" or self.data_type == "valid":
173 | label = self.labels[index]
174 | converted_label = LABEL_2_IDX[label]
175 |
176 | return text, converted_label
177 |
178 | return text
179 |
180 |
181 | def get_preprocessed_data(dataset, name):
182 | if name == "test":
183 | preprocessed_sents = preprocess(dataset["comments"])
184 | out_dataset = pd.DataFrame(
185 | {
186 | "texts": preprocessed_sents,
187 | }
188 | )
189 | else:
190 | preprocessed_sents = preprocess(dataset["comments"])
191 | out_dataset = pd.DataFrame(
192 | {"texts": preprocessed_sents, "labels": dataset["label"]}
193 | )
194 |
195 | return out_dataset
196 |
197 |
198 | # https://pytorch.org/docs/stable/notes/randomness.html
199 | def seed_worker(worker_id):
200 | worker_seed = torch.initial_seed() % 2**32
201 | np.random.seed(worker_seed)
202 | random.seed(worker_seed)
--------------------------------------------------------------------------------
/kd_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "koelectra_kd_model",
3 | "n_gpu": 1,
4 |
5 | "model": {
6 | "type": "BeepKcElectraHateModel",
7 | "args": {
8 | "name": "monologg/koelectra-small-v3-discriminator",
9 | "num_classes": 3
10 | }
11 | },
12 | "teacher_model": {
13 | "type": "BeepKcElectraHateModel",
14 | "args": {
15 | "name": "beomi/beep-KcELECTRA-base-hate",
16 | "num_classes": 3
17 | }
18 | },
19 | "tokenizer": {
20 | "student": {
21 | "type": "monologg/koelectra-small-v3-discriminator"
22 | },
23 | "teacher": {
24 | "type": "tokenizer/"
25 | }
26 | },
27 | "data_loader": {
28 | "type": "MnistDataLoader",
29 | "args":{
30 | "data_dir": "AI-it/korean-hate-speech",
31 | "batch_size": 64,
32 | "max_length": 64,
33 | "shuffle": true,
34 | "validation_split": 0.1,
35 | "num_workers": 2
36 | },
37 | "data_files": {
38 | "train": "train_hate.csv",
39 | "valid": "dev_hate.csv"
40 | },
41 | "test_data_file": {
42 | "test": "test_hate_no_label.csv"
43 | }
44 | },
45 | "optimizer": {
46 | "type": "AdamW",
47 | "args":{
48 | "lr": 5e-5,
49 | "eps": 1e-8
50 | },
51 | "weight_decay": 0.0
52 | },
53 | "loss": "knowledge_distillation_loss",
54 | "metrics": [
55 | "macro_f1"
56 | ],
57 | "lr_scheduler": {
58 | "type": "CosineAnnealingLR",
59 | "args": {
60 | "T_max": 300,
61 | "eta_min": 1e-5
62 | }
63 | },
64 | "trainer": {
65 | "epochs": 2,
66 |
67 | "save": {
68 | "dir": "saved/",
69 | "steps": 300,
70 | "limits": 3
71 | },
72 | "verbosity": 2,
73 |
74 | "monitor": "max val/macro_f1",
75 | "early_stop": 5
76 | }
77 | }
78 |
--------------------------------------------------------------------------------
/kd_train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import wandb
3 | import random
4 | import argparse
5 | import numpy as np
6 | import model.loss as module_loss
7 | import model.metric as module_metric
8 | import model.model as module_arch
9 | import data_loader.kd_data_loaders as module_data
10 | from parse_config import ConfigParser
11 | from trainer import KnowDistTrainer
12 | from utils import prepare_device
13 | from transformers import AutoTokenizer
14 | from data_loader.kd_data_loaders import KhsDataLoader
15 |
16 |
17 | def seed_everything(seed):
18 | """
19 | fix random seeds for reproducibility.
20 | Args:
21 | seed (int):
22 | seed number
23 | """
24 | torch.manual_seed(seed)
25 | torch.cuda.manual_seed(seed)
26 | torch.cuda.manual_seed_all(seed) # if use multi-GPU
27 | torch.backends.cudnn.deterministic = True
28 | torch.backends.cudnn.benchmark = False
29 | np.random.seed(seed)
30 | random.seed(seed)
31 |
32 |
33 | def main(config):
34 | seed_everything(42)
35 | wandb.init(project='#TODO', entity='#TODO', config=config)
36 |
37 | # build model architecture, then print to console
38 | student_model = config.init_obj('model', module_arch)
39 | teacher_model = config.init_obj('teacher_model', module_arch)
40 |
41 | # build tokenizer
42 | student_tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']['student']['type'])
43 | teacher_tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']['teacher']['type'])
44 |
45 | # build train and valid dataloader
46 | dataloader = KhsDataLoader(
47 | student_tokenizer,
48 | teacher_tokenizer,
49 | max_length=config['data_loader']['args']['max_length']
50 | )
51 | train_data_loader = dataloader.get_dataloader(
52 | name='train',
53 | data_dir=config['data_loader']['args']['data_dir'],
54 | data_files=config['data_loader']['data_files'],
55 | batch_size=config['data_loader']['args']['batch_size']
56 | )
57 | valid_data_loader = dataloader.get_dataloader(
58 | name='valid',
59 | data_dir=config['data_loader']['args']['data_dir'],
60 | data_files=config['data_loader']['data_files'],
61 | batch_size=config['data_loader']['args']['batch_size']
62 | )
63 |
64 | # prepare for (multi-device) GPU training
65 | device, device_ids = prepare_device(config['n_gpu'])
66 | student_model = student_model.to(device)
67 | teacher_model = teacher_model.to(device)
68 |
69 | # get function handles of loss and metrics
70 | criterion = getattr(module_loss, config['loss'])
71 | metrics = [getattr(module_metric, met) for met in config['metrics']]
72 |
73 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
74 | no_decay = ['bias', 'LayerNorm.weight']
75 | trainable_params = [
76 | {
77 | 'params': [p for n, p in student_model.named_parameters() if not any(nd in n for nd in no_decay)],
78 | 'weight_decay': config['optimizer']['weight_decay']
79 | },
80 | {
81 | 'params': [p for n, p in student_model.named_parameters() if any(nd in n for nd in no_decay)],
82 | 'weight_decay': 0.0
83 | }
84 | ]
85 |
86 | optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
87 | lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)
88 |
89 | trainer = KnowDistTrainer(
90 | student_model,
91 | teacher_model,
92 | criterion,
93 | metrics,
94 | optimizer,
95 | config=config,
96 | device=device,
97 | data_loader=train_data_loader,
98 | valid_data_loader=valid_data_loader,
99 | lr_scheduler=lr_scheduler
100 | )
101 |
102 | trainer.train()
103 |
104 |
105 | if __name__ == '__main__':
106 | args = argparse.ArgumentParser(description='PyTorch Template')
107 | args.add_argument('-c', '--config', default=None, type=str,
108 | help='config file path (default: None)')
109 | args.add_argument('-r', '--resume', default=None, type=str,
110 | help='path to latest checkpoint (default: None)')
111 | args.add_argument('-d', '--device', default=None, type=str,
112 | help='indices of GPUs to enable (default: all)')
113 |
114 | # custom cli options to modify configuration from default values given in json file.
115 | config = ConfigParser.from_args(args)
116 | main(config)
117 |
--------------------------------------------------------------------------------
/logger/__init__.py:
--------------------------------------------------------------------------------
1 | from .logger import *
2 |
--------------------------------------------------------------------------------
/logger/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config
3 | from pathlib import Path
4 | from utils import read_json
5 |
6 |
7 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO):
8 | """
9 | Setup logging configuration
10 | """
11 | log_config = Path(log_config)
12 | if log_config.is_file():
13 | config = read_json(log_config)
14 | # modify logging paths based on run config
15 | for _, handler in config['handlers'].items():
16 | if 'filename' in handler:
17 | handler['filename'] = str(save_dir / handler['filename'])
18 |
19 | logging.config.dictConfig(config)
20 | else:
21 | print("Warning: logging configuration file is not found in {}.".format(log_config))
22 | logging.basicConfig(level=default_level)
23 |
--------------------------------------------------------------------------------
/logger/logger_config.json:
--------------------------------------------------------------------------------
1 |
2 | {
3 | "version": 1,
4 | "disable_existing_loggers": false,
5 | "formatters": {
6 | "simple": {"format": "%(message)s"},
7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
8 | },
9 | "handlers": {
10 | "console": {
11 | "class": "logging.StreamHandler",
12 | "level": "DEBUG",
13 | "formatter": "simple",
14 | "stream": "ext://sys.stdout"
15 | },
16 | "info_file_handler": {
17 | "class": "logging.handlers.RotatingFileHandler",
18 | "level": "INFO",
19 | "formatter": "datetime",
20 | "filename": "info.log",
21 | "maxBytes": 10485760,
22 | "backupCount": 20, "encoding": "utf8"
23 | }
24 | },
25 | "root": {
26 | "level": "INFO",
27 | "handlers": [
28 | "console",
29 | "info_file_handler"
30 | ]
31 | }
32 | }
--------------------------------------------------------------------------------
/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def nll_loss(output, target):
7 | return F.nll_loss(output, target)
8 |
9 | def softmax(output, target):
10 | loss = nn.CrossEntropyLoss()
11 | return loss(output, target)
12 |
13 | def knowledge_distillation_loss(logits, target, teacher_logits):
14 | alpha = 0.3
15 | T = 1
16 |
17 | student_loss = F.cross_entropy(input=logits, target=target)
18 | distillation_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(logits/T, dim=1), F.softmax(teacher_logits/T, dim=1)) * (T * T)
19 | total_loss = (1. - alpha)*student_loss + alpha*distillation_loss
20 |
21 | return total_loss
--------------------------------------------------------------------------------
/model/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 |
3 | class NoamLR(_LRScheduler):
4 | """
5 | Implements the Noam Learning rate schedule. This corresponds to increasing the learning rate
6 | linearly for the first ``warmup_steps`` training steps, and decreasing it thereafter proportionally
7 | to the inverse square root of the step number, scaled by the inverse square root of the
8 | dimensionality of the model. Time will tell if this is just madness or it's actually important.
9 | Parameters
10 | ----------
11 | warmup_steps: ``int``, required.
12 | The number of steps to linearly increase the learning rate.
13 |
14 | https://github.com/tugstugi/pytorch-saltnet/blob/master/utils/lr_scheduler.py
15 | """
16 | def __init__(self, optimizer, warmup_steps):
17 | self.warmup_steps = warmup_steps
18 | super().__init__(optimizer)
19 |
20 | def get_lr(self):
21 | last_epoch = max(1, self.last_epoch)
22 | scale = self.warmup_steps ** 0.5 * min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5))
23 | return [base_lr * scale for base_lr in self.base_lrs]
--------------------------------------------------------------------------------
/model/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import warnings
3 | warnings.filterwarnings('ignore')
4 | import sklearn
5 |
6 | LABEL_LIST = [
7 | "hate",
8 | "offensive",
9 | "none"
10 | ]
11 |
12 | def accuracy(output, target):
13 | with torch.no_grad():
14 | pred = torch.argmax(output, dim=1)
15 | assert pred.shape[0] == len(target)
16 | correct = 0
17 | correct += torch.sum(pred == target).item()
18 | return correct / len(target)
19 |
20 |
21 | def top_k_acc(output, target, k=3):
22 | with torch.no_grad():
23 | pred = torch.topk(output, k, dim=1)[1]
24 | assert pred.shape[0] == len(target)
25 | correct = 0
26 | for i in range(k):
27 | correct += torch.sum(pred[:, i] == target).item()
28 | return correct / len(target)
29 |
30 |
31 | def macro_f1(output, target):
32 | label_indices = list(range(len(LABEL_LIST)))
33 | return sklearn.metrics.f1_score(target, output, average="macro", labels=label_indices) * 100.0
34 |
--------------------------------------------------------------------------------
/parse_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from pathlib import Path
4 | from functools import reduce, partial
5 | from operator import getitem
6 | from datetime import datetime
7 | from logger import setup_logging
8 | from utils import read_json
9 |
10 |
11 | class ConfigParser:
12 | def __init__(self, config, resume=None, run_id=None):
13 | """
14 | class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving
15 | and logging module.
16 | :param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example.
17 | :param resume: String, path to the checkpoint being loaded.
18 | :param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default
19 | """
20 | # load config file and apply modification
21 | self._config = config
22 | self.resume = resume
23 |
24 | # set save_dir where log will be saved.
25 | save_dir = Path(self.config['trainer']['save']['dir'])
26 |
27 | exper_name = self.config['name']
28 | if run_id is None: # use timestamp as default run-id
29 | run_id = datetime.now().strftime(r'%m%d_%H_%M_%S')
30 | self._log_dir = save_dir / 'log' / exper_name / run_id
31 |
32 | # make directory for saving checkpoints and log.
33 | exist_ok = run_id == ''
34 | self.log_dir.mkdir(parents=True, exist_ok=exist_ok)
35 |
36 | # configure logging module
37 | setup_logging(self.log_dir)
38 | self.log_levels = {
39 | 0: logging.WARNING,
40 | 1: logging.INFO,
41 | 2: logging.DEBUG
42 | }
43 |
44 | @classmethod
45 | def from_args(cls, args):
46 | """
47 | Initialize this class from some cli arguments. Used in train, test.
48 | """
49 | if not isinstance(args, tuple):
50 | args = args.parse_args()
51 |
52 | if args.device is not None:
53 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device
54 |
55 | if args.resume is not None:
56 | resume = Path(args.resume)
57 | cfg_fname = resume.parent / 'config.json'
58 | else:
59 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
60 | assert args.config is not None, msg_no_cfg
61 | resume = None
62 | cfg_fname = Path(args.config)
63 |
64 | config = read_json(cfg_fname)
65 | if args.config and resume:
66 | # update new config for fine-tuning
67 | config.update(read_json(args.config))
68 |
69 | return cls(config, resume)
70 |
71 | def init_obj(self, name, module, *args, **kwargs):
72 | """
73 | Finds a function handle with the name given as 'type' in config, and returns the
74 | instance initialized with corresponding arguments given.
75 |
76 | `object = config.init_obj('name', module, a, b=1)`
77 | is equivalent to
78 | `object = module.name(a, b=1)`
79 | """
80 | module_name = self[name]['type']
81 | module_args = dict(self[name]['args'])
82 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
83 | module_args.update(kwargs)
84 | return getattr(module, module_name)(*args, **module_args)
85 |
86 | def init_ftn(self, name, module, *args, **kwargs):
87 | """
88 | Finds a function handle with the name given as 'type' in config, and returns the
89 | function with given arguments fixed with functools.partial.
90 |
91 | `function = config.init_ftn('name', module, a, b=1)`
92 | is equivalent to
93 | `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`.
94 | """
95 | module_name = self[name]['type']
96 | module_args = dict(self[name]['args'])
97 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
98 | module_args.update(kwargs)
99 | return partial(getattr(module, module_name), *args, **module_args)
100 |
101 | def __getitem__(self, name):
102 | """Access items like ordinary dict."""
103 | return self.config[name]
104 |
105 | def get_logger(self, name, verbosity=2):
106 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys())
107 | assert verbosity in self.log_levels, msg_verbosity
108 | logger = logging.getLogger(name)
109 | logger.setLevel(self.log_levels[verbosity])
110 | return logger
111 |
112 | # setting read-only attributes
113 | @property
114 | def config(self):
115 | return self._config
116 |
117 | @property
118 | def log_dir(self):
119 | return self._log_dir
120 |
121 | def _get_opt_name(flags):
122 | for flg in flags:
123 | if flg.startswith('--'):
124 | return flg.replace('--', '')
125 | return flags[0].replace('--', '')
126 |
127 |
--------------------------------------------------------------------------------
/pkm_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "mem_implementation": "pq_fast",
3 | "mem_grouped_conv": 0,
4 | "mem_values_optimizer": "",
5 | "mem_sparse": 0,
6 | "mem_input2d": 0,
7 | "mem_k_dim": 256,
8 | "mem_v_dim": -1,
9 | "mem_heads": 4,
10 | "mem_knn": 32,
11 | "mem_share_values": 0,
12 | "mem_shuffle_indices": 0,
13 | "mem_shuffle_query": 0,
14 | "mem_modulo_size": -1,
15 | "mem_keys_type": "uniform",
16 | "mem_n_keys": 512,
17 | "mem_keys_normalized_init": 0,
18 | "mem_keys_learn": 1,
19 | "mem_use_different_keys": 1,
20 | "mem_query_detach_input": 0,
21 | "mem_query_layer_sizes": "0,0",
22 | "mem_query_kernel_sizes": "",
23 | "mem_query_bias": 1,
24 | "mem_query_batchnorm": 0,
25 | "mem_query_net_learn": 1,
26 | "mem_query_residual": 0,
27 | "mem_multi_query_net": 0,
28 | "mem_value_zero_init": 0,
29 | "mem_normalize_query": 1,
30 | "mem_temperature": 1,
31 | "mem_score_softmax": 1,
32 | "mem_score_subtract": "",
33 | "mem_score_normalize": 0,
34 | "mem_input_dropout": 0,
35 | "mem_query_dropout": 0,
36 | "mem_value_dropout": 0
37 | }
--------------------------------------------------------------------------------
/prototype/fullstack/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-12/09c6e84a3618050ab0593df6f75beacf0340f9a6/prototype/fullstack/.DS_Store
--------------------------------------------------------------------------------
/prototype/fullstack/Makefile:
--------------------------------------------------------------------------------
1 | run_black:
2 | python3 -m black . -l 119
3 |
4 | run_server:
5 | python3 -m app
6 |
7 | run_client:
8 | python3 -m streamlit run app/frontend.py
9 |
10 | run_app: run_server run_client
--------------------------------------------------------------------------------
/prototype/fullstack/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-12/09c6e84a3618050ab0593df6f75beacf0340f9a6/prototype/fullstack/__init__.py
--------------------------------------------------------------------------------
/prototype/fullstack/app/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-12/09c6e84a3618050ab0593df6f75beacf0340f9a6/prototype/fullstack/app/.DS_Store
--------------------------------------------------------------------------------
/prototype/fullstack/app/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-12/09c6e84a3618050ab0593df6f75beacf0340f9a6/prototype/fullstack/app/__init__.py
--------------------------------------------------------------------------------
/prototype/fullstack/app/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_data_loader import *
2 | from .base_model import *
3 | from .base_trainer import *
4 |
--------------------------------------------------------------------------------
/prototype/fullstack/app/base/base_data_loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import DataLoader
3 | from torch.utils.data.dataloader import default_collate
4 | from torch.utils.data.sampler import SubsetRandomSampler
5 |
6 |
7 | class BaseDataLoader(DataLoader):
8 | """
9 | Base class for all data loaders
10 | """
11 |
12 | def __init__(
13 | self,
14 | dataset,
15 | batch_size,
16 | shuffle,
17 | validation_split,
18 | num_workers,
19 | collate_fn=default_collate,
20 | ):
21 | self.validation_split = validation_split
22 | self.shuffle = shuffle
23 |
24 | self.batch_idx = 0
25 | self.n_samples = len(dataset)
26 |
27 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
28 |
29 | self.init_kwargs = {
30 | "dataset": dataset,
31 | "batch_size": batch_size,
32 | "shuffle": self.shuffle,
33 | "collate_fn": collate_fn,
34 | "num_workers": num_workers,
35 | }
36 | super().__init__(sampler=self.sampler, **self.init_kwargs)
37 |
38 | def _split_sampler(self, split):
39 | if split == 0.0:
40 | return None, None
41 |
42 | idx_full = np.arange(self.n_samples)
43 |
44 | np.random.seed(0)
45 | np.random.shuffle(idx_full)
46 |
47 | if isinstance(split, int):
48 | assert split > 0
49 | assert (
50 | split < self.n_samples
51 | ), "validation set size is configured to be larger than entire dataset."
52 | len_valid = split
53 | else:
54 | len_valid = int(self.n_samples * split)
55 |
56 | valid_idx = idx_full[0:len_valid]
57 | train_idx = np.delete(idx_full, np.arange(0, len_valid))
58 |
59 | train_sampler = SubsetRandomSampler(train_idx)
60 | valid_sampler = SubsetRandomSampler(valid_idx)
61 |
62 | # turn off shuffle option which is mutually exclusive with sampler
63 | self.shuffle = False
64 | self.n_samples = len(train_idx)
65 |
66 | return train_sampler, valid_sampler
67 |
68 | def split_validation(self):
69 | if self.valid_sampler is None:
70 | return None
71 | else:
72 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
73 |
--------------------------------------------------------------------------------
/prototype/fullstack/app/base/base_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | from abc import abstractmethod
4 |
5 |
6 | class BaseModel(nn.Module):
7 | """
8 | Base class for all models
9 | """
10 |
11 | @abstractmethod
12 | def forward(self, *inputs):
13 | """
14 | Forward pass logic
15 |
16 | :return: Model output
17 | """
18 | raise NotImplementedError
19 |
20 | def __str__(self):
21 | """
22 | Model prints with number of trainable parameters
23 | """
24 | model_parameters = filter(lambda p: p.requires_grad, self.parameters())
25 | params = sum([np.prod(p.size()) for p in model_parameters])
26 | return super().__str__() + "\nTrainable parameters: {}".format(params)
27 |
--------------------------------------------------------------------------------
/prototype/fullstack/app/base/base_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import shutil
4 | from abc import abstractmethod
5 | from numpy import inf
6 | from utils import write_json
7 |
8 |
9 | class BaseTrainer:
10 | """
11 | Base class for all trainers
12 | """
13 |
14 | def __init__(self, model, criterion, metric_ftns, optimizer, config):
15 | self.config = config
16 | self.logger = config.get_logger("trainer", config["trainer"]["verbosity"])
17 |
18 | self.model = model
19 | self.criterion = criterion
20 | self.metric_ftns = metric_ftns
21 | self.optimizer = optimizer
22 |
23 | cfg_trainer = config["trainer"]
24 | self.epochs = cfg_trainer["epochs"]
25 | self.save_steps = cfg_trainer["save"]["steps"]
26 | self.save_limits = cfg_trainer["save"]["limits"]
27 | self.monitor = cfg_trainer.get("monitor", "off")
28 |
29 | # configuration to monitor model performance and save best
30 | if self.monitor == "off":
31 | self.mnt_mode = "off"
32 | self.mnt_best = 0
33 | else:
34 | self.mnt_mode, self.mnt_metric = self.monitor.split()
35 | assert self.mnt_mode in ["min", "max"]
36 |
37 | self.mnt_best = inf if self.mnt_mode == "min" else -inf
38 | self.early_stop = cfg_trainer.get("early_stop", inf)
39 | if self.early_stop <= 0:
40 | self.early_stop = inf
41 |
42 | self.not_improved_count = 0
43 |
44 | self.checkpoint_dir = cfg_trainer["save"]["dir"]
45 |
46 | if config.resume is not None:
47 | self._resume_checkpoint(config.resume)
48 |
49 | @abstractmethod
50 | def train(self):
51 | """
52 | Full training logic.
53 | """
54 |
55 | raise NotImplementedError
56 |
57 | @abstractmethod
58 | def _validation(self, step):
59 | """
60 | Full validation logic
61 |
62 | :param step: Current step number
63 | """
64 |
65 | raise NotImplementedError
66 |
67 | def _evaluate_performance(self, log):
68 | # evaluate model performance according to configured metric, save best checkpoint as model_best
69 | is_best = False
70 | if self.mnt_mode != "off":
71 | try:
72 | # check whether model performance improved or not, according to specified metric(mnt_metric)
73 | improved = (
74 | self.mnt_mode == "min" and log[self.mnt_metric] <= self.mnt_best
75 | ) or (self.mnt_mode == "max" and log[self.mnt_metric] >= self.mnt_best)
76 | except KeyError:
77 | self.logger.warning(
78 | "Warning: Metric '{}' is not found. "
79 | "Model performance monitoring is disabled.".format(self.mnt_metric)
80 | )
81 | self.mnt_mode = "off"
82 | improved = False
83 |
84 | if improved:
85 | self.mnt_best = log[self.mnt_metric]
86 | self.not_improved_count = 0
87 | is_best = True
88 | else:
89 | self.not_improved_count += 1
90 |
91 | return is_best
92 |
93 | def _save_checkpoint(self, log, is_best=False):
94 | """
95 | Saving checkpoints
96 |
97 | :param epoch: current epoch number
98 | :param log: logging information of the epoch
99 | :param save_best: if True, rename the saved checkpoint to 'best_model.pt'
100 | """
101 | save_path = f'{self.checkpoint_dir}models/{self.config["name"]}/'
102 | chk_pt_path = save_path + f"steps_{log['steps']}/"
103 |
104 | # make path if there isn't
105 | if not os.path.exists(chk_pt_path):
106 | os.makedirs(chk_pt_path)
107 | # delete the oldest checkpoint not to exceed save limits
108 | if len(os.listdir(save_path)) > self.save_limits:
109 | shutil.rmtree(os.path.join(
110 | save_path,
111 | sorted(os.listdir(save_path),key = lambda x : (len(x), x))[0]
112 | )
113 | )
114 |
115 | self.logger.info("Saving checkpoint: {} ...".format(chk_pt_path))
116 | torch.save(self.model, os.path.join(chk_pt_path, "model.pt"))
117 | torch.save(
118 | self.optimizer.state_dict(), os.path.join(chk_pt_path, "optimizer.pt")
119 | )
120 |
121 | # save updated config file to the checkpoint dir
122 | write_json(self.config._config, os.path.join(chk_pt_path, "config.json"))
123 | write_json(log, os.path.join(chk_pt_path, "log.json"))
124 |
125 | # save best model.
126 | if is_best:
127 | best_path = f'{self.checkpoint_dir}best/{self.config["name"]}/'
128 |
129 | # make path if there isn't
130 | if not os.path.exists(best_path):
131 | os.makedirs(best_path)
132 | # delete old best files
133 | for file_ in os.listdir(best_path):
134 | os.remove(best_path + file_)
135 |
136 | self.logger.info("Saving current best: model_best.pt ...")
137 | torch.save(self.model, os.path.join(best_path, "best_model.pt"))
138 | torch.save(
139 | self.optimizer.state_dict(), os.path.join(best_path, "optimizer.pt")
140 | )
141 |
142 | # save updated config file to the checkpoint dir
143 | write_json(self.config._config, os.path.join(best_path, "config.json"))
144 | write_json(log, os.path.join(best_path, "log.json"))
145 |
146 | def _resume_checkpoint(self, resume_path):
147 | """
148 | Resume from saved checkpoints
149 |
150 | :param resume_path: Checkpoint path to be resumed
151 | """
152 | resume_path = str(resume_path)
153 | self.logger.info("Loading checkpoint: {} ...".format(resume_path))
154 | checkpoint = torch.load(resume_path)
155 | self.start_epoch = checkpoint["epoch"] + 1
156 | self.mnt_best = checkpoint["monitor_best"]
157 |
158 | # load architecture params from checkpoint.
159 | if checkpoint["config"]["arch"] != self.config["arch"]:
160 | self.logger.warning(
161 | "Warning: Architecture configuration given in config file is different from that of "
162 | "checkpoint. This may yield an exception while state_dict is being loaded."
163 | )
164 | self.model.load_state_dict(checkpoint["state_dict"])
165 |
166 | # load optimizer state from checkpoint only when optimizer type is not changed.
167 | if (
168 | checkpoint["config"]["optimizer"]["type"]
169 | != self.config["optimizer"]["type"]
170 | ):
171 | self.logger.warning(
172 | "Warning: Optimizer type given in config file is different from that of checkpoint. "
173 | "Optimizer parameters not being resumed."
174 | )
175 | else:
176 | self.optimizer.load_state_dict(checkpoint["optimizer"])
177 |
178 | self.logger.info(
179 | "Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)
180 | )
181 |
--------------------------------------------------------------------------------
/prototype/fullstack/app/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "beomi/beep-KcELECTRA-base-hate",
3 | "n_gpu": 1,
4 |
5 | "model": {
6 | "type": "BeomiModel",
7 | "args": {
8 | "name": "beomi/beep-KcELECTRA-base-hate",
9 | "num_classes": 3
10 | }
11 | },
12 | "tokenizer": "beomi/KcELECTRA-base",
13 | "data_loader": {
14 | "type": "MnistDataLoader",
15 | "args":{
16 | "data_dir": "data/",
17 | "batch_size": 64,
18 | "max_length": 64,
19 | "shuffle": true,
20 | "validation_split": 0.1,
21 | "num_workers": 2
22 | }
23 | },
24 | "optimizer": {
25 | "type": "AdamW",
26 | "args":{
27 | "lr": 5e-5,
28 | "eps": 1e-8
29 | },
30 | "weight_decay": 0.0
31 | },
32 | "loss": "softmax",
33 | "metrics": [
34 | "macro_f1"
35 | ],
36 | "lr_scheduler": {
37 | "type": "StepLR",
38 | "args": {
39 | "step_size": 50,
40 | "gamma": 0.1
41 | }
42 | },
43 | "trainer": {
44 | "epochs": 2,
45 |
46 | "save": {
47 | "dir": "saved/",
48 | "steps": 300,
49 | "limits": 3
50 | },
51 | "verbosity": 2,
52 |
53 | "monitor": "max val/macro_f1",
54 | "early_stop": 2
55 | },
56 | "data_dir": "AI-it/korean-hate-speech",
57 | "data_files": {
58 | "train": "train_hate.csv",
59 | "valid": "dev_hate.csv"
60 | },
61 | "test_data_file": {
62 | "test": "test_hate_no_label.csv"
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/prototype/fullstack/app/confirm_button_hack.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import streamlit as st
4 |
5 |
6 | def cache_on_button_press(label, **cache_kwargs):
7 | """Function decorator to memoize function executions.
8 | Parameters
9 | ----------
10 | label : str
11 | The label for the button to display prior to running the cached funnction.
12 | cache_kwargs : Dict[Any, Any]
13 | Additional parameters (such as show_spinner) to pass into the underlying @st.cache decorator.
14 | Example
15 | -------
16 | This show how you could write a username/password tester:
17 | >>> @cache_on_button_press('Authenticate')
18 | ... def authenticate(username, password):
19 | ... return username == "buddha" and password == "s4msara"
20 | ...
21 | ... username = st.text_input('username')
22 | ... password = st.text_input('password')
23 | ...
24 | ... if authenticate(username, password):
25 | ... st.success('Logged in.')
26 | ... else:
27 | ... st.error('Incorrect username or password')
28 | """
29 | internal_cache_kwargs = dict(cache_kwargs)
30 | internal_cache_kwargs['allow_output_mutation'] = True
31 | internal_cache_kwargs['show_spinner'] = False
32 |
33 | def function_decorator(func):
34 | @functools.wraps(func)
35 | def wrapped_func(*args, **kwargs):
36 | @st.cache(**internal_cache_kwargs)
37 | def get_cache_entry(func, args, kwargs):
38 | class ButtonCacheEntry:
39 | def __init__(self):
40 | self.evaluated = False
41 | self.return_value = None
42 |
43 | def evaluate(self):
44 | self.evaluated = True
45 | self.return_value = func(*args, **kwargs)
46 |
47 | return ButtonCacheEntry()
48 |
49 | cache_entry = get_cache_entry(func, args, kwargs)
50 | if not cache_entry.evaluated:
51 | if st.button(label):
52 | cache_entry.evaluate()
53 | else:
54 | raise st.script_runner.StopException
55 | return cache_entry.return_value
56 |
57 | return wrapped_func
58 |
59 | return function_decorator
60 |
--------------------------------------------------------------------------------
/prototype/fullstack/app/database.py:
--------------------------------------------------------------------------------
1 | from pymongo import MongoClient
2 | import certifi
3 |
4 | def run_db():
5 | ca = certifi.where()
6 | client = MongoClient('mongodb+srv://jadon:aiit@cluster0.13mh6.mongodb.net/myFirstDatabase?retryWrites=true&w=majority', tlsCAFile=ca)
7 | db = client.aiit
8 | evidence = db.evidence
9 | return evidence
10 |
11 | def insert2db(keyword, results, collection):
12 | docs = []
13 | for res in results:
14 | docs.append({
15 | "keyword": keyword,
16 | 'user_id': res['user_id'],
17 | 'comment': res['comment'],
18 | "label": res['label'],
19 | 'site_name': res['site_name'],
20 | 'site_url': res['site_url'],
21 | 'commented_at': res['commented_at']
22 | })
23 | collection.insert_many(docs)
--------------------------------------------------------------------------------
/prototype/fullstack/app/frontend.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from confirm_button_hack import cache_on_button_press
3 | import requests
4 | import time
5 | import pandas as pd
6 |
7 |
8 | st.set_page_config(layout='wide')
9 | st.header('Hello AI-it!!')
10 |
11 | st.title('Malicious Comments Collecting Service')
12 |
13 |
14 | def main():
15 | keyword = st.text_input('Keyword you want to collect!!')
16 | if keyword:
17 | with st.spinner('Collecting Evidence...'):
18 | response = requests.get('http://49.50.174.246:2227/get_sample/' + keyword)
19 | st.success('Done!')
20 |
21 | st.markdown("