├── model └── sen_model.pkl ├── static_files ├── images │ ├── IMG0 │ ├── logo.png │ ├── logo1.png │ ├── logo2.png │ ├── logo3.png │ ├── bot-avatar.png │ ├── user-avatar.png │ ├── send.svg │ ├── clear.svg │ ├── export.svg │ ├── fold.svg │ ├── user-avatar.svg │ └── bot-avatar.svg ├── index.css └── index.js ├── Judgment_sensitivity ├── data │ └── info │ │ ├── test.txt │ │ └── train.txt ├── predict.py ├── dataset.py ├── model.py └── train.py ├── FirewaLLM_server.png ├── FirewaLLM_fronted.png ├── cloud3.py ├── cloud1.py ├── app.py ├── model.py ├── cosSIM.py ├── README.md ├── cloud2.py ├── feature.py ├── templates_files └── index.html └── mark.py /model/sen_model.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /static_files/images/IMG0: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Judgment_sensitivity/data/info/test.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Judgment_sensitivity/data/info/train.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /FirewaLLM_server.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ysy1216/FirewaLLM/HEAD/FirewaLLM_server.png -------------------------------------------------------------------------------- /FirewaLLM_fronted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ysy1216/FirewaLLM/HEAD/FirewaLLM_fronted.png -------------------------------------------------------------------------------- /static_files/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ysy1216/FirewaLLM/HEAD/static_files/images/logo.png -------------------------------------------------------------------------------- /static_files/images/logo1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ysy1216/FirewaLLM/HEAD/static_files/images/logo1.png -------------------------------------------------------------------------------- /static_files/images/logo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ysy1216/FirewaLLM/HEAD/static_files/images/logo2.png -------------------------------------------------------------------------------- /static_files/images/logo3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ysy1216/FirewaLLM/HEAD/static_files/images/logo3.png -------------------------------------------------------------------------------- /static_files/images/bot-avatar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ysy1216/FirewaLLM/HEAD/static_files/images/bot-avatar.png -------------------------------------------------------------------------------- /static_files/images/user-avatar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ysy1216/FirewaLLM/HEAD/static_files/images/user-avatar.png -------------------------------------------------------------------------------- /cloud3.py: -------------------------------------------------------------------------------- 1 | #cloud3 gpt3.5 2 | import openai 3 | 4 | def cloud_model3(content): 5 | openai.api_key = "EnterYourKey" 6 | response = openai.ChatCompletion.create( 7 | model="gpt-3.5-turbo-0301", # gpt-3.5-turbo-0301、text-davinci-003 8 | messages=[ 9 | {"role": "user", "content": content} 10 | ], 11 | temperature=0.7, 12 | max_tokens=1000, 13 | top_p=1, 14 | frequency_penalty=0, 15 | presence_penalty=0.5, 16 | ) 17 | return response.choices[0].message.content 18 | -------------------------------------------------------------------------------- /static_files/images/send.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /static_files/images/clear.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /static_files/images/export.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /static_files/images/fold.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Judgment_sensitivity/predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model import BertTextClassifier 3 | from transformers import BertTokenizer, BertConfig 4 | #Y mean sensitive, N mean insensitive 5 | labels = ["Y", "N"] 6 | bert_config = BertConfig.from_pretrained('bert-base-chinese') 7 | # definition model 8 | model = BertTextClassifier(bert_config, len(labels)) 9 | # Loading a trained model 10 | model.load_state_dict(torch.load('models/best_model.pkl', map_location=torch.device('cpu'))) 11 | model.eval() 12 | tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') 13 | 14 | print('classification: ') 15 | while True: 16 | text = input('Input: ') 17 | if not text: 18 | print('Please enter valid text content!') 19 | continue 20 | 21 | token = tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512) 22 | input_ids = token['input_ids'] 23 | attention_mask = token['attention_mask'] 24 | token_type_ids = token['token_type_ids'] 25 | 26 | input_ids = torch.tensor([input_ids], dtype=torch.long) 27 | attention_mask = torch.tensor([attention_mask], dtype=torch.long) 28 | token_type_ids = torch.tensor([token_type_ids], dtype=torch.long) 29 | 30 | model.to('cpu') # Switch the model to the CPU 31 | predicted = model( 32 | input_ids, 33 | attention_mask, 34 | token_type_ids, 35 | ) 36 | pred_label = torch.argmax(predicted, dim=1) 37 | 38 | print('Label:', labels[pred_label]) -------------------------------------------------------------------------------- /Judgment_sensitivity/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from transformers import BertTokenizer 4 | 5 | 6 | class InfoDataset(Dataset): 7 | def __init__(self, filename): 8 | self.labels = ["Y", "N"] 9 | self.labels_id = list(range(len(self.labels))) 10 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') 11 | self.data = [] 12 | self.load_data(filename) 13 | 14 | def load_data(self, filename): 15 | print('Loading data from:', filename) 16 | with open(filename, 'r', encoding='utf-8') as rf: 17 | for line in rf: 18 | text, label = line.strip().split('_') 19 | label_id = self.labels.index(label) 20 | token = self.tokenizer( 21 | text, 22 | add_special_tokens=True, 23 | padding='max_length', 24 | truncation=True, 25 | max_length=512 26 | ) 27 | input_ids = torch.tensor(token['input_ids']) 28 | token_type_ids = torch.tensor(token['token_type_ids']) 29 | attention_mask = torch.tensor(token['attention_mask']) 30 | 31 | self.data.append((input_ids, token_type_ids, attention_mask, label_id)) 32 | 33 | def __getitem__(self, index): 34 | input_ids, token_type_ids, attention_mask, label_id = self.data[index] 35 | return input_ids, token_type_ids, attention_mask, label_id 36 | 37 | def __len__(self): 38 | return len(self.data) 39 | -------------------------------------------------------------------------------- /cloud1.py: -------------------------------------------------------------------------------- 1 | #cloud1 BaiDuModel 2 | import requests 3 | import json 4 | from cosSIM import similarity 5 | 6 | def cloud_model1(question): 7 | url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=EnterYourToken" 8 | payload = json.dumps({ 9 | "log_id": "1234567890", 10 | "version": "3.0", 11 | "service_id": "S99567", 12 | "session_id": "", 13 | "request": { 14 | "query": question, 15 | "terminal_id": "1234567890" 16 | }, 17 | "dialog_state": { 18 | "contexts": { 19 | "SYS_REMEMBERED_SKILLS": [ 20 | "" 21 | ] 22 | } 23 | } 24 | }) 25 | headers = { 26 | 'Content-Type': 'application/json' 27 | } 28 | response = requests.request("POST", url, headers=headers, data=payload) 29 | data = response.json() 30 | # Access and print "say" directly if it exists in the response 31 | response = data.get("result", {}).get("responses", [{}])[0].get("actions", [{}])[0].get("options", [{}]) 32 | answer_list = [] 33 | for i in range(0, 5): 34 | obj = response[i] 35 | res = obj.get("info").get("full_answer") 36 | answer_list.append(res) 37 | print("找到的相关信息:", answer_list) 38 | max_score = 0.0 39 | pos = -1 40 | for i in range(0, 5): 41 | cos_sim = similarity(question, answer_list[i]) 42 | if max_score < cos_sim: 43 | max_score = cos_sim 44 | pos = i 45 | print("最终回复:", answer_list[pos]) 46 | return answer_list[pos] 47 | 48 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, render_template, request 2 | from mark import fun_1 3 | from cloud1 import cloud_model1 4 | from cloud2 import cloud_model2 5 | from cloud3 import cloud_model3 6 | import os 7 | app = Flask(__name__, static_url_path='/static') 8 | # set env 9 | os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890" # os.environ["http_proxy"] 10 | os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890" # os.environ["https_proxy"] 11 | 12 | @app.route('/') 13 | def index(): 14 | return render_template('index.html') 15 | 16 | 17 | #submit and get masked information 18 | @app.route("/get_mask", methods=['POST', 'GET']) 19 | def get_mask(): 20 | user_input = request.form['user_input'] 21 | print("脱敏前", user_input) 22 | if user_input == "bye": 23 | return "Goodbye!" 24 | selected_sen_level = (int)(request.form["sen_level"]) 25 | selected_tag = request.form["ask_tag"] 26 | masked_text = fun_1(user_input, selected_sen_level, selected_tag) 27 | print("脱敏后", masked_text) 28 | return masked_text 29 | #get model response 30 | @app.route("/get_response", methods=['POST','GET']) 31 | def get_response(): 32 | selected_cloud_model = request.form['selected_model'] # 33 | masked_text = request.form['mask_info'] 34 | if selected_cloud_model == 'model1': 35 | response = cloud_model1(masked_text) 36 | elif selected_cloud_model == 'model2': 37 | response = cloud_model2(masked_text) 38 | elif selected_cloud_model == 'model3': 39 | response = cloud_model3(masked_text) 40 | return response 41 | if __name__ == '__main__': 42 | app.run(debug=True) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import BertModel 4 | 5 | # Bert 6 | class BertTextClassifier(nn.Module): 7 | def __init__(self, bert_config, num_labels): 8 | super().__init__() 9 | # Define BERT model 10 | self.bert = BertModel(config=bert_config) 11 | # Define Classifier 12 | self.classifier = nn.Linear(bert_config.hidden_size, num_labels) 13 | 14 | def forward(self, input_ids, attention_mask, token_type_ids): 15 | # BERT's output 16 | bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) 17 | # Pooled output at [CLS] position 18 | pooled = bert_output[1] 19 | # classification 20 | logits = self.classifier(pooled) 21 | # Return the result after softmax 22 | return torch.softmax(logits, dim=1) 23 | 24 | 25 | # Bert+BiLSTM 26 | class BertLstmClassifier(nn.Module): 27 | def __init__(self, bert_config, num_labels): 28 | super().__init__() 29 | self.bert = BertModel(config=bert_config) 30 | self.lstm = nn.LSTM(input_size=bert_config.hidden_size, hidden_size=bert_config.hidden_size, num_layers=2, 31 | batch_first=True, bidirectional=True) 32 | self.classifier = nn.Linear(bert_config.hidden_size * 2, num_labels) 33 | self.softmax = nn.Softmax(dim=1) 34 | 35 | def forward(self, input_ids, attention_mask, token_type_ids): 36 | outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) 37 | last_hidden_state = outputs.last_hidden_state 38 | out, _ = self.lstm(last_hidden_state) 39 | logits = self.classifier(out[:, -1, :]) # Take the output at the last moment 40 | return self.softmax(logits) -------------------------------------------------------------------------------- /cosSIM.py: -------------------------------------------------------------------------------- 1 | import jieba 2 | import math 3 | 4 | def tokenize(text): 5 | # Word segmentation and removing empty characters 6 | return [word for word in jieba.cut(text, cut_all=True) if word != ''] 7 | def build_word_dict(tokenized_texts): 8 | # Create a vocabulary and assign a unique encoding for each word 9 | word_set = set() 10 | for tokens in tokenized_texts: 11 | word_set.update(tokens) 12 | 13 | word_dict = {word: idx for idx, word in enumerate(word_set)} 14 | return word_dict 15 | def text_to_code(text, word_dict): 16 | # Convert text to an encoding sequence 17 | code = [0] * len(word_dict) 18 | for word in text: 19 | code[word_dict[word]] += 1 20 | return code 21 | def cosine_similarity(vec1, vec2): 22 | # Calculate cosine similarity 23 | dot_product = sum(x * y for x, y in zip(vec1, vec2)) 24 | norm1 = math.sqrt(sum(x * x for x in vec1)) 25 | norm2 = math.sqrt(sum(x * x for x in vec2)) 26 | try: 27 | result = round(dot_product / (norm1 * norm2), 2) 28 | except ZeroDivisionError: 29 | result = 0.0 30 | return result 31 | 32 | def similarity(sentence1, sentence2): 33 | s1_tokens = tokenize(sentence1) 34 | s2_tokens = tokenize(sentence2) 35 | word_dict = build_word_dict([s1_tokens, s2_tokens]) 36 | # Convert text to an encoding sequence 37 | s1_code = text_to_code(s1_tokens, word_dict) 38 | s2_code = text_to_code(s2_tokens, word_dict) 39 | # Calculate cosine similarity 40 | similarity = cosine_similarity(s1_code, s2_code) 41 | print("与问题的余弦相似度:", similarity) 42 | return similarity 43 | def main(): 44 | s1 = "what is Adams's phone number" 45 | s2 = "Adams's phone number is 15925526729" 46 | s3 = "Mitchell's phone number is 18764284516" 47 | similarity(s1, s2) 48 | similarity(s1, s3) 49 | 50 | if __name__ == '__main__': 51 | main() -------------------------------------------------------------------------------- /static_files/images/user-avatar.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Judgment_sensitivity/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import BertModel, BertConfig 3 | from transformers import BertForSequenceClassification 4 | # Bert 5 | class BertTextClassifier(nn.Module): 6 | def __init__(self, bert_model_name, num_labels): 7 | """ 8 | Args: 9 | bert_model_name (str): BERTThe name or path of the model 10 | num_labels (int): Number of categories classified 11 | """ 12 | super(BertTextClassifier, self).__init__() 13 | # Define BERT model 14 | bert_config = BertConfig.from_pretrained(bert_model_name) 15 | self.bert = BertModel.from_pretrained(bert_model_name, config=bert_config) 16 | # Define Classifier 17 | self.classifier = nn.Linear(bert_config.hidden_size, num_labels) 18 | 19 | def forward(self, input_ids, attention_mask, token_type_ids): 20 | # BERT's output 21 | bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) 22 | pooled = bert_output[1] 23 | # classification 24 | logits = self.classifier(pooled) 25 | return logits 26 | # Bert+BiLSTM 27 | class BertLstmClassifier(nn.Module): 28 | def __init__(self, bert_model_name, num_labels): 29 | super(BertLstmClassifier, self).__init__() 30 | # Using BertForSequenceClassification in the Hugging Face Transformers library 31 | self.bert = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=num_labels) 32 | # LSTM层 33 | self.lstm = nn.LSTM(input_size=BertConfig.hidden_size, hidden_size=BertConfig.hidden_size, num_layers=2, 34 | batch_first=True, bidirectional=True) 35 | 36 | def forward(self, input_ids, attention_mask, token_type_ids): 37 | # Using BERT for feature extraction 38 | outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) 39 | # Extract BERT's output 40 | bert_output = outputs.logits 41 | # Using bidirectional LSTM for processing 42 | lstm_output, _ = self.lstm(bert_output) 43 | # Take the output at the last moment 44 | logits = lstm_output[:, -1, :] 45 | return logits 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FirewaLLM 2 | By calling FirewaLLM, users can ensure the accuracy of the large model while greatly reducing the risk of privacy leakage when interacting with it. 3 | 4 | With the development of large models, more and more users are using them for interactive Q&A. However, with the increasing volume of data, privacy issues have attracted widespread attention.So, we launched a privacy protection small model. Firstly, we will judge sensitive statements based on user questions, and secondly, we will handle sensitive statements accordingly. We guarantee that user privacy exposure may be greatly reduced when interacting with the model. Finally, we ensure the accuracy of large model interactions by restoring sensitive information. Experiments have shown that users interact with large models by calling small models. 5 | 6 | The FirewaLLM framework not only protects user data privacy, but also protects the accuracy of the interaction model. 7 | 8 | 9 | # Usage 10 | 11 | 1. Train FirewallM to have the ability to recognize sensitive information. 12 | 2. Open the FirewaLLM interface for interaction, and users can choose different large models for interaction. 13 | ```python 14 | python FirewaLLM/app.py 15 | ``` 16 | 2. Open the FirewaLLM interface for interaction, and users can choose different large models for interaction. 17 | 3. The effect is as follows: 18 | Firstly. During the process of interacting with large models, FirewallLLM uses file filtering function to process sensitive information. 19 | ![image](https://github.com/ysy1216/FirewaLLM/blob/main/FirewaLLM_server.png) 20 | Secondly, FirewallLLM interacts data with user specified large models. Finally, the answer is processed to recover sensitive information, and the answer with the highest similarity to the question is returned. 21 | ![image](https://github.com/ysy1216/FirewaLLM/blob/main/FirewaLLM_fronted.png) 22 | 4. It can be intuitively seen that when a user inputs sensitive information, FirewaLLM will perform different levels of sensitive information processing operations on the sensitive information. 23 | 5. In the end, FirewaLLM will perform sensitivity restoration processing, returning the results with the highest similarity between the answers of the large model and the problem, in order to protect privacy while ensuring the accuracy of the overall process. 24 | -------------------------------------------------------------------------------- /cloud2.py: -------------------------------------------------------------------------------- 1 | #cloud2 TencentModel 2 | import json 3 | from tencentcloud.common import credential 4 | from tencentcloud.common.profile.client_profile import ClientProfile 5 | from tencentcloud.common.profile.http_profile import HttpProfile 6 | from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException 7 | from tencentcloud.tbp.v20190627 import tbp_client, models 8 | 9 | def cloud_model2(user_input): 10 | try: 11 | # Instantiating an authentication object requires passing in the Tencent Cloud account SecretId and SecretKey to enter the parameters. Here, it is also necessary to pay attention to the confidentiality of the key pair 12 | #Code leakage may lead to the leakage of SecretId and SecretKey, and threaten the security of all resources under the account. The following code example is for reference only, and it is recommended to use the key in a more secure way 13 | SecretId = "EnterYourSecretId" 14 | SecretKey = "EnterYourSecretKey" 15 | cred = credential.Credential(SecretId, SecretKey) 16 | #Instantiate an HTTP option 17 | httpProfile = HttpProfile() 18 | httpProfile.endpoint = "tbp.ap-guangzhou.tencentcloudapi.com" 19 | 20 | # Instantiating a client option 21 | clientProfile = ClientProfile() 22 | clientProfile.httpProfile = httpProfile 23 | # Instantiate the client object to request the product 24 | client = tbp_client.TbpClient(cred, "", clientProfile) 25 | 26 | #Instantiate a request object, and each interface will correspond to a request object 27 | req = models.TextProcessRequest() 28 | params = { 29 | "BotId": "58d826e3-73af-4ddb-887d-350475e88a9a", 30 | "BotEnv": "release", 31 | "TerminalId": "user1", 32 | "SessionAttributes": "True", 33 | "InputText": user_input, 34 | "PlatformType": "True", 35 | "PlatformId": "True" 36 | } 37 | req.from_json_string(json.dumps(params)) 38 | 39 | #The returned resp is an instance of TextProcessResponse, corresponding to the request object 40 | resp = client.TextProcess(req) 41 | # Output the content in ResponseText 42 | response_text = json.loads(resp.to_json_string())["ResponseText"] 43 | return response_text 44 | 45 | except TencentCloudSDKException as err: 46 | return err -------------------------------------------------------------------------------- /feature.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import jieba 3 | from sklearn.feature_extraction.text import TfidfVectorizer 4 | import pandas as pd 5 | 6 | #Import corpus 7 | def importdata(tag, flag): 8 | text_list = [] 9 | if tag == "true": 10 | file_path = "data/Info/info2.txt" 11 | with open(file_path, encoding="UTF-8") as f: 12 | text = f.read() 13 | if flag: 14 | text_jieba = jieba.cut_for_search(text) 15 | else: 16 | text_jieba = jieba.lcut(text) 17 | text_list.append(" ".join(text_jieba)) 18 | else: 19 | for i in range(3, 14): 20 | file_path = f"data/Info/info{i}.txt" 21 | with open(file_path, encoding="UTF-8") as f: 22 | text = f.read() 23 | text_jieba = jieba.lcut(text) 24 | text_list.append(" ".join(text_jieba)) 25 | return text_list 26 | 27 | #Import stop list and remove useless text 28 | def stop_word(): 29 | stopword_list = [] 30 | for i in range(1, 5): 31 | file_path = f"data/StopWord/StopWord{i}.txt" 32 | with open(file_path, encoding="UTF-8") as f: 33 | stopwords = f.read().split("\n") 34 | stopword_list.extend(stopwords) 35 | return stopword_list 36 | 37 | #Use jieba participle in sentences 38 | def process_sentence(flag, sentence): 39 | if flag: 40 | sentence_jieba = jieba.lcut_for_search(sentence) 41 | sentence_jieba = " ".join(sentence_jieba) 42 | else: 43 | sentence_jieba = sentence 44 | return [sentence_jieba] 45 | 46 | #Using Tfidf to Obtain Corpus Features 47 | def vectorize_data(data, sentence_jieba): 48 | vectorizer = TfidfVectorizer(stop_words=stop_word()) 49 | X = vectorizer.fit_transform(data).toarray() 50 | X_fea = vectorizer.get_feature_names_out() 51 | result = vectorizer.transform(sentence_jieba).toarray() 52 | X_pd = pd.DataFrame(result, columns=X_fea) 53 | return X_pd 54 | #Compare the input sentences with the features of the corpus 55 | def getFeature(tag, sentence, flag=True): 56 | data = importdata(tag, flag) 57 | st = process_sentence(flag, sentence) 58 | X_pd = vectorize_data(data, st) 59 | word_list = [] 60 | X_pd_sort = X_pd.sort_values(by=0, axis=1, ascending=False) 61 | for i in range(0, 50): 62 | if X_pd_sort.iat[0, i] >= 0.20: 63 | tmp = X_pd_sort.iloc[:, i] 64 | tmp_frame = tmp.to_frame() 65 | word = "".join(tmp_frame.columns.tolist()) 66 | if word not in word_list: 67 | word_list.append(word) 68 | else: 69 | break 70 | if len(word_list) == 0 and flag: 71 | word_list = getFeature(tag, sentence, False) 72 | return word_list 73 | -------------------------------------------------------------------------------- /templates_files/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | FirewaLLM 7 | 8 | 9 | 10 |
11 |
12 |
13 | 14 |
15 |
16 |
17 |
18 |
19 |
20 | 青少年模式: 21 | 25 |
26 | 27 |
28 |
29 | 30 |
31 |
32 | 33 |
34 | 35 | 36 | 37 |
38 |
39 | 40 |
41 | 42 |
43 |
44 |
45 | 46 | 47 | 53 | 58 |
59 |
60 |
61 |
62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /static_files/images/bot-avatar.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mark.py: -------------------------------------------------------------------------------- 1 | # mark.py 2 | 3 | import re 4 | import jieba 5 | import torch 6 | from model import BertTextClassifier, BertLstmClassifier 7 | from transformers import BertTokenizer, BertConfig, AutoTokenizer 8 | import random 9 | from feature import getFeature, stop_word 10 | 11 | class MaskHandler: 12 | def __init__(self, model_path): 13 | # Initialize the local BERT model 14 | self.labels = ["Y", "N"] 15 | self.bert_config = BertConfig.from_pretrained('bert-base-chinese') 16 | self.model = BertLstmClassifier(self.bert_config, len(self.labels)) 17 | self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False) 18 | self.model.eval() 19 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') 20 | 21 | # Initialize the cloud-based roberta model 22 | self.cloud_tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large") 23 | 24 | def classic(self, query): 25 | sensitive_words = [] 26 | for tmp_text in jieba.lcut(query): 27 | #print(tmp_text) 28 | token = self.tokenizer(tmp_text, add_special_tokens=True, padding='max_length', truncation=True, 29 | max_length=512) 30 | input_ids = torch.tensor([token['input_ids']], dtype=torch.long) 31 | attention_mask = torch.tensor([token['attention_mask']], dtype=torch.long) 32 | token_type_ids = torch.tensor([token['token_type_ids']], dtype=torch.long) 33 | predicted = self.model(input_ids, attention_mask, token_type_ids) 34 | pred_label = torch.argmax(predicted, dim=1) 35 | 36 | return sensitive_words 37 | 38 | #Desensitization algorithm 39 | def mask_sensitive_info(self, text, sensitive, level, tag): 40 | #print("脱敏等级第" + (str)(level+1) + "级") 41 | for word in sensitive: 42 | # text_jieba = jieba.lcut(text) 43 | len_word = len(word) 44 | if len_word < 8: 45 | tmp_level = 0 46 | else: 47 | tmp_level = level 48 | if tag == "true": 49 | length = len(word) 50 | else: 51 | length = (int)(len_word / 10 + tmp_level + 1) 52 | list = range(0, len(word)) 53 | py = random.sample(list, length) 54 | tmp_word = word 55 | for count in range(0, length): 56 | pos = py[count] 57 | masked_sensitive = tmp_word[:pos] + '*' + tmp_word[pos + 1:] 58 | tmp_word = masked_sensitive 59 | text = re.sub(word, masked_sensitive, text, flags=re.IGNORECASE) 60 | return text 61 | 62 | #Divide sentences into punctuation marks 63 | def fun_splite(text): 64 | pattern = r',|\.|/|;|\'|`|\[|\]|<|>|\?|:|"|\{|\}|\~|!|@|#|\$|%|\^|&|\(|\)|-|=|\_|\+|,|。|、|;|‘|’|【|】|:|!| |…|(|),·' 65 | result_list = re.split(pattern, text.strip()) 66 | return result_list 67 | 68 | #Split the sentence with a stop word list 69 | def fun_splitein(text): 70 | sentence_depart = jieba.lcut(text.strip()) 71 | stopwords = stop_word() 72 | outstr = "" 73 | for word in sentence_depart: 74 | if word not in stopwords: 75 | if word != "\t": 76 | outstr += word 77 | # outstr += " " 78 | return outstr 79 | 80 | #Using models to determine whether sentences are sensitive 81 | def fun_isSen(maskHandler, text, tag): 82 | flag = False 83 | token = maskHandler.tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512) 84 | input_ids = torch.tensor([token['input_ids']], dtype=torch.long) 85 | attention_mask = torch.tensor([token['attention_mask']], dtype=torch.long) 86 | token_type_ids = torch.tensor([token['token_type_ids']], dtype=torch.long) 87 | predicted = maskHandler.model(input_ids, attention_mask, token_type_ids) 88 | output = torch.softmax(predicted, dim=1) 89 | print(output) 90 | #output [:, 1]is an insensitive probability 91 | if tag == "false": 92 | if output[:, 0].item() > 0.5: 93 | flag = True 94 | return flag 95 | 96 | #Invert the insensitive phrase returned by tfidf and return the sensitive phrase 97 | def getSen(nosen, text): 98 | sen = [] 99 | text_jieba = jieba.lcut(text) 100 | for word in text_jieba: 101 | if word not in nosen and len(word) > 1: 102 | if word not in sen: 103 | sen.append(word) 104 | return sen 105 | #When inputting multiple sentences, determine whether each sentence is sensitive and then desensitize it 106 | def fun_1(text, selected_sen_level, tag): 107 | maskHandler = MaskHandler("model/sen_model.pkl") # Sensitive model 108 | text_splite = fun_splite(text) 109 | tmp = text 110 | for tmp_text in text_splite: 111 | text_stop = fun_splitein(tmp_text) 112 | sen_fea = [] 113 | if fun_isSen(maskHandler, tmp_text, tag): 114 | if tag == "false": 115 | sen_fea = getSen(getFeature(tag, text_stop), text_stop) 116 | else: 117 | sen_fea = getFeature(tag, tmp_text, True) 118 | #print(sen_fea) 119 | res = maskHandler.mask_sensitive_info(tmp, sen_fea, selected_sen_level, tag) 120 | tmp = res 121 | print(tmp) 122 | return tmp 123 | 124 | if __name__ == '__main__': 125 | str = "..." 126 | fun_1(str, 1, "false") 127 | 128 | -------------------------------------------------------------------------------- /Judgment_sensitivity/train.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | from transformers import BertTokenizer, AdamW, BertConfig 6 | from torch.utils.data import DataLoader 7 | from model import BertTextClassifier, BertLstmClassifier 8 | from dataset import InfoDataset 9 | from tqdm import tqdm 10 | from sklearn import metrics 11 | import pandas as pd 12 | import numpy as np 13 | 14 | def main(): 15 | # Parameter settings 16 | batch_size = 64 17 | 18 | epochs = 50 19 | learning_rate = 1e-4 20 | # Obtain dataset 21 | 22 | train_dataset = InfoDataset("data/info/train.txt") 23 | valid_dataset = InfoDataset("data/info/test.txt") 24 | # Generate Batch 25 | valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) 26 | # Read BERT's configuration file 27 | bert_config = BertConfig.from_pretrained('bert-base-chinese') 28 | num_labels = len(train_dataset.labels) 29 | # initial model 30 | model = BertLstmClassifier(bert_config, num_labels).to(device) #bert_model 31 | # optimizer 32 | optimizer = AdamW(model.parameters(), lr=learning_rate) 33 | # loss function 34 | criterion = nn.CrossEntropyLoss() 35 | best_f1 = 0 36 | tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') 37 | for epoch in range(1, epochs + 1): 38 | total_loss = 0 # Total loss for the epoch 39 | total_correct = 0 # Total correct predictions for the epoch 40 | total_samples = 0 # Total processed samples for the epoch 41 | 42 | model.train() 43 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 44 | train_bar = tqdm(train_dataloader, ncols=100, desc='Epoch {} train'.format(epoch)) 45 | 46 | for input_ids, token_type_ids, attention_mask, label_id in train_bar: 47 | # Move input tensors to the appropriate device (CPU or GPU) 48 | input_ids = input_ids.to(device) 49 | token_type_ids = token_type_ids.to(device) 50 | attention_mask = attention_mask.to(device) 51 | label_id = label_id.to(device) 52 | 53 | # Zero the gradients 54 | optimizer.zero_grad() 55 | 56 | # Forward pass 57 | output = model(input_ids, token_type_ids, attention_mask) 58 | 59 | # Compute the loss 60 | loss = criterion(output, label_id) 61 | total_loss += loss.item() 62 | 63 | # Backward pass and optimization 64 | loss.backward() 65 | optimizer.step() 66 | 67 | # Calculate the number of correct predictions and the total number of samples 68 | predicted_labels = torch.argmax(output, dim=1) 69 | total_correct += torch.sum(predicted_labels == label_id).item() 70 | total_samples += input_ids.size(0) # input_ids.size(0) is the batch size 71 | 72 | # Update the progress bar description 73 | train_bar.set_postfix(loss=loss.item()) 74 | 75 | average_loss = total_loss / len(train_dataloader) 76 | average_accuracy = total_correct / total_samples 77 | 78 | print('Train ACC: {:.4f}\tLoss: {:.4f}'.format(average_accuracy, average_loss)) 79 | 80 | # test and verify 81 | model.eval() 82 | total_loss = 0 83 | pred_labels = [] 84 | true_labels = [] 85 | 86 | with torch.no_grad(): 87 | valid_bar = tqdm(valid_dataloader, ncols=100, desc='Epoch {} valid'.format(epoch)) 88 | for input_ids, token_type_ids, attention_mask, label_id in valid_bar: 89 | input_ids = input_ids.to(device) 90 | token_type_ids = token_type_ids.to(device) 91 | attention_mask = attention_mask.to(device) 92 | label_id = label_id.to(device) 93 | 94 | output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) 95 | 96 | loss = criterion(output, label_id) 97 | total_loss += loss.item() 98 | 99 | pred_label = torch.argmax(output, dim=1) 100 | acc = torch.sum(pred_label == label_id).item() / len(pred_label) 101 | valid_bar.set_postfix(loss=loss.item(), acc=acc) 102 | 103 | pred_labels.extend(pred_label.cpu().numpy().tolist()) 104 | true_labels.extend(label_id.cpu().numpy().tolist()) 105 | 106 | average_loss = total_loss / len(valid_dataloader) 107 | print('Validation Loss: {:.4f}'.format(average_loss)) 108 | 109 | # Classification report 110 | report = metrics.classification_report(true_labels, pred_labels, labels=valid_dataset.labels_id, 111 | target_names=valid_dataset.labels) 112 | print('* Classification Report:') 113 | print(report) 114 | 115 | f1 = metrics.f1_score(true_labels, pred_labels, labels=valid_dataset.labels_id, average='micro') 116 | 117 | if not os.path.exists('models'): 118 | os.makedirs('models') 119 | 120 | if f1 > best_f1: 121 | best_f1 = f1 122 | torch.save(model.state_dict(), 'models/sen_model.pkl') 123 | #Check if the dataset meets the specifications 124 | def test_1(): 125 | paths = "data/info/train.txt" 126 | paths1 = "data/info/test.txt" 127 | with open(paths, encoding="utf8") as f: 128 | lines = f.readlines() 129 | for line in tqdm(lines, ncols=100): 130 | #print(line) 131 | #print(line.strip().split("_")) 132 | print(line.strip().split('_')) 133 | text, label = line.strip().split('_') 134 | print(label) 135 | if __name__ == '__main__': 136 | # test_1() #split data 137 | main() -------------------------------------------------------------------------------- /static_files/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | padding: 0; 3 | margin: 0; 4 | font-family: Arial, sans-serif; 5 | } 6 | 7 | .container { 8 | display: flex; 9 | position: absolute; 10 | width: 100vw; 11 | height: 100vh; 12 | top: 0; 13 | left: 0; 14 | overflow: hidden; 15 | } 16 | 17 | /* Chat history management. */ 18 | .chat-history { 19 | width: 250px; 20 | height: 100vh; 21 | background-color: #444; 22 | left: 0; 23 | position: relative; 24 | transition: width 0.3s; 25 | } 26 | 27 | .chat-panel { 28 | flex: 1; 29 | width: calc(100% - 250px); 30 | height: 100vh; 31 | background-color: #dae7e6; 32 | box-sizing: border-box; 33 | overflow: hidden; 34 | } 35 | 36 | .separator { 37 | height: 1px; 38 | background-color: rgba(255, 255, 255, 0.08); 39 | margin: 0 20px; 40 | } 41 | 42 | 43 | .chat-controls { 44 | display: flex; 45 | width: 100%; 46 | height: 45px; 47 | padding: 5px; 48 | box-sizing: border-box; 49 | margin-top: 8px; 50 | margin-bottom: 8px; 51 | justify-content: center; 52 | align-items: center; 53 | } 54 | 55 | .function-module { 56 | width: 100%; 57 | height: 550px; 58 | padding: 5px; 59 | box-sizing: border-box; 60 | margin-top: 8px; 61 | justify-content: center; 62 | align-items: center; 63 | overflow: auto; 64 | } 65 | 66 | .settings { 67 | display: flex; 68 | width: 100%; 69 | padding: 5px; 70 | height: calc(100% - 45px - 550px); 71 | box-sizing: border-box; 72 | margin-top: 8px; 73 | } 74 | 75 | .mode { 76 | display: flex; 77 | width: 100%; 78 | height: 45px; 79 | margin-top: 5px; 80 | justify-content: center; 81 | align-items: center; 82 | } 83 | 84 | .mode-label { 85 | position: relative; 86 | padding: 5px; 87 | box-sizing: border-box; 88 | color: white; 89 | } 90 | 91 | .switch { 92 | position: relative; 93 | display: inline-block; 94 | width: 80px; 95 | height: 34px; 96 | } 97 | 98 | .switch input { 99 | display: none; 100 | } 101 | 102 | .slider { 103 | position: absolute; 104 | cursor: pointer; 105 | top: 0; 106 | left: 0; 107 | right: 0; 108 | bottom: 0; 109 | background-color: #ccc; 110 | -webkit-transition: .4s; 111 | transition: .4s; 112 | border-radius: 34px; 113 | } 114 | 115 | .slider:before { 116 | position: absolute; 117 | content: ""; 118 | height: 26px; 119 | width: 26px; 120 | left: 4px; 121 | bottom: 4px; 122 | background-color: white; 123 | -webkit-transition: .4s; 124 | transition: .4s; 125 | border-radius: 50%; 126 | } 127 | 128 | input:checked + .slider { 129 | background-color: #2196F3; 130 | } 131 | 132 | input:focus + .slider { 133 | box-shadow: 0 0 1px #2196F3; 134 | } 135 | 136 | input:checked + .slider:before { 137 | transform: translateX(46px); 138 | } 139 | 140 | .control-panel { 141 | width: 100%; 142 | height: 40px; 143 | display: inline-block; 144 | } 145 | 146 | .collapse-chat-btn { 147 | display: flex; 148 | width: 40px; 149 | height: 40px; 150 | margin: 8px 8px; 151 | padding: 10px; 152 | box-sizing: border-box; 153 | border-radius: 5px; 154 | justify-content: center; 155 | align-items: center; 156 | cursor: pointer; 157 | } 158 | 159 | .chat-panel.expanded { 160 | margin-left: 0; 161 | width: calc(100vw - 250px); 162 | } 163 | 164 | .chat-history.collapsed { 165 | width: 0; 166 | overflow: hidden; 167 | } 168 | 169 | .chat-history.collapsed + .container .chat-panel { 170 | width: 100%; 171 | } 172 | 173 | .control-panel { 174 | display: flex; 175 | width: 100%; 176 | height: 45px; 177 | margin: 8px 8px; 178 | justify-content: center; 179 | align-items: center; 180 | } 181 | 182 | .extras { 183 | display: flex; 184 | width: calc(100% - 40px); 185 | height: 40px; 186 | margin: 8px 8px; 187 | padding: 10px; 188 | box-sizing: border-box; 189 | justify-content: center; 190 | align-items: center; 191 | } 192 | 193 | .extras-one { 194 | display: flex; 195 | width: 40px; 196 | height: 40px; 197 | border: none; 198 | justify-content: center; 199 | align-items: center; 200 | background-color: transparent; 201 | } 202 | 203 | #chat-display { 204 | width: 100%; 205 | height: 550px; 206 | padding: 5px; 207 | box-sizing: border-box; 208 | margin-top: 8px; 209 | border-radius: 10px; 210 | overflow-y: scroll; 211 | } 212 | 213 | .user-input-outside { 214 | display: flex; 215 | width: 100%; 216 | height: calc(100% - 45px - 550px); 217 | box-sizing: border-box; 218 | padding: 10px; 219 | margin-bottom: 10px; 220 | justify-content: center; 221 | align-items: center; 222 | } 223 | 224 | .user-input-area { 225 | display: flex; 226 | width: 700px; 227 | height: 70px; 228 | background-color: #fff; 229 | padding: 5px; 230 | border-radius: 10px; 231 | box-sizing: border-box; 232 | justify-content: space-between; 233 | align-items: center; 234 | } 235 | 236 | #user-input { 237 | flex: 1; 238 | height: calc(100% - 10px); 239 | resize: none; 240 | border: none; 241 | margin: 0px 5px; 242 | padding: 5px; 243 | box-sizing: border-box; 244 | font-size: 16px; 245 | } 246 | 247 | #send-button { 248 | background-color: blue; 249 | color: white; 250 | border: none; 251 | border-radius: 5px; 252 | padding: 8px 10px; 253 | box-sizing: border-box; 254 | cursor: pointer; 255 | margin-right: 10px; 256 | } 257 | 258 | #model-select { 259 | color: #fff; 260 | background-color: #34495e; 261 | border-radius: 6px; 262 | padding: 8px 10px; 263 | box-sizing: border-box; 264 | border-radius: 5px; 265 | margin-right: 5px; 266 | } 267 | 268 | #sensitivity-level { 269 | color: #fff; 270 | background-color: #34495e; 271 | border-radius: 6px; 272 | padding: 8px 10px; 273 | box-sizing: border-box; 274 | border-radius: 5px; 275 | margin-left: 5px; 276 | margin-right: 5px; 277 | } 278 | 279 | .user-message-outside { 280 | display: flex; 281 | width: 100%; 282 | min-height: 45px; 283 | background-color: lightblue; 284 | justify-content: center; 285 | align-items: center; 286 | } 287 | 288 | .chatbot-message-outside { 289 | display: flex; 290 | width: 100%; 291 | min-height: 45px; 292 | background-color: lightskyblue; 293 | justify-content: center; 294 | align-items: center; 295 | } 296 | 297 | .model-message-outside { 298 | display: flex; 299 | width: 100%; 300 | min-height: 45px; 301 | justify-content: center; 302 | align-items: center; 303 | } 304 | 305 | .user-message { 306 | display: flex; 307 | width: 70%; 308 | } 309 | 310 | .chatbot-message { 311 | display: flex; 312 | width: 70%; 313 | } 314 | 315 | .model-message { 316 | display: flex; 317 | width: 70%; 318 | } 319 | 320 | .message-content { 321 | flex: 1; 322 | word-wrap: break-word; 323 | white-space: pre-line; 324 | padding: 10px; 325 | border-radius: 10px; 326 | align-self: flex-start; 327 | font-size: 16px; 328 | } 329 | 330 | .model-message-content { 331 | flex: 1; 332 | word-wrap: break-word; 333 | white-space: pre-line; 334 | padding: 10px; 335 | border-radius: 10px; 336 | align-self: flex-start; 337 | font-size: 16px; 338 | color: rgba(128, 128, 128, 0.5); 339 | } 340 | 341 | .avatar { 342 | display: flex; 343 | flex: 0 0 auto; 344 | width: 45px; 345 | height: 45px; 346 | overflow: hidden; 347 | margin: 10px 10px auto 10px; 348 | } 349 | 350 | .avatar-img { 351 | width: 80%; 352 | height: 80%; 353 | border-radius: 50%; 354 | } 355 | 356 | .copy-button { 357 | float: right; 358 | background-color: transparent; 359 | border: none; 360 | margin-right: 10px; 361 | cursor: pointer; 362 | } 363 | 364 | .hide-model-button { 365 | float: right; 366 | background-color: transparent; 367 | border: none; 368 | margin-right: 10px; 369 | cursor: pointer; 370 | color: gray; 371 | } 372 | -------------------------------------------------------------------------------- /static_files/index.js: -------------------------------------------------------------------------------- 1 | // Fold button click event listener 2 | document.querySelector('.collapse-chat-btn').addEventListener('click', () => { 3 | document.querySelector('.chat-history').classList.toggle('collapsed'); 4 | document.querySelector('.chat-panel').classList.toggle('expanded'); 5 | }); 6 | 7 | 8 | // Teenage mode 9 | document.addEventListener('DOMContentLoaded', function() { 10 | const toggleButton = document.getElementById('toggleButton'); 11 | const hiddenSL = document.getElementById('sensitivity-level'); 12 | const modelSelect = document.getElementById('model-select'); 13 | 14 | toggleButton.addEventListener('change', function() { 15 | if (toggleButton.checked) { 16 | hiddenSL.style.display = 'none'; 17 | 18 | modelSelect.innerHTML = ` 19 | 20 | 21 | 22 | `; 23 | } 24 | 25 | else { 26 | hiddenSL.style.display = 'block'; 27 | 28 | modelSelect.innerHTML = ` 29 | 30 | 31 | 32 | 33 | 34 | `; 35 | } 36 | }); 37 | }); 38 | 39 | 40 | // Add message to chat-display 41 | function appendMessage(userType, message, modelMessage = "") { 42 | const chatDisplay = document.getElementById("chat-display"); 43 | const outsideMessageContainer = document.createElement("div") 44 | outsideMessageContainer.classList.add(userType === "user" ? "user-message-outside" : "chatbot-message-outside"); 45 | const messageContainer = document.createElement("div"); 46 | 47 | const messageContent = document.createElement("div"); 48 | messageContent.classList.add("message-content"); 49 | messageContent.innerHTML = message; 50 | 51 | const avatar = document.createElement("div"); 52 | avatar.classList.add("avatar"); 53 | const avatarImg = document.createElement("img"); 54 | avatarImg.classList.add("avatar-img") 55 | 56 | avatarImg.src = userType === "user" ? "../static/images/user-avatar.svg" : "../static/images/bot-avatar.svg"; 57 | avatar.appendChild(avatarImg); 58 | 59 | // Copy button 60 | const copyButton = document.createElement("button"); 61 | copyButton.textContent = "复制"; 62 | copyButton.classList.add("copy-button"); 63 | copyButton.addEventListener("click", function() { 64 | copyToClipboard(message); 65 | copyButton.textContent = "已复制"; 66 | setTimeout(function() { 67 | copyButton.textContent = "复制"; 68 | }, 1500); 69 | }); 70 | 71 | // Hide model message button 72 | if (modelMessage !== "" && modelMessage.replace(/\s/g, '') != message.replace(/\s/g, '')) { 73 | const hideModelButton = document.createElement("button"); 74 | hideModelButton.textContent = "隐藏模型消息"; 75 | hideModelButton.classList.add("hide-model-button"); 76 | hideModelButton.addEventListener("click", function() { 77 | outsideModelMessageContainer.style.display = outsideModelMessageContainer.style.display === "none" ? "flex" : "none"; 78 | hideModelButton.textContent = outsideModelMessageContainer.style.display === "none" ? "展开模型消息" : "隐藏模型消息"; 79 | }); 80 | 81 | // Set different style class names based on user type 82 | messageContainer.classList.add(userType === "user" ? "user-message" : "chatbot-message"); 83 | messageContainer.appendChild(avatar); 84 | messageContainer.appendChild(messageContent); 85 | messageContainer.appendChild(hideModelButton); 86 | messageContainer.appendChild(copyButton); 87 | outsideMessageContainer.appendChild(messageContainer) 88 | 89 | // Model message 90 | // width 100% 91 | const outsideModelMessageContainer = document.createElement("div"); 92 | outsideModelMessageContainer.classList.add("model-message-outside"); 93 | 94 | const modelMessageContainer = document.createElement("div"); 95 | modelMessageContainer.classList.add("model-message"); 96 | 97 | const modelAvatar = document.createElement("div"); 98 | modelAvatar.classList.add("avatar"); 99 | 100 | const modelMessageContent = document.createElement("div"); 101 | modelMessageContent.classList.add("model-message-content"); 102 | modelMessageContent.innerHTML = modelMessage; 103 | 104 | const modelCopyButton = document.createElement("button"); 105 | modelCopyButton.textContent = "复制"; 106 | modelCopyButton.classList.add("copy-button"); 107 | modelCopyButton.addEventListener("click", function() { 108 | copyToClipboard(modelMessage); 109 | modelCopyButton.textContent = "已复制"; 110 | setTimeout(function() { 111 | modelCopyButton.textContent = "复制"; 112 | }, 1500); 113 | }); 114 | 115 | modelMessageContainer.appendChild(modelAvatar); 116 | modelMessageContainer.appendChild(modelMessageContent); 117 | modelMessageContainer.appendChild(modelCopyButton); 118 | outsideModelMessageContainer.appendChild(modelMessageContainer) 119 | 120 | 121 | chatDisplay.appendChild(outsideMessageContainer); 122 | chatDisplay.appendChild(outsideModelMessageContainer); 123 | } 124 | 125 | else { 126 | messageContainer.classList.add(userType === "user" ? "user-message" : "chatbot-message"); 127 | messageContainer.appendChild(avatar); 128 | messageContainer.appendChild(messageContent); 129 | messageContainer.appendChild(copyButton); 130 | 131 | outsideMessageContainer.appendChild(messageContainer) 132 | chatDisplay.appendChild(outsideMessageContainer); 133 | } 134 | } 135 | 136 | 137 | function copyToClipboard(text) { 138 | const textarea = document.createElement("textarea"); 139 | textarea.value = text; 140 | document.body.appendChild(textarea); 141 | textarea.select(); 142 | document.execCommand("copy"); 143 | document.body.removeChild(textarea); 144 | } 145 | 146 | function sendMessage() { 147 | const userInput = document.getElementById("user-input").value; 148 | const selectedLevel = document.getElementById("sensitivity-level").value; 149 | const selectedTag = document.getElementById("toggleButton").checked; 150 | document.getElementById("user-input").value = ""; 151 | fetch("/get_mask", { 152 | method: "POST", 153 | body: new URLSearchParams({ 154 | "user_input": userInput, 155 | "sen_level": selectedLevel, 156 | "ask_tag":selectedTag 157 | }), 158 | headers: { 159 | "Content-Type": "application/x-www-form-urlencoded" 160 | } 161 | }) 162 | .then(response => response.text()) 163 | .then(data => { 164 | appendMessage("user", userInput, data); 165 | getMessage(data); 166 | }) 167 | .catch(error => { 168 | console.error("Error:", error); 169 | }); 170 | } 171 | 172 | function getMessage(data){ 173 | const selected_model = document.getElementById("model-select").value; 174 | fetch("/get_response", { 175 | method: "POST", 176 | body: new URLSearchParams({ 177 | "selected_model": selected_model, 178 | "mask_info": data 179 | }), 180 | headers: { 181 | "Content-Type": "application/x-www-form-urlencoded" 182 | } 183 | }) 184 | .then(response => response.text()) 185 | .then(data => { 186 | appendMessage("bot", data); 187 | }) 188 | .catch(error => { 189 | console.error("Error:", error); 190 | }); 191 | 192 | } 193 | 194 | document.getElementById('send-button').addEventListener('click', sendMessage); 195 | // Enter+ctrl 196 | document.getElementById('user-input').addEventListener('keydown', function (event) { 197 | if (event.key === "Enter" && !event.ctrlKey) { 198 | event.preventDefault(); 199 | 200 | const startPos = this.selectionStart; 201 | const endPos = this.selectionEnd; 202 | this.value = this.value.substring(0, startPos) + "\n" + this.value.substring(endPos); 203 | 204 | this.selectionStart = startPos + 1; 205 | this.selectionEnd = startPos + 1; 206 | } 207 | else if (event.key === "Enter" && event.ctrlKey) { 208 | event.preventDefault(); 209 | 210 | sendMessage(); 211 | } 212 | }); 213 | 214 | 215 | function clearChat() { 216 | var chatDisplay = document.getElementById("chat-display"); 217 | var confirmation = confirm("记录清除后无法恢复,您确定要清除吗?"); 218 | if (confirmation) { 219 | chatDisplay.innerHTML = ""; 220 | } 221 | } 222 | 223 | 224 | function handleExport() { 225 | var chatDisplay = document.getElementById("chat-display"); 226 | var chatContent = chatDisplay.innerText; 227 | var blob = new Blob([chatContent], { type: "text/plain" }); 228 | 229 | var a = document.createElement("a"); 230 | a.href = URL.createObjectURL(blob); 231 | a.download = "chat_export.txt"; 232 | a.textContent = "Download"; 233 | a.style.display = "none"; 234 | document.body.appendChild(a); 235 | a.click(); 236 | document.body.removeChild(a); 237 | } --------------------------------------------------------------------------------