├── .gitignore ├── requirements.txt ├── vocab └── chinese_vocab.model ├── README.md ├── database.py ├── api.py └── asoulgenerate.py /.gitignore: -------------------------------------------------------------------------------- 1 | asoul_cpm/* 2 | asoul_cpm* 3 | *.db 4 | __pycache__ 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers 3 | numpy 4 | tqdm 5 | sklearn 6 | flask 7 | sentencepiece 8 | jieba 9 | -------------------------------------------------------------------------------- /vocab/chinese_vocab.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infinityedge01/ASOUL-Generator-Backend/HEAD/vocab/chinese_vocab.model -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ASOUL-Generator-Backend 2 | 3 | 本项目为 https://asoul.infedg.xyz/ 的后端。 4 | 模型为基于 [CPM-Distill](https://github.com/TsinghuaAI/CPM-1-Distill) 的 [transformers](https://github.com/huggingface/transformers) 转化版本 [CPM-Generate-distill](https://huggingface.co/mymusise/CPM-Generate-distill/tree/main) 训练而成。 5 | 训练数据集: 6 | 7 | - [asoul.icu](http://asoul.icu/) 8 | - [枝江作文展](https://asoulcnki.asia/rank) 9 | 10 | ## 运行方式 11 | 12 | #### 下载模型 13 | 14 | [下载链接](https://disk.pku.edu.cn/#/link/88F0D3C9839329210503C7E50634AAFE) 15 | 16 | 需要将文件夹内的两个文件(`pytorch_model.bin` 和 `config.json`) 放入 `asoul_cpm` 文件夹下。 17 | 18 | 模型会不定期更新。 19 | 20 | #### 安装依赖 21 | ```bash 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | #### 运行后端 26 | 27 | ```bash 28 | python3 api.py 29 | ``` 30 | 31 | 此时后端运行在 `5089` 端口。 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /database.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import os 3 | import time 4 | class Database: 5 | def __init__(self): 6 | db_exists = os.path.exists("generated.db") 7 | db_conn = sqlite3.connect("generated.db") 8 | db = db_conn.cursor() 9 | if not db_exists: 10 | db.execute( 11 | '''CREATE TABLE Data( 12 | qqid INTEGER PRIMARY KEY AUTOINCREMENT, 13 | time INTEGER, 14 | prefix BLOB, 15 | generated BLOB)''') 16 | db_conn.commit() 17 | db_conn.close() 18 | def insert_data(self, prefix, generated): 19 | db_conn = sqlite3.connect("generated.db") 20 | db = db_conn.cursor() 21 | current_time = int(time.time()) 22 | db.execute("INSERT INTO Data (time, prefix, generated) VALUES(?,?,?)", (current_time, prefix.encode('utf-8'), generated.encode('utf-8'))) 23 | db_conn.commit() 24 | db_conn.close() 25 | def query_data(self, count): 26 | db_conn = sqlite3.connect("generated.db") 27 | db = db_conn.cursor() 28 | sql_info = list(db.execute( 29 | "SELECT time, prefix, generated FROM Data ORDER BY time DESC limit 0,?", (count,))) 30 | total = list(db.execute("SELECT COUNT(time) FROM Data")) 31 | return sql_info, total[0][0] 32 | -------------------------------------------------------------------------------- /api.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, abort, request, jsonify 2 | import json 3 | import datetime 4 | import asoulgenerate 5 | import database 6 | app = Flask(__name__) 7 | app.debug = False 8 | readfile = {} 9 | servetime = {} 10 | db = database.Database() 11 | @app.after_request 12 | def cors(environ): 13 | environ.headers['Access-Control-Allow-Origin']='*' 14 | environ.headers['Access-Control-Allow-Method']='*' 15 | environ.headers['Access-Control-Allow-Headers']='*' 16 | return environ 17 | 18 | @app.route('/generate',methods=['post']) 19 | def generate(): 20 | if not request.data: 21 | ret_dict = {"code": -1, "state": "No Request Data"} 22 | return jsonify(ret_dict) 23 | global servetime 24 | ip = request.headers.get("X-Real-Ip") 25 | print(ip) 26 | currtime = datetime.datetime.now() 27 | if ip in servetime: 28 | lasttime = servetime[ip] 29 | if (currtime - lasttime) < datetime.timedelta(seconds = 30): 30 | ret_dict = {"code": -4, "state": "Request Too Frequently, Please Wait For {} Seconds.".format((lasttime - currtime + datetime.timedelta(seconds = 30)).seconds)} 31 | return jsonify(ret_dict) 32 | data = request.data.decode('utf-8') 33 | data = json.loads(data) 34 | print(data) 35 | prefix = data['prefix'] 36 | if len(prefix) > 1000: 37 | ret_dict = {"code": -2, "state": "Prefix is Too Long"} 38 | return jsonify(ret_dict) 39 | if asoulgenerate.is_processing: 40 | ret_dict = {"code": -3, "state": "Server is Busy, Please Try Again Later."} 41 | return jsonify(ret_dict) 42 | prefix, generated = asoulgenerate.process(prefix) 43 | db.insert_data(prefix, generated) 44 | ret_dict = {"code": 0, "state": "success", "reply": {"prefix" : prefix, "generated": generated}} 45 | servetime[ip] = datetime.datetime.now() 46 | print(ret_dict) 47 | return jsonify(ret_dict) 48 | 49 | @app.route('/query',methods=['get']) 50 | def query(): 51 | args = request.args 52 | wd = args.get("count") 53 | if wd != 5 and wd != '5': 54 | print('error') 55 | ret_dict = {"code": 0, "state": "success", "reply": [], "count": 0} 56 | return jsonify(ret_dict) 57 | ret, count = db.query_data(wd) 58 | lst = [] 59 | for x in ret: 60 | lst.append({'time': x[0], 'prefix': x[1].decode('utf-8'), 'generated' : x[2].decode('utf-8')}) 61 | ret_dict = {"code": 0, "state": "success", "reply": lst, "count" : count} 62 | # print(ret_dict) 63 | return jsonify(ret_dict) 64 | 65 | if __name__ == '__main__': 66 | app.run(host = '0.0.0.0', port=5089) 67 | #这里指定了地址和端口号。 68 | -------------------------------------------------------------------------------- /asoulgenerate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import os 4 | from tqdm import trange 5 | from transformers import GPT2LMHeadModel, CpmTokenizer 6 | eod_id = 7 7 | def is_word(word): 8 | for item in list(word): 9 | if item not in 'qwertyuiopasdfghjklzxcvbnm': 10 | return False 11 | return True 12 | 13 | 14 | def _is_chinese_char(char): 15 | """Checks whether CP is the codepoint of a CJK character.""" 16 | # This defines a "chinese character" as anything in the CJK Unicode block: 17 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 18 | # 19 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 20 | # despite its name. The modern Korean Hangul alphabet is a different block, 21 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 22 | # space-separated words, so they are not treated specially and handled 23 | # like the all of the other languages. 24 | cp = ord(char) 25 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 26 | (cp >= 0x3400 and cp <= 0x4DBF) or # 27 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 28 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 29 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 30 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 31 | (cp >= 0xF900 and cp <= 0xFAFF) or # 32 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 33 | return True 34 | 35 | return False 36 | 37 | 38 | def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): 39 | """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 40 | Args: 41 | logits: logits distribution shape (vocabulary size) 42 | top_k > 0: keep only top k tokens with highest probability (top-k filtering). 43 | top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 44 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 45 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 46 | """ 47 | assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear 48 | top_k = min(top_k, logits.size(-1)) # Safety check 49 | if top_k > 0: 50 | # Remove all tokens with a probability less than the last token of the top-k 51 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 52 | logits[indices_to_remove] = filter_value 53 | 54 | if top_p > 0.0: 55 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 56 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 57 | 58 | # Remove tokens with cumulative probability above the threshold 59 | sorted_indices_to_remove = cumulative_probs > top_p 60 | # Shift the indices to the right to keep also the first token above the threshold 61 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 62 | sorted_indices_to_remove[..., 0] = 0 63 | 64 | indices_to_remove = sorted_indices[sorted_indices_to_remove] 65 | logits[indices_to_remove] = filter_value 66 | return logits 67 | 68 | 69 | def sample_sequence(model, context, length, n_ctx, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0, 70 | device='cpu'): 71 | context = torch.tensor(context, dtype=torch.long, device=device) 72 | context = context.unsqueeze(0) 73 | generated = context 74 | with torch.no_grad(): 75 | for _ in trange(length): 76 | inputs = {'input_ids': generated[0][-(n_ctx - 1):].unsqueeze(0)} 77 | outputs = model( 78 | **inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) 79 | next_token_logits = outputs[0][0, -1, :] 80 | for id in set(generated): 81 | next_token_logits[id] /= repitition_penalty 82 | next_token_logits = next_token_logits / temperature 83 | next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf') 84 | filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) 85 | next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) 86 | generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) 87 | return generated.tolist()[0] 88 | 89 | 90 | def fast_sample_sequence(model, context, length, repitition_penalty = 1.0, temperature=1.0, top_k=30, top_p=0.0, device='cpu'): 91 | inputs = torch.LongTensor(context).view(1, -1).to(device) 92 | if len(context) > 1: 93 | _, past = model(inputs[:, :-1], None)[:2] 94 | prev = inputs[:, -1].view(1, -1) 95 | else: 96 | past = None 97 | prev = inputs 98 | generate = [] + context 99 | with torch.no_grad(): 100 | for i in trange(length): 101 | output = model(prev, past_key_values=past) 102 | output, past = output[:2] 103 | output = output[-1].squeeze(0) / temperature 104 | for id in set(generate): 105 | output[id] /= repitition_penalty 106 | filtered_logits = top_k_top_p_filtering(output, top_k=top_k, top_p=top_p) 107 | next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1) 108 | generate.append(next_token.item()) 109 | prev = next_token.view(1, 1) 110 | if next_token.item() == eod_id: break 111 | return generate 112 | 113 | 114 | # 通过命令行参数--fast_pattern,指定模式 115 | def generate(n_ctx, model, context, length, tokenizer, temperature=1, top_k=0, top_p=0.0, repitition_penalty=1.0, device='cpu', 116 | is_fast_pattern=False): 117 | if is_fast_pattern: 118 | return fast_sample_sequence(model, context, length, temperature=temperature, top_k=top_k, top_p=top_p, repitition_penalty=repitition_penalty, 119 | device=device) 120 | else: 121 | return sample_sequence(model, context, length, n_ctx, tokenizer=tokenizer, temperature=temperature, top_k=top_k, top_p=top_p, 122 | repitition_penalty=repitition_penalty, device=device) 123 | 124 | 125 | length = 100 126 | batch_size = 1 127 | nsamples = 1 128 | temperature = 1 129 | topk = 0 130 | topp = 0.85 131 | repetition_penalty = 1.1 132 | device = "cpu" 133 | tokenizer = CpmTokenizer(vocab_file="vocab/chinese_vocab.model") 134 | eod_id = tokenizer.convert_tokens_to_ids('') 135 | mask_id = tokenizer.mask_token_id 136 | model = GPT2LMHeadModel.from_pretrained("asoul_cpm") 137 | model.to(device) 138 | model.eval() 139 | n_ctx = model.config.n_ctx 140 | if length == -1: 141 | length = model.config.n_ctx 142 | 143 | is_processing = False 144 | print('model successfully loaded.') 145 | def process(prefix): 146 | global is_processing 147 | if is_processing: return "","" 148 | is_processing = True 149 | raw_text = prefix 150 | context_tokens = [mask_id] + tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_text)) 151 | generated = 0 152 | for _ in range(nsamples // batch_size): 153 | out = generate( 154 | n_ctx=n_ctx, 155 | model=model, 156 | context=context_tokens, 157 | length=length, 158 | is_fast_pattern=True, tokenizer=tokenizer, 159 | temperature=temperature, top_k=topk, top_p=topp, repitition_penalty=repetition_penalty, device=device 160 | ) 161 | cnt = 1 162 | while cnt < 5 and len(out) - len(context_tokens) < 15: 163 | out = generate( 164 | n_ctx=n_ctx, 165 | model=model, 166 | context=context_tokens, 167 | length=length, 168 | is_fast_pattern=True, tokenizer=tokenizer, 169 | temperature=temperature, top_k=topk, top_p=topp, repitition_penalty=repetition_penalty, device=device 170 | ) 171 | cnt += 1 172 | text = tokenizer.convert_ids_to_tokens(out) 173 | #for i, item in enumerate(text[:-1]): # 确保英文前后有空格 174 | # if is_word(item) and is_word(text[i + 1]): 175 | # text[i] = item + ' ' 176 | for i, item in enumerate(text): 177 | if item == '': 178 | text[i] = '' 179 | elif item == '': 180 | text[i] = '\n\n' 181 | elif item == '': 182 | text[i] = '\n' 183 | info = "=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40 + "\n" 184 | print(info) 185 | l = len(context_tokens) 186 | text1 = ''.join(text[:l]).replace('▁', '').replace('▂', '').replace('▃', '').replace("“", "「").replace("”", "」").strip() 187 | text2 = ''.join(text[l:]).replace('▁', '').replace('▂', '').replace('▃', '').replace("“", "「").replace("”", "」").strip() 188 | is_processing = False 189 | return prefix, text2 --------------------------------------------------------------------------------