├── .gitignore ├── README.md ├── get_similar_words.py ├── pyw2v ├── __init__.py ├── callback │ ├── __init__.py │ ├── lrscheduler.py │ ├── modelcheckpoint.py │ ├── progressbar.py │ └── trainingmonitor.py ├── common │ ├── __init__.py │ └── tools.py ├── config │ ├── __init__.py │ └── basic_config.py ├── dataset │ ├── __init__.py │ ├── processed │ │ └── __init__.py │ └── raw │ │ └── __init__.py ├── ensemble │ └── __init__.py ├── feature │ └── __init__.py ├── io │ ├── __init__.py │ ├── data_transformer.py │ └── dataset.py ├── model │ ├── __init__.py │ └── nn │ │ ├── __init__.py │ │ ├── gensim_word2vec.py │ │ └── skip_gram.py ├── output │ ├── checkpoints │ │ └── __init__.py │ ├── embedding │ │ └── __init__.py │ ├── feature │ │ └── __init__.py │ ├── figure │ │ └── __init__.py │ ├── log │ │ └── __init__.py │ └── result │ │ └── __init__.py ├── preprocessing │ ├── __init__.py │ └── preprocessor.py └── train │ ├── __init__.py │ └── trainer.py ├── train_gensim_word2vec.py └── train_word2vec.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### word2vec implementation for skip-gram in pytorch 2 | 3 | 本repo包含了使用pytorch实现skip-gram版本的word2vec词向量模型。 4 | 5 | 备注: 该版本以batch为1进行训练,速度较慢。 6 | 7 | ### 目录结构 8 | 9 | 主要的代码目录结果如下所示: 10 | 11 | ```text 12 | ├── pyword2vec 13 | | └── callback 14 | | | └── lrscheduler.py   15 | | └── config 16 | | | └── word2vec_config.py   17 | | └── dataset            18 | | └── io               19 | | └── model 20 | | └── output            21 | | └── preprocessing     22 | | └── train 23 | | └── utils 24 | ├── get_similar_words.py 25 | ├── train_gensim_word2vec.py 26 | ├── train_word2vec.py 27 | ``` 28 | ### 案例 29 | 30 | 1. 首先下载数据集,可以从[百度网盘](https://pan.baidu.com/s/1FcrAc3w48dG8Gixv9E6EQw){提取码:7fyf},并放入`pyw2v/dataset/raw`文件夹中 31 | 32 | 2. 修改config文件夹中对应的数据路径配置 33 | 2. 运行`python train_word2vec.py`进行word2vec模型训练 34 | 35 | ### 实验结果 36 | 37 | 大概6次epochs之后,可得到以下结果: 38 | 39 | | 目标词 | Top10 | 目标词 | Top10 | 40 | | :--: | :--------: | :--: | :---------: | 41 | | 中国 | 中国 : 1.000 | 男人 | 男人 : 1.000 | 42 | | 中国 | 美国 : 0.651 | 男人 | 女人 : 0.764 | 43 | | 中国 | 日本 : 0.578 | 男人 | 女生 : 0.687 | 44 | | 中国 | 国家 : 0.560 | 男人 | 男生 : 0.670 | 45 | | 中国 | 发展 : 0.550 | 男人 | 喜欢 : 0.625 | 46 | | 中国 | 文化 : 0.529 | 男人 | 恋爱 : 0.601 | 47 | | 中国 | 朝鲜 : 0.512 | 男人 | 岁 : 0.590 | 48 | | 中国 | 经济 : 0.504 | 男人 | 女 : 0.588 | 49 | | 中国 | 世界 : 0.493 | 男人 | 感觉 : 0.586 | 50 | | 中国 | 社会 : 0.481 | 男人 | 男朋友 : 0.581 | 51 | 52 | -------------------------------------------------------------------------------- /get_similar_words.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import os 3 | import warnings 4 | from pyw2v.io.data_transformer import DataTransformer 5 | from pyw2v.config.basic_config import configs as config 6 | warnings.filterwarnings("ignore") 7 | 8 | def main(): 9 | 10 | data_transformer = DataTransformer(embedding_path = config['pytorch_embedding_path']) 11 | data_transformer.get_similar_words(word = '中国',w_num=10) 12 | data_transformer.get_similar_words(word='男人', w_num=10) 13 | 14 | del data_transformer 15 | 16 | if __name__ =="__main__": 17 | main() 18 | -------------------------------------------------------------------------------- /pyw2v/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | -------------------------------------------------------------------------------- /pyw2v/callback/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | -------------------------------------------------------------------------------- /pyw2v/callback/lrscheduler.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import math 3 | import logging 4 | import numpy as np 5 | import warnings 6 | from torch.optim.optimizer import Optimizer 7 | 8 | __all__ = ['StepLR', 9 | 'BertLR', 10 | 'CyclicLR', 11 | 'ReduceLROnPlateau', 12 | 'ReduceLRWDOnPlateau', 13 | 'CosineLRWithRestarts', 14 | 'NoamLR' 15 | ] 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class StepLR(object): 20 | ''' 21 | 自定义学习率变化机制 22 | Example: 23 | >>> scheduler =StepLR(optimizer) 24 | >>> for epoch in range(100): 25 | >>> scheduler.epoch_step() 26 | >>> train(...) 27 | >>> ... 28 | >>> optimizer.zero_grad() 29 | >>> loss.backward() 30 | >>> optimizer.step() 31 | >>> validate(...) 32 | ''' 33 | 34 | def __init__(self, optimizer, lr, epochs): 35 | self.optimizer = optimizer 36 | self.lr = lr 37 | self.epochs = epochs 38 | 39 | def epoch_step(self, epoch): 40 | new_lr = self.lr * (1.0 - 1.0 * epoch / self.epochs) 41 | for param_group in self.optimizer.param_groups: 42 | param_group['lr'] = new_lr 43 | 44 | 45 | class BertLR(object): 46 | ''' 47 | Bert模型内定的学习率变化机制 48 | Example: 49 | >>> scheduler = BertLR(optimizer) 50 | >>> for epoch in range(100): 51 | >>> scheduler.step() 52 | >>> train(...) 53 | >>> ... 54 | >>> optimizer.zero_grad() 55 | >>> loss.backward() 56 | >>> optimizer.step() 57 | >>> scheduler.batch_step() 58 | >>> validate(...) 59 | ''' 60 | 61 | def __init__(self, optimizer, lr, t_total, warmup): 62 | self.lr = lr 63 | self.optimizer = optimizer 64 | self.t_total = t_total 65 | self.warmup = warmup 66 | 67 | # 线性预热方式 68 | def warmup_linear(self, x, warmup=0.002): 69 | if x < warmup: 70 | return x / warmup 71 | return 1.0 - x 72 | 73 | def batch_step(self, training_step): 74 | lr_this_step = self.lr * self.warmup_linear(training_step / self.t_total, self.warmup) 75 | for param_group in self.optimizer.param_groups: 76 | param_group['lr'] = lr_this_step 77 | 78 | 79 | class CyclicLR(object): 80 | ''' 81 | Cyclical learning rates for training neural networks 82 | Example: 83 | >>> scheduler = CyclicLR(optimizer) 84 | >>> for epoch in range(100): 85 | >>> scheduler.step() 86 | >>> train(...) 87 | >>> ... 88 | >>> optimizer.zero_grad() 89 | >>> loss.backward() 90 | >>> optimizer.step() 91 | >>> scheduler.batch_step() 92 | >>> validate(...) 93 | ''' 94 | 95 | def __init__(self, optimizer, base_lr=1e-3, max_lr=6e-3, 96 | step_size=2000, mode='triangular', gamma=1., 97 | scale_fn=None, scale_mode='cycle', last_batch_iteration=-1): 98 | 99 | if not isinstance(optimizer, Optimizer): 100 | raise TypeError(f'{type(optimizer).__name__} is not an Optimizer') 101 | 102 | self.optimizer = optimizer 103 | 104 | if isinstance(base_lr, list) or isinstance(base_lr, tuple): 105 | if len(base_lr) != len(optimizer.param_groups): 106 | raise ValueError(f"expected {len(optimizer.param_groups)} base_lr, got {len(base_lr)}") 107 | self.base_lrs = list(base_lr) 108 | else: 109 | self.base_lrs = [base_lr] * len(optimizer.param_groups) 110 | 111 | if isinstance(max_lr, list) or isinstance(max_lr, tuple): 112 | if len(max_lr) != len(optimizer.param_groups): 113 | raise ValueError(f"expected {len(optimizer.param_groups)} max_lr, got {len(max_lr)}") 114 | self.max_lrs = list(max_lr) 115 | else: 116 | self.max_lrs = [max_lr] * len(optimizer.param_groups) 117 | 118 | self.step_size = step_size 119 | 120 | if mode not in ['triangular', 'triangular2', 'exp_range'] \ 121 | and scale_fn is None: 122 | raise ValueError('mode is invalid and scale_fn is None') 123 | 124 | self.mode = mode 125 | self.gamma = gamma 126 | 127 | if scale_fn is None: 128 | if self.mode == 'triangular': 129 | self.scale_fn = self._triangular_scale_fn 130 | self.scale_mode = 'cycle' 131 | elif self.mode == 'triangular2': 132 | self.scale_fn = self._triangular2_scale_fn 133 | self.scale_mode = 'cycle' 134 | elif self.mode == 'exp_range': 135 | self.scale_fn = self._exp_range_scale_fn 136 | self.scale_mode = 'iterations' 137 | else: 138 | self.scale_fn = scale_fn 139 | self.scale_mode = scale_mode 140 | 141 | self.batch_step(last_batch_iteration + 1) 142 | self.last_batch_iteration = last_batch_iteration 143 | 144 | def _triangular_scale_fn(self, x): 145 | return 1. 146 | 147 | def _triangular2_scale_fn(self, x): 148 | return 1 / (2. ** (x - 1)) 149 | 150 | def _exp_range_scale_fn(self, x): 151 | return self.gamma ** (x) 152 | 153 | def get_lr(self): 154 | step_size = float(self.step_size) 155 | cycle = np.floor(1 + self.last_batch_iteration / (2 * step_size)) 156 | x = np.abs(self.last_batch_iteration / step_size - 2 * cycle + 1) 157 | 158 | lrs = [] 159 | param_lrs = zip(self.optimizer.param_groups, self.base_lrs, self.max_lrs) 160 | for param_group, base_lr, max_lr in param_lrs: 161 | base_height = (max_lr - base_lr) * np.maximum(0, (1 - x)) 162 | if self.scale_mode == 'cycle': 163 | lr = base_lr + base_height * self.scale_fn(cycle) 164 | else: 165 | lr = base_lr + base_height * self.scale_fn(self.last_batch_iteration) 166 | lrs.append(lr) 167 | return lrs 168 | 169 | def batch_step(self, batch_iteration=None): 170 | if batch_iteration is None: 171 | batch_iteration = self.last_batch_iteration + 1 172 | self.last_batch_iteration = batch_iteration 173 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 174 | param_group['lr'] = lr 175 | 176 | 177 | class ReduceLROnPlateau(object): 178 | """Reduce learning rate when a metrics has stopped improving. 179 | Models often benefit from reducing the learning rate by a factor 180 | of 2-10 once learning stagnates. This scheduler reads a metrics 181 | quantity and if no improvement is seen for a 'patience' number 182 | of epochs, the learning rate is reduced. 183 | 184 | Args: 185 | factor: factor by which the learning rate will 186 | be reduced. new_lr = lr * factor 187 | patience: number of epochs with no improvement 188 | after which learning rate will be reduced. 189 | verbose: int. 0: quiet, 1: update messages. 190 | mode: one of {min, max}. In `min` mode, 191 | lr will be reduced when the quantity 192 | monitored has stopped decreasing; in `max` 193 | mode it will be reduced when the quantity 194 | monitored has stopped increasing. 195 | epsilon: threshold for measuring the new optimum, 196 | to only focus on significant changes. 197 | cooldown: number of epochs to wait before resuming 198 | normal operation after lr has been reduced. 199 | min_lr: lower bound on the learning rate. 200 | 201 | 202 | Example: 203 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 204 | >>> scheduler = ReduceLROnPlateau(optimizer, 'min') 205 | >>> for epoch in range(10): 206 | >>> train(...) 207 | >>> val_acc, val_loss = validate(...) 208 | >>> scheduler.epoch_step(val_loss, epoch) 209 | """ 210 | 211 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, 212 | verbose=0, epsilon=1e-4, cooldown=0, min_lr=0, eps=1e-8): 213 | 214 | super(ReduceLROnPlateau, self).__init__() 215 | assert isinstance(optimizer, Optimizer) 216 | if factor >= 1.0: 217 | raise ValueError('ReduceLROnPlateau ' 218 | 'does not support a factor >= 1.0.') 219 | self.factor = factor 220 | self.min_lr = min_lr 221 | self.epsilon = epsilon 222 | self.patience = patience - 1 223 | self.verbose = verbose 224 | self.cooldown = cooldown 225 | self.cooldown_counter = 0 # Cooldown counter. 226 | self.monitor_op = None 227 | self.wait = 0 228 | self.best = 0 229 | self.mode = mode 230 | self.optimizer = optimizer 231 | self.eps = eps 232 | self._reset() 233 | 234 | def _reset(self): 235 | """Resets wait counter and cooldown counter. 236 | """ 237 | if self.mode not in ['min', 'max']: 238 | raise RuntimeError('Learning Rate Plateau Reducing mode %s is unknown!') 239 | if self.mode == 'min': 240 | self.monitor_op = lambda a, b: np.less(a, b - self.epsilon) 241 | self.best = np.Inf 242 | else: 243 | self.monitor_op = lambda a, b: np.greater(a, b + self.epsilon) 244 | self.best = -np.Inf 245 | self.cooldown_counter = 0 246 | self.wait = 0 247 | 248 | def reset(self): 249 | self._reset() 250 | 251 | def epoch_step(self, metrics, epoch): 252 | current = metrics 253 | if current is None: 254 | warnings.warn('Learning Rate Plateau Reducing requires metrics available!', RuntimeWarning) 255 | else: 256 | if self.in_cooldown(): 257 | self.cooldown_counter -= 1 258 | self.wait = 0 259 | 260 | if self.monitor_op(current, self.best): 261 | self.best = current 262 | self.wait = 0 263 | 264 | elif not self.in_cooldown(): 265 | if self.wait >= self.patience: 266 | for param_group in self.optimizer.param_groups: 267 | old_lr = float(param_group['lr']) 268 | if old_lr > self.min_lr + self.eps: 269 | new_lr = old_lr * self.factor 270 | new_lr = max(new_lr, self.min_lr) 271 | param_group['lr'] = new_lr 272 | if self.verbose > 0: 273 | logger.info(f'\nEpoch {epoch}: reducing learning rate to {new_lr}.') 274 | self.cooldown_counter = self.cooldown 275 | self.wait = 0 276 | self.wait += 1 277 | 278 | def in_cooldown(self): 279 | return self.cooldown_counter > 0 280 | 281 | 282 | class ReduceLRWDOnPlateau(ReduceLROnPlateau): 283 | """Reduce learning rate and weight decay when a metrics has stopped 284 | improving. Models often benefit from reducing the learning rate by 285 | a factor of 2-10 once learning stagnates. This scheduler reads a metrics 286 | quantity and if no improvement is seen for a 'patience' number 287 | of epochs, the learning rate and weight decay factor is reduced for 288 | optimizers that implement the the weight decay method from the paper 289 | `Fixing Weight Decay Regularization in Adam`_. 290 | 291 | .. _Fixing Weight Decay Regularization in Adam: 292 | https://arxiv.org/abs/1711.05101 293 | for AdamW or SGDW 294 | Example: 295 | >>> optimizer = AdamW(model.parameters(), lr=0.1, weight_decay=1e-3) 296 | >>> scheduler = ReduceLRWDOnPlateau(optimizer, 'min') 297 | >>> for epoch in range(10): 298 | >>> train(...) 299 | >>> val_loss = validate(...) 300 | >>> # Note that step should be called after validate() 301 | >>> scheduler.epoch_step(val_loss) 302 | """ 303 | 304 | def epoch_step(self, metrics, epoch): 305 | current = metrics 306 | if current is None: 307 | warnings.warn('Learning Rate Plateau Reducing requires metrics available!', RuntimeWarning) 308 | else: 309 | if self.in_cooldown(): 310 | self.cooldown_counter -= 1 311 | self.wait = 0 312 | 313 | if self.monitor_op(current, self.best): 314 | self.best = current 315 | self.wait = 0 316 | elif not self.in_cooldown(): 317 | if self.wait >= self.patience: 318 | for param_group in self.optimizer.param_groups: 319 | old_lr = float(param_group['lr']) 320 | if old_lr > self.min_lr + self.eps: 321 | new_lr = old_lr * self.factor 322 | new_lr = max(new_lr, self.min_lr) 323 | param_group['lr'] = new_lr 324 | if self.verbose > 0: 325 | logger.info(f'Epoch {epoch}: reducing learning rate to {new_lr}.') 326 | if param_group['weight_decay'] != 0: 327 | old_weight_decay = float(param_group['weight_decay']) 328 | new_weight_decay = max(old_weight_decay * self.factor, self.min_lr) 329 | if old_weight_decay > new_weight_decay + self.eps: 330 | param_group['weight_decay'] = new_weight_decay 331 | if self.verbose: 332 | logger.info( 333 | f'\nEpoch {epoch}: reducing weight decay factor of group to {new_weight_decay:.4e}.') 334 | self.cooldown_counter = self.cooldown 335 | self.wait = 0 336 | self.wait += 1 337 | 338 | 339 | class CosineLRWithRestarts(object): 340 | """Decays learning rate with cosine annealing, normalizes weight decay 341 | hyperparameter value, implements restarts. 342 | https://arxiv.org/abs/1711.05101 343 | 344 | Args: 345 | optimizer (Optimizer): Wrapped optimizer. 346 | batch_size: minibatch size 347 | epoch_size: training samples per epoch 348 | restart_period: epoch count in the first restart period 349 | t_mult: multiplication factor by which the next restart period will extend/shrink 350 | 351 | Example: 352 | >>> scheduler = CosineLRWithRestarts(optimizer, 32, 1024, restart_period=5, t_mult=1.2) 353 | >>> for epoch in range(100): 354 | >>> scheduler.step() 355 | >>> train(...) 356 | >>> ... 357 | >>> optimizer.zero_grad() 358 | >>> loss.backward() 359 | >>> optimizer.step() 360 | >>> scheduler.batch_step() 361 | >>> validate(...) 362 | """ 363 | 364 | def __init__(self, optimizer, batch_size, epoch_size, restart_period=100, 365 | t_mult=2, last_epoch=-1, eta_threshold=1000, verbose=False): 366 | if not isinstance(optimizer, Optimizer): 367 | raise TypeError(f'{type(optimizer).__name__} is not an Optimizer') 368 | self.optimizer = optimizer 369 | if last_epoch == -1: 370 | for group in optimizer.param_groups: 371 | group.setdefault('initial_lr', group['lr']) 372 | else: 373 | for i, group in enumerate(optimizer.param_groups): 374 | if 'initial_lr' not in group: 375 | raise KeyError("param 'initial_lr' is not specified " 376 | f"in param_groups[{i}] when resuming an" 377 | " optimizer") 378 | self.base_lrs = list(map(lambda group: group['initial_lr'], 379 | optimizer.param_groups)) 380 | 381 | self.last_epoch = last_epoch 382 | self.batch_size = batch_size 383 | self.iteration = 0 384 | self.epoch_size = epoch_size 385 | self.eta_threshold = eta_threshold 386 | self.t_mult = t_mult 387 | self.verbose = verbose 388 | self.base_weight_decays = list(map(lambda group: group['weight_decay'], 389 | optimizer.param_groups)) 390 | self.restart_period = restart_period 391 | self.restarts = 0 392 | self.t_epoch = -1 393 | self.batch_increments = [] 394 | self._set_batch_increment() 395 | 396 | def _schedule_eta(self): 397 | """ 398 | Threshold value could be adjusted to shrink eta_min and eta_max values. 399 | """ 400 | eta_min = 0 401 | eta_max = 1 402 | if self.restarts <= self.eta_threshold: 403 | return eta_min, eta_max 404 | else: 405 | d = self.restarts - self.eta_threshold 406 | k = d * 0.09 407 | return (eta_min + k, eta_max - k) 408 | 409 | def get_lr(self, t_cur): 410 | eta_min, eta_max = self._schedule_eta() 411 | 412 | eta_t = (eta_min + 0.5 * (eta_max - eta_min) 413 | * (1. + math.cos(math.pi * 414 | (t_cur / self.restart_period)))) 415 | 416 | weight_decay_norm_multi = math.sqrt(self.batch_size / 417 | (self.epoch_size * 418 | self.restart_period)) 419 | lrs = [base_lr * eta_t for base_lr in self.base_lrs] 420 | weight_decays = [base_weight_decay * eta_t * weight_decay_norm_multi 421 | for base_weight_decay in self.base_weight_decays] 422 | 423 | if self.t_epoch % self.restart_period < self.t_epoch: 424 | if self.verbose: 425 | logger.info(f"Restart at epoch {self.last_epoch}") 426 | self.restart_period *= self.t_mult 427 | self.restarts += 1 428 | self.t_epoch = 0 429 | 430 | return zip(lrs, weight_decays) 431 | 432 | def _set_batch_increment(self): 433 | d, r = divmod(self.epoch_size, self.batch_size) 434 | batches_in_epoch = d + 2 if r > 0 else d + 1 435 | self.iteration = 0 436 | self.batch_increments = list(np.linspace(0, 1, batches_in_epoch)) 437 | 438 | def batch_step(self): 439 | self.last_epoch += 1 440 | self.t_epoch += 1 441 | self._set_batch_increment() 442 | try: 443 | t_cur = self.t_epoch + self.batch_increments[self.iteration] 444 | self.iteration += 1 445 | except (IndexError): 446 | raise RuntimeError("Epoch size and batch size used in the " 447 | "training loop and while initializing " 448 | "scheduler should be the same.") 449 | 450 | for param_group, (lr, weight_decay) in zip(self.optimizer.param_groups, self.get_lr(t_cur)): 451 | param_group['lr'] = lr 452 | param_group['weight_decay'] = weight_decay 453 | 454 | 455 | class NoamLR(object): 456 | ''' 457 | 主要参考论文<< Attention Is All You Need>>中的学习更新方式 458 | Example: 459 | >>> scheduler = NoamLR(d_model,factor,warm_up,optimizer) 460 | >>> for epoch in range(100): 461 | >>> scheduler.step() 462 | >>> train(...) 463 | >>> ... 464 | >>> glopab_step += 1 465 | >>> optimizer.zero_grad() 466 | >>> loss.backward() 467 | >>> optimizer.step() 468 | >>> scheduler.batch_step(global_step) 469 | >>> validate(...) 470 | ''' 471 | 472 | def __init__(self, d_model, factor, warm_up, optimizer): 473 | self.optimizer = optimizer 474 | self.warm_up = warm_up 475 | self.factor = factor 476 | self.d_model = d_model 477 | self._lr = 0 478 | 479 | def get_lr(self, step): 480 | lr = self.factor * (self.d_model ** (-0.5) * min(step ** (-0.5), step * self.warm_up ** (-1.5))) 481 | return lr 482 | 483 | def batch_step(self, step): 484 | ''' 485 | update parameters and rate 486 | :return: 487 | ''' 488 | lr = self.get_lr(step) 489 | for p in self.optimizer.param_groups: 490 | p['lr'] = lr 491 | self._lr = lr 492 | -------------------------------------------------------------------------------- /pyw2v/callback/modelcheckpoint.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import os 3 | from pathlib import Path 4 | import numpy as np 5 | import torch 6 | from ..common.tools import logger 7 | 8 | class ModelCheckpoint(object): 9 | ''' 10 | 模型保存,两种模式: 11 | 1. 直接保存最好模型 12 | 2. 按照epoch频率保存模型 13 | ''' 14 | def __init__(self, checkpoint_dir, 15 | monitor, 16 | arch,mode='min', 17 | epoch_freq=1, 18 | best = None, 19 | save_best_only = True): 20 | if isinstance(checkpoint_dir,Path): 21 | checkpoint_dir = checkpoint_dir 22 | else: 23 | checkpoint_dir = Path(checkpoint_dir) 24 | assert checkpoint_dir.is_dir() 25 | checkpoint_dir.mkdir(exist_ok=True) 26 | self.base_path = checkpoint_dir 27 | self.arch = arch 28 | self.monitor = monitor 29 | self.epoch_freq = epoch_freq 30 | self.save_best_only = save_best_only 31 | 32 | # 计算模式 33 | if mode == 'min': 34 | self.monitor_op = np.less 35 | self.best = np.Inf 36 | 37 | elif mode == 'max': 38 | self.monitor_op = np.greater 39 | self.best = -np.Inf 40 | # 这里主要重新加载模型时候 41 | #对best重新赋值 42 | if best: 43 | self.best = best 44 | 45 | if save_best_only: 46 | self.model_name = f"BEST_{arch}_MODEL.pth" 47 | 48 | def epoch_step(self, state,current): 49 | ''' 50 | :param state: 需要保存的信息 51 | :param current: 当前判断指标 52 | :return: 53 | ''' 54 | # 是否保存最好模型 55 | if self.save_best_only: 56 | if self.monitor_op(current, self.best): 57 | logger.info(f"\nEpoch {state['epoch']}: {self.monitor} improved from {self.best:.5f} to {current:.5f}") 58 | self.best = current 59 | state['best'] = self.best 60 | best_path = self.base_path/ self.model_name 61 | torch.save(state, str(best_path)) 62 | # 每隔几个epoch保存下模型 63 | else: 64 | filename = self.base_path / f"EPOCH_{state['epoch']}_{state[self.monitor]}_{self.arch}_MODEL.pth" 65 | if state['epoch'] % self.epoch_freq == 0: 66 | logger.info(f"\nEpoch {state['epoch']}: save model to disk.") 67 | torch.save(state, str(filename)) 68 | -------------------------------------------------------------------------------- /pyw2v/callback/progressbar.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class ProgressBar(object): 4 | def __init__(self, n_batch,width=30): 5 | self.width = width 6 | self.n_batch = n_batch 7 | self.start_time = time.time() 8 | 9 | def batch_step(self, batch_idx, info, bar_type='Training'): 10 | now = time.time() 11 | current = batch_idx + 1 12 | recv_per = current / self.n_batch 13 | bar = f'[{bar_type}] {current}/{self.n_batch} [' 14 | if recv_per >= 1: 15 | recv_per = 1 16 | prog_width = int(self.width * recv_per) 17 | if prog_width > 0: 18 | bar += '=' * (prog_width - 1) 19 | if current< self.n_batch: 20 | bar += ">" 21 | else: 22 | bar += '=' 23 | bar += '.' * (self.width - prog_width) 24 | bar += ']' 25 | show_bar = f"\r{bar}" 26 | time_per_unit = (now - self.start_time) / current 27 | if current < self.n_batch: 28 | eta = time_per_unit * (self.n_batch - current) 29 | if eta > 3600: 30 | eta_format = ('%d:%02d:%02d' % 31 | (eta // 3600, (eta % 3600) // 60, eta % 60)) 32 | elif eta > 60: 33 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 34 | else: 35 | eta_format = '%ds' % eta 36 | time_info = f' - ETA: {eta_format}' 37 | else: 38 | if time_per_unit >= 1: 39 | time_info = f' {time_per_unit:.1f}s/step' 40 | elif time_per_unit >= 1e-3: 41 | time_info = f' {time_per_unit * 1e3:.1f}ms/step' 42 | else: 43 | time_info = f' {time_per_unit * 1e6:.1f}us/step' 44 | 45 | show_bar += time_info 46 | if len(info) != 0: 47 | show_info = f'{show_bar} ' + \ 48 | "-".join([f' {key}: {value:.4f} ' for key, value in info.items()]) 49 | print(show_info, end='') 50 | else: 51 | print(show_bar, end='') 52 | -------------------------------------------------------------------------------- /pyw2v/callback/trainingmonitor.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import numpy as np 3 | from pathlib import Path 4 | import matplotlib.pyplot as plt 5 | from ..common.tools import load_json 6 | from ..common.tools import save_json 7 | plt.switch_backend('agg') # 防止ssh上绘图问题 8 | 9 | 10 | class TrainingMonitor(): 11 | def __init__(self, file_dir, arch, add_test=False,add_valid = True): 12 | ''' 13 | :param startAt: 重新开始训练的epoch点 14 | ''' 15 | if isinstance(file_dir, Path): 16 | pass 17 | else: 18 | file_dir = Path(file_dir) 19 | file_dir.mkdir(parents=True, exist_ok=True) 20 | 21 | self.arch = arch 22 | self.file_dir = file_dir 23 | self.H = {} 24 | self.add_test = add_test 25 | self.add_valid = add_valid 26 | self.json_path = file_dir / (arch + "_training_monitor.json") 27 | 28 | def reset(self,start_at): 29 | if start_at > 0: 30 | if self.json_path is not None: 31 | if self.json_path.exists(): 32 | self.H = load_json(self.json_path) 33 | for k in self.H.keys(): 34 | self.H[k] = self.H[k][:start_at] 35 | 36 | def epoch_step(self, logs={}): 37 | for (k, v) in logs.items(): 38 | l = self.H.get(k, []) 39 | # np.float32会报错 40 | if not isinstance(v, np.float): 41 | v = round(float(v), 4) 42 | l.append(v) 43 | self.H[k] = l 44 | 45 | # 写入文件 46 | if self.json_path is not None: 47 | save_json(data = self.H,file_path=self.json_path) 48 | 49 | # 保存train图像 50 | if len(self.H["loss"]) == 1: 51 | self.paths = {key: self.file_dir / (self.arch + f'_{key.upper()}') for key in self.H.keys()} 52 | 53 | if len(self.H["loss"]) > 1: 54 | # 指标变化 55 | # 曲线 56 | # 需要成对出现 57 | keys = [key for key, _ in self.H.items() if '_' not in key] 58 | for key in keys: 59 | N = np.arange(0, len(self.H[key])) 60 | plt.style.use("ggplot") 61 | plt.figure() 62 | plt.plot(N, self.H[key], label=f"train_{key}") 63 | if self.add_valid: 64 | plt.plot(N, self.H[f"valid_{key}"], label=f"valid_{key}") 65 | if self.add_test: 66 | plt.plot(N, self.H[f"test_{key}"], label=f"test_{key}") 67 | plt.legend() 68 | plt.xlabel("Epoch #") 69 | plt.ylabel(key) 70 | plt.title(f"Training {key} [Epoch {len(self.H[key])}]") 71 | plt.savefig(str(self.paths[key])) 72 | plt.close() 73 | -------------------------------------------------------------------------------- /pyw2v/common/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 -------------------------------------------------------------------------------- /pyw2v/common/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import numpy as np 5 | import json 6 | import pickle 7 | from pathlib import Path 8 | import torch.nn as nn 9 | from collections import OrderedDict 10 | from pathlib import Path 11 | import logging 12 | 13 | logger = logging.getLogger() 14 | 15 | def print_config(config): 16 | info = "Running with the following configs:\n" 17 | for k, v in config.items(): 18 | info += f"\t{k} : {str(v)}\n" 19 | print("\n" + info + "\n") 20 | return 21 | 22 | def init_logger(log_file=None, log_file_level=logging.NOTSET): 23 | ''' 24 | 日志文件 25 | Example: 26 | >>> from logging import init_logger,logger 27 | >>> init_logger(log_file) 28 | >>> logger.info("abc'") 29 | ''' 30 | if isinstance(log_file,Path): 31 | log_file = str(log_file) 32 | # log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") 33 | log_format = logging.Formatter("%(message)s") 34 | logger = logging.getLogger() 35 | logger.setLevel(logging.INFO) 36 | console_handler = logging.StreamHandler() 37 | console_handler.setFormatter(log_format) 38 | logger.handlers = [console_handler] 39 | if log_file and log_file != '': 40 | file_handler = logging.FileHandler(log_file) 41 | file_handler.setLevel(log_file_level) 42 | file_handler.setFormatter(log_format) 43 | logger.addHandler(file_handler) 44 | return logger 45 | 46 | def seed_everything(seed=1029): 47 | ''' 48 | 设置整个开发环境的seed 49 | :param seed: 50 | :param device: 51 | :return: 52 | ''' 53 | random.seed(seed) 54 | os.environ['PYTHONHASHSEED'] = str(seed) 55 | np.random.seed(seed) 56 | torch.manual_seed(seed) 57 | torch.cuda.manual_seed(seed) 58 | torch.cuda.manual_seed_all(seed) 59 | # some cudnn methods can be random even after fixing the seed 60 | # unless you tell it to be deterministic 61 | torch.backends.cudnn.deterministic = True 62 | 63 | 64 | def prepare_device(n_gpu_use): 65 | """ 66 | setup GPU device if available, move model into configured device 67 | # 如果n_gpu_use为数字,则使用range生成list 68 | # 如果输入的是一个list,则默认使用list[0]作为controller 69 | """ 70 | if not n_gpu_use: 71 | device_type = 'cpu' 72 | else: 73 | n_gpu_use = n_gpu_use.split(",") 74 | device_type = f"cuda:{n_gpu_use[0]}" 75 | n_gpu = torch.cuda.device_count() 76 | if len(n_gpu_use) > 0 and n_gpu == 0: 77 | logger.warning("Warning: There\'s no GPU available on this machine, training will be performed on CPU.") 78 | device_type = 'cpu' 79 | if len(n_gpu_use) > n_gpu: 80 | msg = f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are available on this machine." 81 | logger.warning(msg) 82 | n_gpu_use = range(n_gpu) 83 | device = torch.device(device_type) 84 | list_ids = n_gpu_use 85 | return device, list_ids 86 | 87 | 88 | def model_device(n_gpu, model): 89 | ''' 90 | 判断环境 cpu还是gpu 91 | 支持单机多卡 92 | :param n_gpu: 93 | :param model: 94 | :return: 95 | ''' 96 | device, device_ids = prepare_device(n_gpu) 97 | if len(device_ids) > 1: 98 | logger.info(f"current {len(device_ids)} GPUs") 99 | model = torch.nn.DataParallel(model, device_ids=device_ids) 100 | if len(device_ids) == 1: 101 | os.environ['CUDA_VISIBLE_DEVICES'] = str(device_ids[0]) 102 | model = model.to(device) 103 | return model, device 104 | 105 | 106 | def restore_checkpoint(resume_path, model=None): 107 | ''' 108 | 加载模型 109 | :param resume_path: 110 | :param model: 111 | :param optimizer: 112 | :return: 113 | 注意: 如果是加载Bert模型的话,需要调整,不能使用该模式 114 | 可以使用模块自带的Bert_model.from_pretrained(state_dict = your save state_dict) 115 | ''' 116 | if isinstance(resume_path, Path): 117 | resume_path = str(resume_path) 118 | checkpoint = torch.load(resume_path) 119 | best = checkpoint['best'] 120 | start_epoch = checkpoint['epoch'] + 1 121 | states = checkpoint['state_dict'] 122 | if isinstance(model, nn.DataParallel): 123 | model.module.load_state_dict(states) 124 | else: 125 | model.load_state_dict(states) 126 | return [model,best,start_epoch] 127 | 128 | 129 | def save_pickle(data, file_path): 130 | ''' 131 | 保存成pickle文件 132 | :param data: 133 | :param file_name: 134 | :param pickle_path: 135 | :return: 136 | ''' 137 | if isinstance(file_path, Path): 138 | file_path = str(file_path) 139 | with open(file_path, 'wb') as f: 140 | pickle.dump(data, f) 141 | 142 | 143 | def load_pickle(input_file): 144 | ''' 145 | 读取pickle文件 146 | :param pickle_path: 147 | :param file_name: 148 | :return: 149 | ''' 150 | with open(str(input_file), 'rb') as f: 151 | data = pickle.load(f) 152 | return data 153 | 154 | 155 | def save_json(data, file_path): 156 | ''' 157 | 保存成json文件 158 | :param data: 159 | :param json_path: 160 | :param file_name: 161 | :return: 162 | ''' 163 | if not isinstance(file_path, Path): 164 | file_path = Path(file_path) 165 | # if isinstance(data,dict): 166 | # data = json.dumps(data) 167 | with open(str(file_path), 'w') as f: 168 | json.dump(data, f) 169 | 170 | 171 | def load_json(file_path): 172 | ''' 173 | 加载json文件 174 | :param json_path: 175 | :param file_name: 176 | :return: 177 | ''' 178 | if not isinstance(file_path, Path): 179 | file_path = Path(file_path) 180 | with open(str(file_path), 'r') as f: 181 | data = json.load(f) 182 | return data 183 | 184 | def save_model(model, model_path): 185 | """ 存储不含有显卡信息的state_dict或model 186 | :param model: 187 | :param model_name: 188 | :param only_param: 189 | :return: 190 | """ 191 | if isinstance(model_path, Path): 192 | model_path = str(model_path) 193 | if isinstance(model, nn.DataParallel): 194 | model = model.module 195 | state_dict = model.state_dict() 196 | for key in state_dict: 197 | state_dict[key] = state_dict[key].cpu() 198 | torch.save(state_dict, model_path) 199 | 200 | def load_model(model, model_path): 201 | ''' 202 | 加载模型 203 | :param model: 204 | :param model_name: 205 | :param model_path: 206 | :param only_param: 207 | :return: 208 | ''' 209 | if isinstance(model_path, Path): 210 | model_path = str(model_path) 211 | logging.info(f"loading model from {str(model_path)} .") 212 | states = torch.load(model_path) 213 | state = states['state_dict'] 214 | if isinstance(model, nn.DataParallel): 215 | model.module.load_state_dict(state) 216 | else: 217 | model.load_state_dict(state) 218 | return model 219 | 220 | 221 | class AverageMeter(object): 222 | ''' 223 | computes and stores the average and current value 224 | Example: 225 | >>> loss = AverageMeter() 226 | >>> for step,batch in enumerate(train_data): 227 | >>> pred = self.model(batch) 228 | >>> raw_loss = self.metrics(pred,target) 229 | >>> loss.update(raw_loss.item(),n = 1) 230 | >>> cur_loss = loss.avg 231 | ''' 232 | 233 | def __init__(self): 234 | self.reset() 235 | 236 | def reset(self): 237 | self.val = 0 238 | self.avg = 0 239 | self.sum = 0 240 | self.count = 0 241 | 242 | def update(self, val, n=1): 243 | self.val = val 244 | self.sum += val * n 245 | self.count += n 246 | self.avg = self.sum / self.count 247 | 248 | 249 | def summary(model, *inputs, batch_size=-1, show_input=True): 250 | ''' 251 | 打印模型结构信息 252 | :param model: 253 | :param inputs: 254 | :param batch_size: 255 | :param show_input: 256 | :return: 257 | Example: 258 | >>> print("model summary info: ") 259 | >>> for step,batch in enumerate(train_data): 260 | >>> summary(self.model,*batch,show_input=True) 261 | >>> break 262 | ''' 263 | 264 | def register_hook(module): 265 | def hook(module, input, output=None): 266 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 267 | module_idx = len(summary) 268 | 269 | m_key = f"{class_name}-{module_idx + 1}" 270 | summary[m_key] = OrderedDict() 271 | summary[m_key]["input_shape"] = list(input[0].size()) 272 | summary[m_key]["input_shape"][0] = batch_size 273 | 274 | if show_input is False and output is not None: 275 | if isinstance(output, (list, tuple)): 276 | for out in output: 277 | if isinstance(out, torch.Tensor): 278 | summary[m_key]["output_shape"] = [ 279 | [-1] + list(out.size())[1:] 280 | ][0] 281 | else: 282 | summary[m_key]["output_shape"] = [ 283 | [-1] + list(out[0].size())[1:] 284 | ][0] 285 | else: 286 | summary[m_key]["output_shape"] = list(output.size()) 287 | summary[m_key]["output_shape"][0] = batch_size 288 | 289 | params = 0 290 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 291 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 292 | summary[m_key]["trainable"] = module.weight.requires_grad 293 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 294 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 295 | summary[m_key]["nb_params"] = params 296 | 297 | if (not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) and not (module == model)): 298 | if show_input is True: 299 | hooks.append(module.register_forward_pre_hook(hook)) 300 | else: 301 | hooks.append(module.register_forward_hook(hook)) 302 | 303 | # create properties 304 | summary = OrderedDict() 305 | hooks = [] 306 | 307 | # register hook 308 | model.apply(register_hook) 309 | model(*inputs) 310 | 311 | # remove these hooks 312 | for h in hooks: 313 | h.remove() 314 | 315 | print("-----------------------------------------------------------------------") 316 | if show_input is True: 317 | line_new = f"{'Layer (type)':>25} {'Input Shape':>25} {'Param #':>15}" 318 | else: 319 | line_new = f"{'Layer (type)':>25} {'Output Shape':>25} {'Param #':>15}" 320 | print(line_new) 321 | print("=======================================================================") 322 | 323 | total_params = 0 324 | total_output = 0 325 | trainable_params = 0 326 | for layer in summary: 327 | # input_shape, output_shape, trainable, nb_params 328 | if show_input is True: 329 | line_new = "{:>25} {:>25} {:>15}".format( 330 | layer, 331 | str(summary[layer]["input_shape"]), 332 | "{0:,}".format(summary[layer]["nb_params"]), 333 | ) 334 | else: 335 | line_new = "{:>25} {:>25} {:>15}".format( 336 | layer, 337 | str(summary[layer]["output_shape"]), 338 | "{0:,}".format(summary[layer]["nb_params"]), 339 | ) 340 | 341 | total_params += summary[layer]["nb_params"] 342 | if show_input is True: 343 | total_output += np.prod(summary[layer]["input_shape"]) 344 | else: 345 | total_output += np.prod(summary[layer]["output_shape"]) 346 | if "trainable" in summary[layer]: 347 | if summary[layer]["trainable"] == True: 348 | trainable_params += summary[layer]["nb_params"] 349 | 350 | print(line_new) 351 | 352 | print("=======================================================================") 353 | print(f"Total params: {total_params:0,}") 354 | print(f"Trainable params: {trainable_params:0,}") 355 | print(f"Non-trainable params: {(total_params - trainable_params):0,}") 356 | print("-----------------------------------------------------------------------") 357 | -------------------------------------------------------------------------------- /pyw2v/config/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | -------------------------------------------------------------------------------- /pyw2v/config/basic_config.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | from pathlib import Path 3 | BASE_DIR = Path('pyw2v') 4 | 5 | configs = { 6 | 'data_path': BASE_DIR / 'dataset/raw/zhihu.txt', 7 | 'model_save_path': BASE_DIR / 'output/checkpoints/word2vec.pth', 8 | 9 | 'vocab_path': BASE_DIR / 'dataset/processed/vocab.pkl', # 语料数据 10 | 'pytorch_embedding_path': BASE_DIR / 'output/embedding/pytorch_word2vec2.bin', 11 | 'gensim_embedding_path':BASE_DIR / 'output/embedding/gensim_word2vec.bin', 12 | 13 | 'log_dir': BASE_DIR / 'output/log', 14 | 'figure_dir': BASE_DIR / 'output/figure', 15 | 'stopword_path': BASE_DIR / 'dataset/stopwords.txt' 16 | } 17 | -------------------------------------------------------------------------------- /pyw2v/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | -------------------------------------------------------------------------------- /pyw2v/dataset/processed/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | -------------------------------------------------------------------------------- /pyw2v/dataset/raw/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | -------------------------------------------------------------------------------- /pyw2v/ensemble/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | -------------------------------------------------------------------------------- /pyw2v/feature/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | -------------------------------------------------------------------------------- /pyw2v/io/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/chinese-word2vec-pytorch/22c9b3f502824145871427cc247a06869645f6e7/pyw2v/io/__init__.py -------------------------------------------------------------------------------- /pyw2v/io/data_transformer.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import numpy as np 3 | from sklearn.metrics.pairwise import cosine_similarity 4 | 5 | class DataTransformer(object): 6 | def __init__(self, 7 | embedding_path): 8 | self.embedding_path = embedding_path 9 | self.reset() 10 | 11 | def reset(self): 12 | self.load_embedding() 13 | 14 | # 加载词向量矩阵 15 | def load_embedding(self, ): 16 | print(" load emebedding weights") 17 | self.embeddings_index = {} 18 | self.words = [] 19 | self.vectors = [] 20 | f = open(self.embedding_path, 'r',encoding = 'utf8') 21 | for line in f: 22 | values = line.split(' ') 23 | try: 24 | word = values[0] 25 | self.words.append(word) 26 | coefs = np.asarray(values[1:], dtype='float32') 27 | self.embeddings_index[word] = coefs 28 | self.vectors.append(coefs) 29 | except: 30 | print("Error on ", values[:2]) 31 | f.close() 32 | self.vectors = np.vstack(self.vectors) 33 | print('Total %s word vectors.' % len(self.embeddings_index)) 34 | 35 | # 计算相似度 36 | def get_similar_words(self, word, w_num=10): 37 | if word not in self.embeddings_index: 38 | raise ValueError('%d not in vocab') 39 | current_vector = self.embeddings_index[word] 40 | result = cosine_similarity(current_vector.reshape(1, -1), self.vectors) 41 | result = np.array(result).reshape(len(self.words), ) 42 | idxs = np.argsort(result)[::-1][:w_num] 43 | print("<<<" * 7) 44 | print(word) 45 | for i in idxs: 46 | print("%s : %.3f\n" % (self.words[i], result[i])) 47 | print(">>>" * 7) 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /pyw2v/io/dataset.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import math 3 | import random 4 | import torch 5 | import numpy as np 6 | from collections import Counter 7 | from ..common.tools import save_pickle 8 | import operator 9 | 10 | class DataLoader(object): 11 | def __init__(self, 12 | min_freq, 13 | data_path, 14 | window_size, 15 | skip_header, 16 | negative_num, 17 | vocab_size, 18 | vocab_path, 19 | shuffle, 20 | seed, 21 | sample 22 | ): 23 | 24 | self.window_size = window_size 25 | self.negative_num = negative_num 26 | self.min_freq = min_freq 27 | self.shuffle = shuffle 28 | self.seed = seed 29 | self.sample = sample 30 | self.data_path = data_path 31 | self.vocab_path = vocab_path 32 | self.skip_header = skip_header 33 | self.vocab_size = vocab_size 34 | self.random_s = np.random.RandomState(seed) 35 | self.build_examples() 36 | self.build_vocab() 37 | self.build_negative_sample_table() 38 | self.subsampling() 39 | 40 | # 分割数据 41 | def split_sent(self,line): 42 | res = line.split() 43 | return res 44 | 45 | # 将词转化为id 46 | def word_to_id(self,word, vocab): 47 | return vocab[word][0] if word in vocab else vocab[''][0] 48 | 49 | # 读取数据,并进行预处理 50 | def build_examples(self): 51 | self.examples = [] 52 | print('read data and processing') 53 | with open(self.data_path, 'r') as fr: 54 | for i, line in enumerate(fr): 55 | # 数据首行为列名 56 | if i == 0 and self.skip_header: 57 | continue 58 | line = line.strip("\n") 59 | if line: 60 | self.examples.append(self.split_sent(line)) 61 | 62 | # 建立语料库 63 | def build_vocab(self): 64 | count = Counter() 65 | print("build vocab") 66 | for words in self.examples: 67 | count.update(words) 68 | count = {k: v for k, v in count.items()} 69 | count = sorted(count.items(), key=operator.itemgetter(1),reverse=True) 70 | all_words = [(w[0],w[1]) for w in count if w[1] >= self.min_freq] 71 | if self.vocab_size: 72 | all_words = all_words[:self.vocab_size] 73 | all_words = all_words+[('',0)] 74 | word2id = {k: (i,v) for i,(k, v) in zip(range(0, len(all_words)),all_words)} 75 | self.word_frequency = {tu[0]: tu[1] for word, tu in word2id.items()} 76 | self.vocab = {word: tu[0] for word, tu in word2id.items()} 77 | print(f"vocab size: {len(self.vocab)}") 78 | save_pickle(data = word2id,file_path=self.vocab_path) 79 | 80 | # 构建负样本 81 | def build_negative_sample_table(self): 82 | self.negative_sample_table = [] 83 | sample_table_size = 1e8 84 | pow_frequency = np.array(list(self.word_frequency.values())) ** 0.75 85 | words_pow = sum(pow_frequency) 86 | ratio = pow_frequency / words_pow 87 | count = np.round(ratio * sample_table_size) 88 | for wid, c in enumerate(count): 89 | self.negative_sample_table += [wid] * int(c) 90 | self.negative_sample_table = np.array(self.negative_sample_table) 91 | 92 | def reserve_ratio(self,p,total): 93 | tmp_p = (math.sqrt( p / self.sample) + 1 ) * self.sample / p 94 | if tmp_p >1: 95 | tmp_p = 1 96 | return tmp_p * total 97 | 98 | # 数据采样,降低高频词的出现 99 | def subsampling(self,total = 2 ** 32): 100 | pow_frequency = np.array(list(self.word_frequency.values())) 101 | words_pow = sum(pow_frequency) 102 | ratio = pow_frequency / words_pow 103 | delete_int = [self.reserve_ratio(p,total = total) for p in ratio] 104 | 105 | self.train_examples = [] 106 | for example in self.examples: 107 | words = [self.vocab[word] for word in example if 108 | word in self.vocab and delete_int[self.vocab[word]] >= random.random() * total] 109 | if len(words) > 0: 110 | self.train_examples.append(words) 111 | del self.examples 112 | 113 | # 负样本 114 | def get_neg_word(self,u): 115 | neg_v = [] 116 | while len(neg_v) < self.negative_num: 117 | n_w = np.random.choice(self.negative_sample_table,size = self.negative_num).tolist()[0] 118 | if n_w != u: 119 | neg_v.append(n_w) 120 | return neg_v 121 | 122 | # 构建skip gram模型样本 123 | def make_iter(self): 124 | for example in self.train_examples: 125 | if len(example) < 2: 126 | continue 127 | reduced_window = self.random_s.randint(self.window_size) 128 | for i,w in enumerate(example): 129 | words_num = len(example) 130 | window_start = max(0, i - self.window_size + reduced_window) 131 | window_end = min(words_num, i + self.window_size + 1 - reduced_window) 132 | pos_v = [example[j] for j in range(window_start, window_end) if j != i] 133 | pos_u = [w] * len(pos_v) 134 | neg_u = [c for c in pos_v for _ in range(self.negative_num)] 135 | neg_v = [v for u in pos_u for v in self.get_neg_word(u)] 136 | yield (torch.tensor(pos_u,dtype=torch.long), 137 | torch.tensor(pos_v,dtype=torch.long), 138 | torch.tensor(neg_u,dtype=torch.long), 139 | torch.tensor(neg_v,dtype=torch.long)) 140 | 141 | def __len__(self): 142 | return len([w for ex in self.train_examples for w in ex if len(ex) >=2]) 143 | 144 | -------------------------------------------------------------------------------- /pyw2v/model/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | -------------------------------------------------------------------------------- /pyw2v/model/nn/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | -------------------------------------------------------------------------------- /pyw2v/model/nn/gensim_word2vec.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | from gensim.models import word2vec 3 | class Word2Vec(): 4 | def __init__(self,size, 5 | sg, 6 | iter, 7 | seed, 8 | save_path, 9 | num_workers, 10 | window, 11 | min_count): 12 | 13 | self.size=size 14 | self.sg = sg 15 | self.seed = seed 16 | self.iter = iter 17 | self.window = window 18 | self.min_count = min_count 19 | self.workers = num_workers 20 | self.save_path = save_path 21 | 22 | def train_w2v(self, data): 23 | model = word2vec.Word2Vec(data, 24 | size=self.size, 25 | window=self.window, 26 | sg=self.sg, 27 | min_count=self.min_count, 28 | workers=self.workers, 29 | seed=self.seed, 30 | compute_loss=True, 31 | iter= self.iter) 32 | print(model.get_latest_training_loss()) 33 | with open(self.save_path,'w') as fw: 34 | for word in model.wv.vocab: 35 | vector = model[word] 36 | fw.write(str(word) + ' ' + ' '.join(map(str, vector)) + '\n') 37 | -------------------------------------------------------------------------------- /pyw2v/model/nn/skip_gram.py: -------------------------------------------------------------------------------- 1 | #encoding:Utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class SkipGram(torch.nn.Module): 7 | def __init__(self, embedding_dim, vocab_size): 8 | super(SkipGram, self).__init__() 9 | initrange = 0.5 / embedding_dim 10 | self.u_embedding_matrix = nn.Embedding(vocab_size,embedding_dim) 11 | self.u_embedding_matrix.weight.data.uniform_(-initrange,initrange) 12 | self.v_embedding_matrix = nn.Embedding(vocab_size,embedding_dim) 13 | self.v_embedding_matrix.weight.data.uniform_(-0, 0) 14 | 15 | def forward(self, pos_u, pos_v,neg_u, neg_v): 16 | embed_pos_u = self.v_embedding_matrix(pos_u) 17 | embed_pos_v = self.u_embedding_matrix(pos_v) 18 | score = torch.mul(embed_pos_u, embed_pos_v) 19 | score = torch.sum(score,dim = 1) 20 | log_target = F.logsigmoid(score).squeeze() 21 | 22 | embed_neg_u = self.u_embedding_matrix(neg_u) 23 | embed_neg_v = self.v_embedding_matrix(neg_v) 24 | 25 | neg_score = torch.mul(embed_neg_u,embed_neg_v) 26 | neg_score = torch.sum(neg_score, dim=1) 27 | sum_log_sampled = F.logsigmoid(-1 * neg_score).squeeze() 28 | 29 | loss = log_target.sum() + sum_log_sampled.sum() 30 | loss = -1 * loss 31 | return loss 32 | -------------------------------------------------------------------------------- /pyw2v/output/checkpoints/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | -------------------------------------------------------------------------------- /pyw2v/output/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | -------------------------------------------------------------------------------- /pyw2v/output/feature/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | -------------------------------------------------------------------------------- /pyw2v/output/figure/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | -------------------------------------------------------------------------------- /pyw2v/output/log/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | -------------------------------------------------------------------------------- /pyw2v/output/result/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | -------------------------------------------------------------------------------- /pyw2v/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/chinese-word2vec-pytorch/22c9b3f502824145871427cc247a06869645f6e7/pyw2v/preprocessing/__init__.py -------------------------------------------------------------------------------- /pyw2v/preprocessing/preprocessor.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import re 3 | class Preprocessor(object): 4 | def __init__(self,min_len = 2,stopwords_path = None): 5 | self.min_len = min_len 6 | self.stopwords_path = stopwords_path 7 | self.reset() 8 | 9 | def reset(self): 10 | if self.stopwords_path: 11 | with open(self.stopwords_path,'r') as fr: 12 | self.stopwords = {} 13 | for line in fr: 14 | word = line.strip(' ').strip('\n') 15 | self.stopwords[word] = 1 16 | 17 | # 去除长度小于min_len的文本 18 | def clean_length(self,x): 19 | if len(x.split(" ")) >= self.min_len: 20 | return x 21 | 22 | #去除停用词 23 | def remove_stopword(self,sentence): 24 | words = sentence.split() 25 | x = [word for word in words if word not in self.stopwords] 26 | return " ".join(x) 27 | 28 | # 删除数字 29 | def remove_numbers(self,sentence): 30 | words = sentence.split() 31 | x = [re.sub('\d+','',word) for word in words] 32 | return ' '.join([w for w in x if w !='']) 33 | 34 | # 移除中文 35 | def get_china(self,sentence): 36 | zhmodel = re.compile("[\u4e00-\u9fa5]") 37 | words = sentence.split() 38 | china_list = [] 39 | for word in words: 40 | match = zhmodel.search(word) 41 | if match: 42 | china_list.append(word) 43 | return ' '.join(china_list) 44 | 45 | def __call__(self, sentence): 46 | # TorchText returns a list of words instead of a normal sentence. 47 | # First, create the sentence again. Then, do preprocess. Finally, return the preprocessed sentence as list 48 | # of words 49 | x = sentence 50 | if self.stopwords_path: 51 | x = self.remove_stopword(x) 52 | x = self.remove_numbers(x) 53 | x = self.get_china(x) 54 | x = self.clean_length(x) 55 | return x 56 | -------------------------------------------------------------------------------- /pyw2v/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/chinese-word2vec-pytorch/22c9b3f502824145871427cc247a06869645f6e7/pyw2v/train/__init__.py -------------------------------------------------------------------------------- /pyw2v/train/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..common.tools import AverageMeter 3 | from ..callback.progressbar import ProgressBar 4 | from ..common.tools import model_device 5 | 6 | # 训练包装器 7 | class Trainer(object): 8 | def __init__(self,model, 9 | epochs, 10 | logger, 11 | n_gpu, 12 | vocab, 13 | model_save_path, 14 | vector_save_path, 15 | optimizer, 16 | lr_scheduler, 17 | training_monitor, 18 | verbose = 1): 19 | self.model = model 20 | self.epochs = epochs 21 | self.optimizer = optimizer 22 | self.logger = logger 23 | self.verbose = verbose 24 | self.training_monitor = training_monitor 25 | self.lr_scheduler = lr_scheduler 26 | self.n_gpu = n_gpu 27 | self.vocab = vocab 28 | self.vector_save_path = vector_save_path 29 | self.model_save_path = model_save_path 30 | 31 | self.model, self.device = model_device(n_gpu, model=self.model) 32 | self.start_epoch = 1 33 | 34 | 35 | def _save_info(self): 36 | state = { 37 | 'epoch': self.epochs, 38 | 'state_dict': self.model.state_dict(), 39 | 'optimizer': self.optimizer.state_dict(), 40 | } 41 | return state 42 | 43 | def save(self): 44 | id_word = {value:key for key ,value in self.vocab.items()} 45 | state = self._save_info() 46 | torch.save(state, self.model_save_path) 47 | self.logger.info('saving word2vec vector') 48 | metrix = self.model.v_embedding_matrix.weight.data 49 | with open(self.vector_save_path, "w", encoding="utf-8") as f: 50 | if self.device=='cpu': 51 | vector = metrix.numpy() 52 | else: 53 | vector = metrix.cpu().numpy() 54 | for i in range(len(vector)): 55 | if i % 1000 == 0: 56 | print(f'saving {i} word vector') 57 | word = id_word[i] 58 | s_vec = vector[i] 59 | s_vec = [str(s) for s in s_vec.tolist()] 60 | write_line = word + " " + " ".join(s_vec)+"\n" 61 | f.write(write_line) 62 | 63 | # epoch训练 64 | def train_epoch(self,train_data): 65 | pbar = ProgressBar(n_batch=len(train_data)) 66 | train_loss = AverageMeter() 67 | self.model.train() 68 | assert self.model.training 69 | train_examples = train_data.make_iter() 70 | for step,batch in enumerate(train_examples): 71 | batch = tuple(t.to(self.device) for t in batch) 72 | pos_u, pos_v, neg_u, neg_v = batch 73 | self.optimizer.zero_grad() 74 | loss = self.model(pos_u, pos_v, neg_u, neg_v) 75 | loss.backward() 76 | self.optimizer.step() 77 | pbar.batch_step(batch_idx=step, info={'loss': loss.item()}) 78 | train_loss.update(loss.item(),n = 1) 79 | print(" ") 80 | result = {'loss':train_loss.avg} 81 | if 'cuda' in str(self.device): 82 | torch.cuda.empty_cache() 83 | return result 84 | 85 | def train(self,train_data): 86 | for epoch in range(self.start_epoch,self.start_epoch+self.epochs): 87 | print(f"Epoch {epoch}/{self.start_epoch + self.epochs - 1}") 88 | train_log = self.train_epoch(train_data) 89 | 90 | show_info = f'\nEpoch: {epoch} - ' + "-".join([f' {key}: {value:.4f} ' for key, value in train_log.items()]) 91 | self.logger.info(show_info) 92 | 93 | if hasattr(self.lr_scheduler, 'epoch_step'): 94 | self.lr_scheduler.epoch_step(epoch) 95 | 96 | if self.training_monitor: 97 | self.training_monitor.epoch_step(train_log) 98 | self.save() 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /train_gensim_word2vec.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import argparse 3 | from pyw2v.common.tools import logger,init_logger,seed_everything 4 | from pyw2v.config.basic_config import configs as config 5 | from pyw2v.model.nn import gensim_word2vec 6 | from pyw2v.preprocessing.preprocessor import Preprocessor 7 | 8 | 9 | def run(args): 10 | logger.info('load data from disk' ) 11 | processing = Preprocessor(min_len=2,stopwords_path=config['stopword_path']) 12 | examples = [] 13 | with open(config['data_path'], 'r') as fr: 14 | for i, line in enumerate(fr): 15 | # 数据首行为列名 16 | if i == 0 and False: 17 | continue 18 | line = line.strip("\n") 19 | line = processing(line) 20 | if line: 21 | examples.append(line.split()) 22 | logger.info("initializing emnedding model") 23 | word2vec_model = gensim_word2vec.Word2Vec(sg = 1, 24 | iter = 10, 25 | size=args.embedd_dim, 26 | window=args.window_size, 27 | min_count=args.min_freq, 28 | save_path=config['gensim_embedding_path'], 29 | num_workers=args.num_workers, 30 | seed = args.seed) 31 | word2vec_model.train_w2v([[word for word in document] for document in examples]) 32 | 33 | def main(): 34 | parser = argparse.ArgumentParser(description='Gensim Word2Vec model training') 35 | parser.add_argument("--model", type=str, default='gensim_word2vec') 36 | parser.add_argument("--task", type=str, default='training word vector') 37 | parser.add_argument('--seed', default=2018, type=int, 38 | help='Seed for initializing training.') 39 | parser.add_argument('--resume', default=False, type=bool, 40 | help='Choose whether resume checkpoint model') 41 | parser.add_argument('--embedd_dim', default=300, type=int) 42 | parser.add_argument('--spochs', default=6, type=int) 43 | parser.add_argument('--window_size', default=5, str=int) 44 | parser.add_argument('--n_gpu', default='0', type=str) 45 | parser.add_argument('--min_freq', default=5, type=int) 46 | parser.add_argument('--sample', default=1e-3, type=float) 47 | parser.add_argument('--negative_sample_num', default=5, type=int) 48 | parser.add_argument('--learning_rate', default=0.025, type=float) 49 | parser.add_argument('--weight_decay', default=5e-4, type=float) 50 | parser.add_argument('--vocab_size', default=30000000, type=int) 51 | parser.add_argument('--num_workers',default=10) 52 | args = parser.parse_args() 53 | init_logger(log_file=config['log_dir'] / (args.model + ".log")) 54 | logger.info("seed is %d" % args['seed']) 55 | seed_everything(seed=args['seed']) 56 | run(args) 57 | if __name__ == "__main__": 58 | 59 | main() 60 | 61 | -------------------------------------------------------------------------------- /train_word2vec.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import argparse 3 | import torch 4 | import warnings 5 | from torch import optim 6 | from pyw2v.train.trainer import Trainer 7 | from pyw2v.io.dataset import DataLoader 8 | from pyw2v.model.nn.skip_gram import SkipGram 9 | from pyw2v.common.tools import init_logger, logger 10 | from pyw2v.common.tools import seed_everything 11 | from pyw2v.config.basic_config import configs as config 12 | from pyw2v.callback.lrscheduler import StepLR 13 | from pyw2v.callback.trainingmonitor import TrainingMonitor 14 | 15 | warnings.filterwarnings("ignore") 16 | 17 | 18 | def run(args): 19 | # **************************** 加载数据集 **************************** 20 | logger.info('starting load train data from disk') 21 | train_dataset = DataLoader(skip_header=False, 22 | negative_num=args.negative_sample_num, 23 | window_size=args.window_size, 24 | data_path=config['data_path'], 25 | vocab_path=config['vocab_path'], 26 | vocab_size=args.vocab_size, 27 | min_freq=args.min_freq, 28 | shuffle=True, 29 | seed=args.seed, 30 | sample=args.sample) 31 | 32 | # **************************** 模型和优化器 *********************** 33 | logger.info("initializing model") 34 | model = SkipGram(embedding_dim=args.embedd_dim, vocab_size=len(train_dataset.vocab)) 35 | optimizer = optim.SGD(params=model.parameters(), lr=args.learning_rate) 36 | 37 | # **************************** callbacks *********************** 38 | logger.info("initializing callbacks") 39 | train_monitor = TrainingMonitor(file_dir=config['figure_dir'], arch=args.model) 40 | lr_scheduler = StepLR(optimizer=optimizer,lr=args.learning_rate, epochs=args.epochs) 41 | 42 | # **************************** training model *********************** 43 | logger.info('training model....') 44 | trainer = Trainer(model=model, 45 | vocab=train_dataset.vocab, 46 | optimizer=optimizer, 47 | epochs=args.epochs, 48 | logger=logger, 49 | training_monitor=train_monitor, 50 | lr_scheduler=lr_scheduler, 51 | n_gpu=args.n_gpus, 52 | model_save_path=config['model_save_path'], 53 | vector_save_path=config['pytorch_embedding_path'] 54 | ) 55 | trainer.train(train_data=train_dataset) 56 | 57 | 58 | def main(): 59 | parser = argparse.ArgumentParser(description='PyTorch Word2Vec model training') 60 | parser.add_argument("--model", type=str, default='skip_gram') 61 | parser.add_argument("--task", type=str, default='training word vector') 62 | parser.add_argument('--seed', default=2018, type=int, 63 | help='Seed for initializing training.') 64 | parser.add_argument('--resume', default=False, type=bool, 65 | help='Choose whether resume checkpoint model') 66 | parser.add_argument('--embedd_dim', default=300, type=int) 67 | parser.add_argument('--epochs', default=6, type=int) 68 | parser.add_argument('--window_size', default=5, type=int) 69 | parser.add_argument('--n_gpus', default='0', type=str) 70 | parser.add_argument('--min_freq', default=5, type=int) 71 | parser.add_argument('--sample', default=1e-3, type=float) 72 | parser.add_argument('--negative_sample_num', default=5, type=int) 73 | parser.add_argument('--learning_rate', default=0.025, type=float) 74 | parser.add_argument('--weight_decay', default=5e-4, type=float) 75 | parser.add_argument('--vocab_size', default=30000000, type=int) 76 | args = parser.parse_args() 77 | init_logger(log_file=config['log_dir'] / (args.model + ".log")) 78 | logger.info(f"seed is {args.seed}") 79 | seed_everything(seed=args.seed) 80 | run(args) 81 | 82 | if __name__ == '__main__': 83 | main() 84 | --------------------------------------------------------------------------------