├── .gitignore ├── README.md ├── capsule_layer.py ├── data └── .gitignore ├── m.py ├── main.py ├── output └── .gitignore ├── rnn_revised.py └── save └── .gitignore /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | .idea 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CapsNet pytorch实现(文本多分类) 2 | 3 | CapsNet based on Geoffrey Hinton's original paper 4 | [Dynamic Routing Between Capsules](https://arxiv.org/abs/1710.09829) 5 | 6 | [先读懂CapsNet架构然后用TensorFlow实现:全面解析Hinton提出的Capsule](https://www.jiqizhixin.com/articles/2017-11-05) 7 | 8 | ## Requirements 9 | 10 | - python 3.6+ 11 | - pytorch 0.4.1+ 12 | - gensim 13 | - tqdm 14 | 15 | ## Run 16 | 17 | ```bash 18 | python main.py 19 | ``` 20 | Train and test dataset should be included in data folder 21 | 22 | ## DIY 23 | 24 | If you need hard_sigmoid for GRU gate, just uncomment 25 | ```python 26 | from rnn_revised import * 27 | ``` 28 | in capsule_layer.py. You can also use whatever activation func 29 | or dropout/recurrent_dropout ratio you want and revise in rnn_revised.py doc. 30 | One more thing, the revise version is non-cuda, if you find a way 31 | out for cuda version please let me know. 32 | 33 | 注:
34 | 1. PrimaryCapsLayer中的squash压缩的是向量size是[batch_size, 1152, 8],在最后一个维度上进行压缩即维度8 35 | 压缩率|Sj|2/(1+|Sj|2)/|Sj|大小为[batch_size, 1152],然后与原来的输入向量相乘即可 36 | 37 | 2. 如果reconstruction为True,则loss由两部分组成margin_loss和reconstruction_loss
38 | ```python 39 | output, probs = model(data, target) 40 | reconstruction_loss = F.mse_loss(output, data.view(-1, 784)) 41 | margin_loss = loss_fn(probs, target) 42 | # 如果reconstruction为True,则loss由两部分组成margin_loss和reconstruction_loss 43 | loss = reconstruction_alpha * reconstruction_loss + margin_loss 44 | ``` 45 | -------------------------------------------------------------------------------- /capsule_layer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | author = 'BinZhou' 5 | nick_name = '发送小信号' 6 | mtime = '2018/10/19' 7 | 8 | import torch as t 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | # from rnn_revised import * 13 | 14 | USE_CUDA = True 15 | embedding_dim = 300 16 | embedding_path = '../save/embedding_matrix.npy' # or False, not use pre-trained-matrix 17 | use_pretrained_embedding = True 18 | BATCH_SIZE = 128 19 | gru_len = 128 20 | Routings = 5 21 | Num_capsule = 10 22 | Dim_capsule = 16 23 | dropout_p = 0.25 24 | rate_drop_dense = 0.28 25 | LR = 0.001 26 | T_epsilon = 1e-7 27 | num_classes = 30 28 | 29 | 30 | class Embed_Layer(nn.Module): 31 | def __init__(self, embedding_matrix=None, vocab_size=None, embedding_dim=300): 32 | super(Embed_Layer, self).__init__() 33 | self.encoder = nn.Embedding(vocab_size + 1, embedding_dim) 34 | if use_pretrained_embedding: 35 | # self.encoder.weight.data.copy_(t.from_numpy(np.load(embedding_path))) # 方法一,加载np.save的npy文件 36 | self.encoder.weight.data.copy_(t.from_numpy(embedding_matrix)) # 方法二 37 | 38 | def forward(self, x, dropout_p=0.25): 39 | return nn.Dropout(p=dropout_p)(self.encoder(x)) 40 | 41 | 42 | class GRU_Layer(nn.Module): 43 | def __init__(self): 44 | super(GRU_Layer, self).__init__() 45 | self.gru = nn.GRU(input_size=300, 46 | hidden_size=gru_len, 47 | bidirectional=True) 48 | ''' 49 | 自己修改GRU里面的激活函数及加dropout和recurrent_dropout 50 | 如果要使用,把rnn_revised import进来,但好像是使用cpu跑的,比较慢 51 | ''' 52 | # # if you uncomment /*from rnn_revised import * */, uncomment following code aswell 53 | # self.gru = RNNHardSigmoid('GRU', input_size=300, 54 | # hidden_size=gru_len, 55 | # bidirectional=True) 56 | 57 | # 这步很关键,需要像keras一样用glorot_uniform和orthogonal_uniform初始化参数 58 | def init_weights(self): 59 | ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name) 60 | hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name) 61 | b = (param.data for name, param in self.named_parameters() if 'bias' in name) 62 | for k in ih: 63 | nn.init.xavier_uniform_(k) 64 | for k in hh: 65 | nn.init.orthogonal_(k) 66 | for k in b: 67 | nn.init.constant_(k, 0) 68 | 69 | def forward(self, x): 70 | return self.gru(x) 71 | 72 | 73 | # core caps_layer with squash func 74 | class Caps_Layer(nn.Module): 75 | def __init__(self, input_dim_capsule=gru_len * 2, num_capsule=Num_capsule, dim_capsule=Dim_capsule, \ 76 | routings=Routings, kernel_size=(9, 1), share_weights=True, 77 | activation='default', **kwargs): 78 | super(Caps_Layer, self).__init__(**kwargs) 79 | 80 | self.num_capsule = num_capsule 81 | self.dim_capsule = dim_capsule 82 | self.routings = routings 83 | self.kernel_size = kernel_size # 暂时没用到 84 | self.share_weights = share_weights 85 | if activation == 'default': 86 | self.activation = self.squash 87 | else: 88 | self.activation = nn.ReLU(inplace=True) 89 | 90 | if self.share_weights: 91 | self.W = nn.Parameter( 92 | nn.init.xavier_normal_(t.empty(1, input_dim_capsule, self.num_capsule * self.dim_capsule))) 93 | else: 94 | self.W = nn.Parameter( 95 | t.randn(BATCH_SIZE, input_dim_capsule, self.num_capsule * self.dim_capsule)) # 64即batch_size 96 | 97 | def forward(self, x): 98 | 99 | if self.share_weights: 100 | u_hat_vecs = t.matmul(x, self.W) 101 | else: 102 | print('add later') 103 | 104 | batch_size = x.size(0) 105 | input_num_capsule = x.size(1) 106 | u_hat_vecs = u_hat_vecs.view((batch_size, input_num_capsule, 107 | self.num_capsule, self.dim_capsule)) 108 | u_hat_vecs = u_hat_vecs.permute(0, 2, 1, 3) # 转成(batch_size,num_capsule,input_num_capsule,dim_capsule) 109 | b = t.zeros_like(u_hat_vecs[:, :, :, 0]) # (batch_size,num_capsule,input_num_capsule) 110 | 111 | for i in range(self.routings): 112 | b = b.permute(0, 2, 1) 113 | c = F.softmax(b, dim=2) 114 | c = c.permute(0, 2, 1) 115 | b = b.permute(0, 2, 1) 116 | outputs = self.activation(t.einsum('bij,bijk->bik', (c, u_hat_vecs))) # batch matrix multiplication 117 | # outputs shape (batch_size, num_capsule, dim_capsule) 118 | if i < self.routings - 1: 119 | b = t.einsum('bik,bijk->bij', (outputs, u_hat_vecs)) # batch matrix multiplication 120 | return outputs # (batch_size, num_capsule, dim_capsule) 121 | 122 | # text version of squash, slight different from original one 123 | def squash(self, x, axis=-1): 124 | s_squared_norm = (x ** 2).sum(axis, keepdim=True) 125 | scale = t.sqrt(s_squared_norm + T_epsilon) 126 | return x / scale 127 | 128 | 129 | class Dense_Layer(nn.Module): 130 | def __init__(self): 131 | super(Dense_Layer, self).__init__() 132 | self.fc = nn.Sequential( 133 | nn.Dropout(p=dropout_p, inplace=True), 134 | nn.Linear(Num_capsule * Dim_capsule, num_classes), # num_capsule*dim_capsule -> num_classes 135 | nn.Sigmoid() 136 | ) 137 | 138 | def forward(self, x): 139 | batch_size = x.size(0) 140 | x = x.view(batch_size, -1) 141 | return self.fc(x) 142 | 143 | 144 | # capsule如果单纯做分类则不需要重构(reconstruction) 145 | # 如果就用在分类里面,decoder用不到,不需要reconstruction 146 | 147 | class Capsule_Main(nn.Module): 148 | def __init__(self, embedding_matrix=None, vocab_size=None): 149 | super(Capsule_Main, self).__init__() 150 | self.embed_layer = Embed_Layer(embedding_matrix, vocab_size) 151 | self.gru_layer = GRU_Layer() 152 | # 【重要】初始化GRU权重操作,这一步非常关键,acc上升到0.98,如果用默认的uniform初始化则acc一直在0.5左右 153 | self.gru_layer.init_weights() 154 | self.caps_layer = Caps_Layer() 155 | self.dense_layer = Dense_Layer() 156 | 157 | def forward(self, content): 158 | content1 = self.embed_layer(content) 159 | content2, _ = self.gru_layer( 160 | content1) # 这个输出是个tuple,一个output(batch_size, seq_len, num_directions * hidden_size),一个hn 161 | content3 = self.caps_layer(content2) 162 | output = self.dense_layer(content3) 163 | return output -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /m.py: -------------------------------------------------------------------------------- 1 | author = 'binzhou' 2 | mdtime = '2018/10/15' 3 | 4 | import pandas as pd 5 | import numpy as np 6 | from tqdm import tqdm 7 | import torch as t 8 | import time 9 | 10 | def f1_for_car(df_true:pd.DataFrame, df_pred:pd.DataFrame): 11 | 12 | ''' 13 | f1评分标准 for 汽车行业用户观点主题及情感识别(DF竞赛) 14 | ''' 15 | 16 | 17 | Tp, Fp, Fn = 0, 0, 0 18 | for cnt_id in set(df_true.content_id): 19 | y_true = df_true[df_true.content_id==cnt_id].copy(deep=True) 20 | y_pred = df_pred[df_pred.content_id==cnt_id].copy(deep=True) 21 | if len(y_true) > len(y_pred): 22 | Fn += len(y_true) - len(y_pred) 23 | tp = y_pred.merge(y_true,on=['subject','sentiment_value'],how='left').content_id_y.notnull().sum() 24 | Tp += tp 25 | Fp += len(y_pred) - tp 26 | elif len(y_true) < len(y_pred): 27 | Fp += len(y_pred) - len(y_true) 28 | tp = y_true.merge(y_pred,on=['subject','sentiment_value'],how='left').content_id_y.notnull().sum() 29 | Tp += tp 30 | Fp += len(y_true) - tp 31 | else: 32 | tp = y_true.merge(y_pred,on=['subject','sentiment_value'],how='left').content_id_y.notnull().sum() 33 | Tp += tp 34 | Fp += len(y_true) - tp 35 | P = Tp*1.0/(Tp+Fp) 36 | R = Tp*1.0/(Tp+Fn) 37 | return 2*P*R/(P+R) 38 | 39 | 40 | 41 | class BOW(object): 42 | def __init__(self, X, min_count=10, maxlen=100): 43 | """ 44 | X: [[w1, w2],]] 45 | """ 46 | self.X = X 47 | self.min_count = min_count 48 | self.maxlen = maxlen 49 | self.__word_count() 50 | self.__idx() 51 | self.__doc2num() 52 | 53 | def __word_count(self): 54 | wc = {} 55 | for ws in tqdm(self.X, desc=' Word Count'): 56 | for w in ws: 57 | if w in wc: 58 | wc[w] += 1 59 | else: 60 | wc[w] = 1 61 | self.word_count = {i: j for i, j in wc.items() if j >= self.min_count} 62 | 63 | def __idx(self): 64 | self.idx2word = {i + 1: j for i, j in enumerate(self.word_count)} 65 | self.word2idx = {j: i for i, j in self.idx2word.items()} 66 | 67 | def __doc2num(self): 68 | doc2num = [] 69 | for text in tqdm(self.X, desc='Doc To Number'): 70 | s = [self.word2idx.get(i, 0) for i in text[:self.maxlen]] 71 | doc2num.append(s + [0]*(self.maxlen-len(s))) # 未登录词全部用0表示 72 | self.doc2num = np.asarray(doc2num) 73 | 74 | class BasicModule(t.nn.Module): 75 | ''' 76 | 封装了nn.Module,主要是提供了save和load两个方法 77 | ''' 78 | 79 | def __init__(self): 80 | super(BasicModule,self).__init__() 81 | self.model_name=str(type(self))# 默认名字 82 | 83 | def load(self, path,change_opt=True): 84 | print(path) 85 | data = t.load(path) 86 | if 'opt' in data: 87 | # old_opt_stats = self.opt.state_dict() 88 | if change_opt: 89 | 90 | self.opt.parse(data['opt'],print_=False) 91 | self.opt.embedding_path=None 92 | self.__init__(self.opt) 93 | # self.opt.parse(old_opt_stats,print_=False) 94 | self.load_state_dict(data['d']) 95 | else: 96 | self.load_state_dict(data) 97 | return self.cuda() 98 | 99 | def save(self, name=None,new=False): 100 | prefix = 'checkpoints/' + self.model_name + '_' +self.opt.type_+'_' 101 | if name is None: 102 | name = time.strftime('%m%d_%H:%M:%S.pth') 103 | path = prefix+name 104 | 105 | if new: 106 | data = {'opt':self.opt.state_dict(),'d':self.state_dict()} 107 | else: 108 | data=self.state_dict() 109 | 110 | t.save(data, path) 111 | return path 112 | 113 | def get_optimizer(self,lr1,lr2=0,weight_decay = 0): 114 | ignored_params = list(map(id, self.encoder.parameters())) 115 | base_params = filter(lambda p: id(p) not in ignored_params, 116 | self.parameters()) 117 | if lr2 is None: lr2 = lr1*0.5 118 | optimizer = t.optim.Adam([ 119 | dict(params=base_params,weight_decay = weight_decay,lr=lr1), 120 | {'params': self.encoder.parameters(), 'lr': lr2} 121 | ]) 122 | return optimizer -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | author = 'BinZhou' 5 | nick_name = '发送小信号' 6 | mtime = '2018/10/19' 7 | 8 | import torch.utils.data as Data 9 | import torch as t 10 | from torch import nn 11 | from torch.optim import Adam 12 | import torch.nn.functional as F 13 | import numpy as np 14 | import pandas as pd 15 | import jieba 16 | import gensim 17 | from gensim.models import Word2Vec, FastText 18 | from collections import Counter 19 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer 20 | # tfidf or countvec for lr or svm 21 | from sklearn.linear_model import LogisticRegression 22 | from sklearn.svm import SVC 23 | from sklearn.model_selection import StratifiedKFold 24 | from tqdm import tqdm, tqdm_notebook 25 | from sklearn.preprocessing import MultiLabelBinarizer 26 | from sklearn.metrics import accuracy_score 27 | import copy 28 | 29 | from m import f1_for_car, BOW, BasicModule 30 | from capsule_layer import * 31 | 32 | # 以训练数据为例 33 | data = pd.read_csv('data/train.csv') 34 | data['content'] = data.content.map(lambda x: ''.join(x.strip().split())) 35 | 36 | # 把主题和情感拼接起来,一共10*3类 37 | data['label'] = data['subject'] + data['sentiment_value'].astype(str) 38 | subj_lst = list(filter(lambda x : x is not np.nan, list(set(data.label)))) 39 | subj_lst_dic = {value:key for key, value in enumerate(subj_lst)} 40 | data['label'] = data['label'].apply(lambda x : subj_lst_dic.get(x)) 41 | 42 | # 处理同一个句子对应对标签的情况,然后进行MLB处理 43 | data_tmp = data.groupby('content').agg({'label':lambda x : set(x)}).reset_index() 44 | # [[1,0,0],[0,1,0],[0,0,1]] 45 | # 可能有多标签则[[1,1,0],[0,1,0],[0,0,1]] 46 | mlb = MultiLabelBinarizer() 47 | data_tmp['hh'] = mlb.fit_transform(data_tmp.label).tolist() 48 | y_train = np.array(data_tmp.hh.tolist()) 49 | 50 | # 构造embedding字典 51 | 52 | bow = BOW(data_tmp.content.apply(jieba.lcut).tolist(), min_count=1, maxlen=100) # 长度补齐或截断固定长度100 53 | 54 | # word2vec = Word2Vec(data_tmp.content.apply(jieba.lcut).tolist(),size=300,min_count=1) 55 | word2vec = gensim.models.KeyedVectors.load_word2vec_format('data/ft_wv.txt') # 读取txt文件的预训练词向量 56 | 57 | vocab_size = len(bow.word2idx) 58 | embedding_matrix = np.zeros((vocab_size+1, 300)) 59 | for key, value in bow.word2idx.items(): 60 | if key in word2vec.vocab: # Word2Vec训练得到的的实例需要word2vec.wv.vocab 61 | embedding_matrix[value] = word2vec.get_vector(key) 62 | else: 63 | embedding_matrix[value] = [0] * embedding_dim 64 | 65 | X_train = copy.deepcopy(bow.doc2num) 66 | y_train = copy.deepcopy(y_train) 67 | #-------------------------------------------- 68 | 69 | # 数据处理成tensor 70 | BATCH_SIZE = 64 71 | label_tensor = t.from_numpy(np.array(y_train)).float() 72 | content_tensor = t.from_numpy(np.array(X_train)).long() 73 | 74 | torch_dataset = Data.TensorDataset(content_tensor, label_tensor) 75 | train_loader = Data.DataLoader( 76 | dataset=torch_dataset, # torch TensorDataset format 77 | batch_size=BATCH_SIZE, # mini batch size 78 | shuffle=True, # random shuffle for training 79 | num_workers=8, # subprocesses for loading data 80 | ) 81 | 82 | # 网络结构、损失函数、优化器初始化 83 | capnet = Capsule_Main(embedding_matrix,vocab_size) # 加载预训练embedding matrix 84 | loss_func = nn.BCELoss() # 用二分类方法预测是否属于该类,而非多分类 85 | if USE_CUDA: 86 | capnet = capnet.cuda() # 把搭建的网络载入GPU 87 | loss_func.cuda() # 把损失函数载入GPU 88 | optimizer = Adam(capnet.parameters(),lr=LR) # 默认lr 89 | 90 | # 开始跑模型 91 | it = 1 92 | EPOCH = 30 93 | for epoch in tqdm_notebook(range(EPOCH)): 94 | for batch_id, (data, target) in enumerate(train_loader): 95 | if USE_CUDA: 96 | data, target = data.cuda(), target.cuda() # 数据载入GPU 97 | output = capnet(data) 98 | loss = loss_func(output, target) 99 | if it % 50 == 0: 100 | print('training loss: ', loss.cpu().data.numpy().tolist()) 101 | print('training acc: ', accuracy_score(np.argmax(target.cpu().data.numpy(),axis=1), np.argmax(output.cpu().data.numpy(),axis=1))) 102 | optimizer.zero_grad() # clear gradients for this training step 103 | loss.backward() # backpropagation, compute gradients 104 | optimizer.step() # apply gradients 105 | it += 1 106 | 107 | -------------------------------------------------------------------------------- /output/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /rnn_revised.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | """ Implement a pyTorch LSTM or GRU with hard sigmoid reccurent activation functions. 4 | Adapted from the non-cuda variant of pyTorch LSTM or GRU at 5 | https://github.com/pytorch/pytorch/blob/master/torch/nn/_functions/rnn.py 6 | """ 7 | 8 | author = 'BinZhou' 9 | nick_name = '发送小信号' 10 | mtime = '2018/10/26' 11 | 12 | from __future__ import print_function, division 13 | import math 14 | import warnings 15 | import itertools 16 | import numbers 17 | import torch 18 | 19 | from torch.nn import Module 20 | from torch.nn.parameter import Parameter 21 | from torch.nn.utils.rnn import PackedSequence 22 | import torch.nn.functional as F 23 | import torch.nn._functions.thnn.rnnFusedPointwise as fusedBackend 24 | import torch.nn as nn 25 | 26 | class RNNHardSigmoid(Module): 27 | 28 | def __init__(self, mode, input_size, hidden_size, 29 | num_layers=1, bias=True, batch_first=False, 30 | dropout=0, bidirectional=False): 31 | super(RNNHardSigmoid, self).__init__() 32 | self.mode = mode 33 | self.input_size = input_size 34 | self.hidden_size = hidden_size 35 | self.num_layers = num_layers 36 | self.bias = bias 37 | self.batch_first = batch_first 38 | self.dropout = dropout 39 | self.dropout_state = {} 40 | self.bidirectional = bidirectional 41 | num_directions = 2 if bidirectional else 1 42 | 43 | if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \ 44 | isinstance(dropout, bool): 45 | raise ValueError("dropout should be a number in range [0, 1] " 46 | "representing the probability of an element being " 47 | "zeroed") 48 | if dropout > 0 and num_layers == 1: 49 | warnings.warn("dropout option adds dropout after all but last " 50 | "recurrent layer, so non-zero dropout expects " 51 | "num_layers greater than 1, but got dropout={} and " 52 | "num_layers={}".format(dropout, num_layers)) 53 | 54 | if mode == 'LSTM': 55 | gate_size = 4 * hidden_size 56 | elif mode == 'GRU': 57 | gate_size = 3 * hidden_size 58 | else: 59 | gate_size = hidden_size 60 | 61 | self._all_weights = [] 62 | for layer in range(num_layers): 63 | for direction in range(num_directions): 64 | layer_input_size = input_size if layer == 0 else hidden_size * num_directions 65 | 66 | w_ih = Parameter(torch.Tensor(gate_size, layer_input_size)) 67 | w_hh = Parameter(torch.Tensor(gate_size, hidden_size)) 68 | b_ih = Parameter(torch.Tensor(gate_size)) 69 | b_hh = Parameter(torch.Tensor(gate_size)) 70 | layer_params = (w_ih, w_hh, b_ih, b_hh) 71 | 72 | suffix = '_reverse' if direction == 1 else '' 73 | param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}'] 74 | if bias: 75 | param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}'] 76 | param_names = [x.format(layer, suffix) for x in param_names] 77 | 78 | for name, param in zip(param_names, layer_params): 79 | setattr(self, name, param) 80 | self._all_weights.append(param_names) 81 | 82 | self.flatten_parameters() 83 | self.reset_parameters() 84 | 85 | def flatten_parameters(self): 86 | """Resets parameter data pointer so that they can use faster code paths. 87 | 88 | Right now, this works only if the module is on the GPU and cuDNN is enabled. 89 | Otherwise, it's a no-op. 90 | """ 91 | any_param = next(self.parameters()).data 92 | if not any_param.is_cuda or not torch.backends.cudnn.is_acceptable(any_param): 93 | self._data_ptrs = [] 94 | return 95 | 96 | # If any parameters alias, we fall back to the slower, copying code path. This is 97 | # a sufficient check, because overlapping parameter buffers that don't completely 98 | # alias would break the assumptions of the uniqueness check in 99 | # Module.named_parameters(). 100 | unique_data_ptrs = set(p.data_ptr() for l in self.all_weights for p in l) 101 | if len(unique_data_ptrs) != sum(len(l) for l in self.all_weights): 102 | self._data_ptrs = [] 103 | return 104 | 105 | with torch.cuda.device_of(any_param): 106 | import torch.backends.cudnn.rnn as rnn 107 | 108 | weight_arr = list(itertools.chain.from_iterable(self.all_weights)) 109 | weight_stride0 = len(self.all_weights[0]) 110 | 111 | # NB: This is a temporary hack while we still don't have Tensor 112 | # bindings for ATen functions 113 | with torch.no_grad(): 114 | # NB: this is an INPLACE function on weight_arr, that's why the 115 | # no_grad() is necessary. 116 | weight_buf = torch._cudnn_rnn_flatten_weight( 117 | weight_arr, weight_stride0, 118 | self.input_size, rnn.get_cudnn_mode(self.mode), self.hidden_size, self.num_layers, 119 | self.batch_first, bool(self.bidirectional)) 120 | 121 | self._param_buf_size = weight_buf.size(0) 122 | self._data_ptrs = list(p.data.data_ptr() for p in self.parameters()) 123 | 124 | def _apply(self, fn): 125 | ret = super(RNNHardSigmoid, self)._apply(fn) 126 | self.flatten_parameters() 127 | return ret 128 | 129 | def reset_parameters(self): 130 | stdv = 1.0 / math.sqrt(self.hidden_size) 131 | for weight in self.parameters(): 132 | weight.data.uniform_(-stdv, stdv) 133 | 134 | def check_forward_args(self, input, hidden, batch_sizes): 135 | is_input_packed = batch_sizes is not None 136 | expected_input_dim = 2 if is_input_packed else 3 137 | if input.dim() != expected_input_dim: 138 | raise RuntimeError( 139 | 'input must have {} dimensions, got {}'.format( 140 | expected_input_dim, input.dim())) 141 | if self.input_size != input.size(-1): 142 | raise RuntimeError( 143 | 'input.size(-1) must be equal to input_size. Expected {}, got {}'.format( 144 | self.input_size, input.size(-1))) 145 | 146 | if is_input_packed: 147 | mini_batch = int(batch_sizes[0]) 148 | else: 149 | mini_batch = input.size(0) if self.batch_first else input.size(1) 150 | 151 | num_directions = 2 if self.bidirectional else 1 152 | expected_hidden_size = (self.num_layers * num_directions, 153 | mini_batch, self.hidden_size) 154 | 155 | def check_hidden_size(hx, expected_hidden_size, msg='Expected hidden size {}, got {}'): 156 | if tuple(hx.size()) != expected_hidden_size: 157 | raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size()))) 158 | 159 | if self.mode == 'LSTM': 160 | check_hidden_size(hidden[0], expected_hidden_size, 161 | 'Expected hidden[0] size {}, got {}') 162 | check_hidden_size(hidden[1], expected_hidden_size, 163 | 'Expected hidden[1] size {}, got {}') 164 | else: 165 | check_hidden_size(hidden, expected_hidden_size) 166 | 167 | def forward(self, input, hx=None): 168 | is_packed = isinstance(input, PackedSequence) 169 | if is_packed: 170 | input, batch_sizes = input 171 | max_batch_size = int(batch_sizes[0]) 172 | else: 173 | batch_sizes = None 174 | max_batch_size = input.size(0) if self.batch_first else input.size(1) 175 | 176 | if hx is None: 177 | num_directions = 2 if self.bidirectional else 1 178 | hx = input.new_zeros(self.num_layers * num_directions, 179 | max_batch_size, self.hidden_size, 180 | requires_grad=False) 181 | if self.mode == 'LSTM': 182 | hx = (hx, hx) 183 | 184 | has_flat_weights = list(p.data.data_ptr() for p in self.parameters()) == self._data_ptrs 185 | if has_flat_weights: 186 | first_data = next(self.parameters()).data 187 | assert first_data.storage().size() == self._param_buf_size 188 | flat_weight = first_data.new().set_(first_data.storage(), 0, torch.Size([self._param_buf_size])) 189 | else: 190 | flat_weight = None 191 | 192 | self.check_forward_args(input, hx, batch_sizes) 193 | func = AutogradRNN( 194 | self.mode, 195 | self.input_size, 196 | self.hidden_size, 197 | num_layers=self.num_layers, 198 | batch_first=self.batch_first, 199 | dropout=self.dropout, 200 | train=self.training, 201 | bidirectional=self.bidirectional, 202 | dropout_state=self.dropout_state, 203 | variable_length=is_packed, 204 | flat_weight=flat_weight 205 | ) 206 | output, hidden = func(input, self.all_weights, hx, batch_sizes) 207 | if is_packed: 208 | output = PackedSequence(output, batch_sizes) 209 | return output, hidden 210 | 211 | def extra_repr(self): 212 | s = '{input_size}, {hidden_size}' 213 | if self.num_layers != 1: 214 | s += ', num_layers={num_layers}' 215 | if self.bias is not True: 216 | s += ', bias={bias}' 217 | if self.batch_first is not False: 218 | s += ', batch_first={batch_first}' 219 | if self.dropout != 0: 220 | s += ', dropout={dropout}' 221 | if self.bidirectional is not False: 222 | s += ', bidirectional={bidirectional}' 223 | return s.format(**self.__dict__) 224 | 225 | def __setstate__(self, d): 226 | super(RNNHardSigmoid, self).__setstate__(d) 227 | self.__dict__.setdefault('_data_ptrs', []) 228 | if 'all_weights' in d: 229 | self._all_weights = d['all_weights'] 230 | if isinstance(self._all_weights[0][0], str): 231 | return 232 | num_layers = self.num_layers 233 | num_directions = 2 if self.bidirectional else 1 234 | self._all_weights = [] 235 | for layer in range(num_layers): 236 | for direction in range(num_directions): 237 | suffix = '_reverse' if direction == 1 else '' 238 | weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}'] 239 | weights = [x.format(layer, suffix) for x in weights] 240 | if self.bias: 241 | self._all_weights += [weights] 242 | else: 243 | self._all_weights += [weights[:2]] 244 | 245 | @property 246 | def all_weights(self): 247 | return [[getattr(self, weight) for weight in weights] for weights in self._all_weights] 248 | 249 | # if cudnn.is_acceptable(input.data)为True的时候用CudnnRNN跑的,而不是AutogradRNN 250 | # 但是CudnnRNN里面修改不了GRU或LSTMcell,不是用python写的,好像是动态链接.so文件 251 | 252 | def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False, 253 | dropout=0, train=True, bidirectional=False, variable_length=False, 254 | dropout_state=None, flat_weight=None): 255 | 256 | if mode == 'LSTM': 257 | cell = LSTMCell 258 | elif mode == 'GRU': 259 | cell = GRUCell 260 | else: 261 | raise Exception('Unknown mode: {}'.format(mode)) 262 | 263 | rec_factory = variable_recurrent_factory if variable_length else Recurrent 264 | 265 | if bidirectional: 266 | layer = (rec_factory(cell), rec_factory(cell, reverse=True)) 267 | else: 268 | layer = (rec_factory(cell),) 269 | 270 | func = StackedRNN(layer, 271 | num_layers, 272 | (mode == 'LSTM'), 273 | dropout=dropout, 274 | train=train) 275 | 276 | def forward(input, weight, hidden, batch_sizes): 277 | if batch_first and not variable_length: 278 | input = input.transpose(0, 1) 279 | 280 | nexth, output = func(input, hidden, weight, batch_sizes) 281 | 282 | if batch_first and not variable_length: 283 | output = output.transpose(0, 1) 284 | 285 | return output, nexth 286 | 287 | return forward 288 | 289 | ###func-------------------------------------------------------------------------------------### 290 | def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True): 291 | 292 | num_directions = len(inners) 293 | total_layers = num_layers * num_directions 294 | 295 | def forward(input, hidden, weight, batch_sizes): 296 | #input = nn.Dropout(p=0.25)(input) 297 | #hidden = nn.Dropout(p=0.28)(hidden) 298 | assert(len(weight) == total_layers) 299 | next_hidden = [] 300 | 301 | if lstm: 302 | hidden = list(zip(*hidden)) 303 | 304 | for i in range(num_layers): 305 | all_output = [] 306 | for j, inner in enumerate(inners): 307 | l = i * num_directions + j 308 | 309 | hy, output = inner(input, hidden[l], weight[l], batch_sizes) 310 | next_hidden.append(hy) 311 | all_output.append(output) 312 | 313 | input = torch.cat(all_output, input.dim() - 1) 314 | 315 | if dropout != 0 and i < num_layers - 1: 316 | input = F.dropout(input, p=dropout, training=train, inplace=False) 317 | 318 | if lstm: 319 | next_h, next_c = zip(*next_hidden) 320 | next_hidden = ( 321 | torch.cat(next_h, 0).view(total_layers, *next_h[0].size()), 322 | torch.cat(next_c, 0).view(total_layers, *next_c[0].size()) 323 | ) 324 | else: 325 | next_hidden = torch.cat(next_hidden, 0).view( 326 | total_layers, *next_hidden[0].size()) 327 | 328 | return next_hidden, input 329 | 330 | return forward 331 | 332 | 333 | def Recurrent(inner, reverse=False): 334 | def forward(input, hidden, weight, batch_sizes): 335 | output = [] 336 | steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0)) 337 | for i in steps: 338 | hidden = inner(input[i], hidden, *weight) 339 | # hack to handle LSTM 340 | output.append(hidden[0] if isinstance(hidden, tuple) else hidden) 341 | 342 | if reverse: 343 | output.reverse() 344 | output = torch.cat(output, 0).view(input.size(0), *output[0].size()) 345 | 346 | return hidden, output 347 | 348 | return forward 349 | 350 | 351 | def variable_recurrent_factory(inner, reverse=False): 352 | if reverse: 353 | return VariableRecurrentReverse(inner) 354 | else: 355 | return VariableRecurrent(inner) 356 | 357 | 358 | def VariableRecurrent(inner): 359 | def forward(input, hidden, weight, batch_sizes): 360 | 361 | output = [] 362 | input_offset = 0 363 | last_batch_size = batch_sizes[0] 364 | hiddens = [] 365 | flat_hidden = not isinstance(hidden, tuple) 366 | if flat_hidden: 367 | hidden = (hidden,) 368 | for batch_size in batch_sizes: 369 | step_input = input[input_offset:input_offset + batch_size] 370 | input_offset += batch_size 371 | 372 | dec = last_batch_size - batch_size 373 | if dec > 0: 374 | hiddens.append(tuple(h[-dec:] for h in hidden)) 375 | hidden = tuple(h[:-dec] for h in hidden) 376 | last_batch_size = batch_size 377 | 378 | if flat_hidden: 379 | hidden = (inner(step_input, hidden[0], *weight),) 380 | else: 381 | hidden = inner(step_input, hidden, *weight) 382 | 383 | output.append(hidden[0]) 384 | hiddens.append(hidden) 385 | hiddens.reverse() 386 | 387 | hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens)) 388 | assert hidden[0].size(0) == batch_sizes[0] 389 | if flat_hidden: 390 | hidden = hidden[0] 391 | output = torch.cat(output, 0) 392 | 393 | return hidden, output 394 | 395 | return forward 396 | 397 | 398 | def VariableRecurrentReverse(inner): 399 | def forward(input, hidden, weight, batch_sizes): 400 | output = [] 401 | input_offset = input.size(0) 402 | last_batch_size = batch_sizes[-1] 403 | initial_hidden = hidden 404 | flat_hidden = not isinstance(hidden, tuple) 405 | if flat_hidden: 406 | hidden = (hidden,) 407 | initial_hidden = (initial_hidden,) 408 | hidden = tuple(h[:batch_sizes[-1]] for h in hidden) 409 | for i in reversed(range(len(batch_sizes))): 410 | batch_size = batch_sizes[i] 411 | inc = batch_size - last_batch_size 412 | if inc > 0: 413 | hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0) 414 | for h, ih in zip(hidden, initial_hidden)) 415 | last_batch_size = batch_size 416 | step_input = input[input_offset - batch_size:input_offset] 417 | input_offset -= batch_size 418 | 419 | if flat_hidden: 420 | hidden = (inner(step_input, hidden[0], *weight),) 421 | else: 422 | hidden = inner(step_input, hidden, *weight) 423 | output.append(hidden[0]) 424 | 425 | output.reverse() 426 | output = torch.cat(output, 0) 427 | if flat_hidden: 428 | hidden = hidden[0] 429 | return hidden, output 430 | 431 | return forward 432 | ###func end---------------------------------------------------------------------------------### 433 | 434 | def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): 435 | if input.is_cuda: 436 | igates = F.linear(input, w_ih) 437 | hgates = F.linear(hidden[0], w_hh) 438 | state = fusedBackend.LSTMFused.apply 439 | return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh) 440 | 441 | hx, cx = hidden 442 | gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) 443 | 444 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 445 | 446 | ingate = hard_sigmoid(ingate) 447 | forgetgate = hard_sigmoid(forgetgate) 448 | cellgate = torch.tanh(cellgate) 449 | outgate = hard_sigmoid(outgate) 450 | 451 | cy = (forgetgate * cx) + (ingate * cellgate) 452 | hy = outgate * torch.tanh(cy) 453 | 454 | return hy, cy 455 | 456 | def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): 457 | 458 | if input.is_cuda: 459 | gi = F.linear(input, w_ih) 460 | gh = F.linear(hidden, w_hh) 461 | state = fusedBackend.GRUFused.apply 462 | return state(gi, gh, hidden) if b_ih is None else state(gi, gh, hidden, b_ih, b_hh) 463 | 464 | gi = F.linear(input, w_ih, b_ih) 465 | gh = F.linear(hidden, w_hh, b_hh) 466 | i_r, i_i, i_n = gi.chunk(3, 1) 467 | h_r, h_i, h_n = gh.chunk(3, 1) 468 | 469 | resetgate = hard_sigmoid(i_r + h_r) 470 | inputgate = hard_sigmoid(i_i + h_i) 471 | # 可以用relu或其他激活函数 instead of tanh 472 | newgate = torch.tanh(i_n + resetgate * h_n) 473 | hy = newgate + inputgate * (hidden - newgate) 474 | 475 | return hy 476 | 477 | # 这里用了hard_sigmoid激活函数 478 | def hard_sigmoid(x): 479 | """ 480 | Computes element-wise hard sigmoid of x. 481 | See e.g. https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/sigm.py#L279 482 | """ 483 | x = (0.2 * x) + 0.5 484 | x = F.threshold(-x, -1, -1) 485 | x = F.threshold(-x, 0, 0) 486 | return x -------------------------------------------------------------------------------- /save/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | --------------------------------------------------------------------------------