├── model
└── sen_model.pkl
├── static_files
├── images
│ ├── IMG0
│ ├── logo.png
│ ├── logo1.png
│ ├── logo2.png
│ ├── logo3.png
│ ├── bot-avatar.png
│ ├── user-avatar.png
│ ├── send.svg
│ ├── clear.svg
│ ├── export.svg
│ ├── fold.svg
│ ├── user-avatar.svg
│ └── bot-avatar.svg
├── index.css
└── index.js
├── Judgment_sensitivity
├── data
│ └── info
│ │ ├── test.txt
│ │ └── train.txt
├── predict.py
├── dataset.py
├── model.py
└── train.py
├── FirewaLLM_server.png
├── FirewaLLM_fronted.png
├── cloud3.py
├── cloud1.py
├── app.py
├── model.py
├── cosSIM.py
├── README.md
├── cloud2.py
├── feature.py
├── templates_files
└── index.html
└── mark.py
/model/sen_model.pkl:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/static_files/images/IMG0:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/Judgment_sensitivity/data/info/test.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Judgment_sensitivity/data/info/train.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/FirewaLLM_server.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ysy1216/FirewaLLM/HEAD/FirewaLLM_server.png
--------------------------------------------------------------------------------
/FirewaLLM_fronted.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ysy1216/FirewaLLM/HEAD/FirewaLLM_fronted.png
--------------------------------------------------------------------------------
/static_files/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ysy1216/FirewaLLM/HEAD/static_files/images/logo.png
--------------------------------------------------------------------------------
/static_files/images/logo1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ysy1216/FirewaLLM/HEAD/static_files/images/logo1.png
--------------------------------------------------------------------------------
/static_files/images/logo2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ysy1216/FirewaLLM/HEAD/static_files/images/logo2.png
--------------------------------------------------------------------------------
/static_files/images/logo3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ysy1216/FirewaLLM/HEAD/static_files/images/logo3.png
--------------------------------------------------------------------------------
/static_files/images/bot-avatar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ysy1216/FirewaLLM/HEAD/static_files/images/bot-avatar.png
--------------------------------------------------------------------------------
/static_files/images/user-avatar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ysy1216/FirewaLLM/HEAD/static_files/images/user-avatar.png
--------------------------------------------------------------------------------
/cloud3.py:
--------------------------------------------------------------------------------
1 | #cloud3 gpt3.5
2 | import openai
3 |
4 | def cloud_model3(content):
5 | openai.api_key = "EnterYourKey"
6 | response = openai.ChatCompletion.create(
7 | model="gpt-3.5-turbo-0301", # gpt-3.5-turbo-0301、text-davinci-003
8 | messages=[
9 | {"role": "user", "content": content}
10 | ],
11 | temperature=0.7,
12 | max_tokens=1000,
13 | top_p=1,
14 | frequency_penalty=0,
15 | presence_penalty=0.5,
16 | )
17 | return response.choices[0].message.content
18 |
--------------------------------------------------------------------------------
/static_files/images/send.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/static_files/images/clear.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/static_files/images/export.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/static_files/images/fold.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Judgment_sensitivity/predict.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from model import BertTextClassifier
3 | from transformers import BertTokenizer, BertConfig
4 | #Y mean sensitive, N mean insensitive
5 | labels = ["Y", "N"]
6 | bert_config = BertConfig.from_pretrained('bert-base-chinese')
7 | # definition model
8 | model = BertTextClassifier(bert_config, len(labels))
9 | # Loading a trained model
10 | model.load_state_dict(torch.load('models/best_model.pkl', map_location=torch.device('cpu')))
11 | model.eval()
12 | tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
13 |
14 | print('classification: ')
15 | while True:
16 | text = input('Input: ')
17 | if not text:
18 | print('Please enter valid text content!')
19 | continue
20 |
21 | token = tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512)
22 | input_ids = token['input_ids']
23 | attention_mask = token['attention_mask']
24 | token_type_ids = token['token_type_ids']
25 |
26 | input_ids = torch.tensor([input_ids], dtype=torch.long)
27 | attention_mask = torch.tensor([attention_mask], dtype=torch.long)
28 | token_type_ids = torch.tensor([token_type_ids], dtype=torch.long)
29 |
30 | model.to('cpu') # Switch the model to the CPU
31 | predicted = model(
32 | input_ids,
33 | attention_mask,
34 | token_type_ids,
35 | )
36 | pred_label = torch.argmax(predicted, dim=1)
37 |
38 | print('Label:', labels[pred_label])
--------------------------------------------------------------------------------
/Judgment_sensitivity/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from transformers import BertTokenizer
4 |
5 |
6 | class InfoDataset(Dataset):
7 | def __init__(self, filename):
8 | self.labels = ["Y", "N"]
9 | self.labels_id = list(range(len(self.labels)))
10 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
11 | self.data = []
12 | self.load_data(filename)
13 |
14 | def load_data(self, filename):
15 | print('Loading data from:', filename)
16 | with open(filename, 'r', encoding='utf-8') as rf:
17 | for line in rf:
18 | text, label = line.strip().split('_')
19 | label_id = self.labels.index(label)
20 | token = self.tokenizer(
21 | text,
22 | add_special_tokens=True,
23 | padding='max_length',
24 | truncation=True,
25 | max_length=512
26 | )
27 | input_ids = torch.tensor(token['input_ids'])
28 | token_type_ids = torch.tensor(token['token_type_ids'])
29 | attention_mask = torch.tensor(token['attention_mask'])
30 |
31 | self.data.append((input_ids, token_type_ids, attention_mask, label_id))
32 |
33 | def __getitem__(self, index):
34 | input_ids, token_type_ids, attention_mask, label_id = self.data[index]
35 | return input_ids, token_type_ids, attention_mask, label_id
36 |
37 | def __len__(self):
38 | return len(self.data)
39 |
--------------------------------------------------------------------------------
/cloud1.py:
--------------------------------------------------------------------------------
1 | #cloud1 BaiDuModel
2 | import requests
3 | import json
4 | from cosSIM import similarity
5 |
6 | def cloud_model1(question):
7 | url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=EnterYourToken"
8 | payload = json.dumps({
9 | "log_id": "1234567890",
10 | "version": "3.0",
11 | "service_id": "S99567",
12 | "session_id": "",
13 | "request": {
14 | "query": question,
15 | "terminal_id": "1234567890"
16 | },
17 | "dialog_state": {
18 | "contexts": {
19 | "SYS_REMEMBERED_SKILLS": [
20 | ""
21 | ]
22 | }
23 | }
24 | })
25 | headers = {
26 | 'Content-Type': 'application/json'
27 | }
28 | response = requests.request("POST", url, headers=headers, data=payload)
29 | data = response.json()
30 | # Access and print "say" directly if it exists in the response
31 | response = data.get("result", {}).get("responses", [{}])[0].get("actions", [{}])[0].get("options", [{}])
32 | answer_list = []
33 | for i in range(0, 5):
34 | obj = response[i]
35 | res = obj.get("info").get("full_answer")
36 | answer_list.append(res)
37 | print("找到的相关信息:", answer_list)
38 | max_score = 0.0
39 | pos = -1
40 | for i in range(0, 5):
41 | cos_sim = similarity(question, answer_list[i])
42 | if max_score < cos_sim:
43 | max_score = cos_sim
44 | pos = i
45 | print("最终回复:", answer_list[pos])
46 | return answer_list[pos]
47 |
48 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | from flask import Flask, render_template, request
2 | from mark import fun_1
3 | from cloud1 import cloud_model1
4 | from cloud2 import cloud_model2
5 | from cloud3 import cloud_model3
6 | import os
7 | app = Flask(__name__, static_url_path='/static')
8 | # set env
9 | os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890" # os.environ["http_proxy"]
10 | os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890" # os.environ["https_proxy"]
11 |
12 | @app.route('/')
13 | def index():
14 | return render_template('index.html')
15 |
16 |
17 | #submit and get masked information
18 | @app.route("/get_mask", methods=['POST', 'GET'])
19 | def get_mask():
20 | user_input = request.form['user_input']
21 | print("脱敏前", user_input)
22 | if user_input == "bye":
23 | return "Goodbye!"
24 | selected_sen_level = (int)(request.form["sen_level"])
25 | selected_tag = request.form["ask_tag"]
26 | masked_text = fun_1(user_input, selected_sen_level, selected_tag)
27 | print("脱敏后", masked_text)
28 | return masked_text
29 | #get model response
30 | @app.route("/get_response", methods=['POST','GET'])
31 | def get_response():
32 | selected_cloud_model = request.form['selected_model'] #
33 | masked_text = request.form['mask_info']
34 | if selected_cloud_model == 'model1':
35 | response = cloud_model1(masked_text)
36 | elif selected_cloud_model == 'model2':
37 | response = cloud_model2(masked_text)
38 | elif selected_cloud_model == 'model3':
39 | response = cloud_model3(masked_text)
40 | return response
41 | if __name__ == '__main__':
42 | app.run(debug=True)
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import BertModel
4 |
5 | # Bert
6 | class BertTextClassifier(nn.Module):
7 | def __init__(self, bert_config, num_labels):
8 | super().__init__()
9 | # Define BERT model
10 | self.bert = BertModel(config=bert_config)
11 | # Define Classifier
12 | self.classifier = nn.Linear(bert_config.hidden_size, num_labels)
13 |
14 | def forward(self, input_ids, attention_mask, token_type_ids):
15 | # BERT's output
16 | bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
17 | # Pooled output at [CLS] position
18 | pooled = bert_output[1]
19 | # classification
20 | logits = self.classifier(pooled)
21 | # Return the result after softmax
22 | return torch.softmax(logits, dim=1)
23 |
24 |
25 | # Bert+BiLSTM
26 | class BertLstmClassifier(nn.Module):
27 | def __init__(self, bert_config, num_labels):
28 | super().__init__()
29 | self.bert = BertModel(config=bert_config)
30 | self.lstm = nn.LSTM(input_size=bert_config.hidden_size, hidden_size=bert_config.hidden_size, num_layers=2,
31 | batch_first=True, bidirectional=True)
32 | self.classifier = nn.Linear(bert_config.hidden_size * 2, num_labels)
33 | self.softmax = nn.Softmax(dim=1)
34 |
35 | def forward(self, input_ids, attention_mask, token_type_ids):
36 | outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
37 | last_hidden_state = outputs.last_hidden_state
38 | out, _ = self.lstm(last_hidden_state)
39 | logits = self.classifier(out[:, -1, :]) # Take the output at the last moment
40 | return self.softmax(logits)
--------------------------------------------------------------------------------
/cosSIM.py:
--------------------------------------------------------------------------------
1 | import jieba
2 | import math
3 |
4 | def tokenize(text):
5 | # Word segmentation and removing empty characters
6 | return [word for word in jieba.cut(text, cut_all=True) if word != '']
7 | def build_word_dict(tokenized_texts):
8 | # Create a vocabulary and assign a unique encoding for each word
9 | word_set = set()
10 | for tokens in tokenized_texts:
11 | word_set.update(tokens)
12 |
13 | word_dict = {word: idx for idx, word in enumerate(word_set)}
14 | return word_dict
15 | def text_to_code(text, word_dict):
16 | # Convert text to an encoding sequence
17 | code = [0] * len(word_dict)
18 | for word in text:
19 | code[word_dict[word]] += 1
20 | return code
21 | def cosine_similarity(vec1, vec2):
22 | # Calculate cosine similarity
23 | dot_product = sum(x * y for x, y in zip(vec1, vec2))
24 | norm1 = math.sqrt(sum(x * x for x in vec1))
25 | norm2 = math.sqrt(sum(x * x for x in vec2))
26 | try:
27 | result = round(dot_product / (norm1 * norm2), 2)
28 | except ZeroDivisionError:
29 | result = 0.0
30 | return result
31 |
32 | def similarity(sentence1, sentence2):
33 | s1_tokens = tokenize(sentence1)
34 | s2_tokens = tokenize(sentence2)
35 | word_dict = build_word_dict([s1_tokens, s2_tokens])
36 | # Convert text to an encoding sequence
37 | s1_code = text_to_code(s1_tokens, word_dict)
38 | s2_code = text_to_code(s2_tokens, word_dict)
39 | # Calculate cosine similarity
40 | similarity = cosine_similarity(s1_code, s2_code)
41 | print("与问题的余弦相似度:", similarity)
42 | return similarity
43 | def main():
44 | s1 = "what is Adams's phone number"
45 | s2 = "Adams's phone number is 15925526729"
46 | s3 = "Mitchell's phone number is 18764284516"
47 | similarity(s1, s2)
48 | similarity(s1, s3)
49 |
50 | if __name__ == '__main__':
51 | main()
--------------------------------------------------------------------------------
/static_files/images/user-avatar.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Judgment_sensitivity/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from transformers import BertModel, BertConfig
3 | from transformers import BertForSequenceClassification
4 | # Bert
5 | class BertTextClassifier(nn.Module):
6 | def __init__(self, bert_model_name, num_labels):
7 | """
8 | Args:
9 | bert_model_name (str): BERTThe name or path of the model
10 | num_labels (int): Number of categories classified
11 | """
12 | super(BertTextClassifier, self).__init__()
13 | # Define BERT model
14 | bert_config = BertConfig.from_pretrained(bert_model_name)
15 | self.bert = BertModel.from_pretrained(bert_model_name, config=bert_config)
16 | # Define Classifier
17 | self.classifier = nn.Linear(bert_config.hidden_size, num_labels)
18 |
19 | def forward(self, input_ids, attention_mask, token_type_ids):
20 | # BERT's output
21 | bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
22 | pooled = bert_output[1]
23 | # classification
24 | logits = self.classifier(pooled)
25 | return logits
26 | # Bert+BiLSTM
27 | class BertLstmClassifier(nn.Module):
28 | def __init__(self, bert_model_name, num_labels):
29 | super(BertLstmClassifier, self).__init__()
30 | # Using BertForSequenceClassification in the Hugging Face Transformers library
31 | self.bert = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=num_labels)
32 | # LSTM层
33 | self.lstm = nn.LSTM(input_size=BertConfig.hidden_size, hidden_size=BertConfig.hidden_size, num_layers=2,
34 | batch_first=True, bidirectional=True)
35 |
36 | def forward(self, input_ids, attention_mask, token_type_ids):
37 | # Using BERT for feature extraction
38 | outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
39 | # Extract BERT's output
40 | bert_output = outputs.logits
41 | # Using bidirectional LSTM for processing
42 | lstm_output, _ = self.lstm(bert_output)
43 | # Take the output at the last moment
44 | logits = lstm_output[:, -1, :]
45 | return logits
46 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FirewaLLM
2 | By calling FirewaLLM, users can ensure the accuracy of the large model while greatly reducing the risk of privacy leakage when interacting with it.
3 |
4 | With the development of large models, more and more users are using them for interactive Q&A. However, with the increasing volume of data, privacy issues have attracted widespread attention.So, we launched a privacy protection small model. Firstly, we will judge sensitive statements based on user questions, and secondly, we will handle sensitive statements accordingly. We guarantee that user privacy exposure may be greatly reduced when interacting with the model. Finally, we ensure the accuracy of large model interactions by restoring sensitive information. Experiments have shown that users interact with large models by calling small models.
5 |
6 | The FirewaLLM framework not only protects user data privacy, but also protects the accuracy of the interaction model.
7 |
8 |
9 | # Usage
10 |
11 | 1. Train FirewallM to have the ability to recognize sensitive information.
12 | 2. Open the FirewaLLM interface for interaction, and users can choose different large models for interaction.
13 | ```python
14 | python FirewaLLM/app.py
15 | ```
16 | 2. Open the FirewaLLM interface for interaction, and users can choose different large models for interaction.
17 | 3. The effect is as follows:
18 | Firstly. During the process of interacting with large models, FirewallLLM uses file filtering function to process sensitive information.
19 | 
20 | Secondly, FirewallLLM interacts data with user specified large models. Finally, the answer is processed to recover sensitive information, and the answer with the highest similarity to the question is returned.
21 | 
22 | 4. It can be intuitively seen that when a user inputs sensitive information, FirewaLLM will perform different levels of sensitive information processing operations on the sensitive information.
23 | 5. In the end, FirewaLLM will perform sensitivity restoration processing, returning the results with the highest similarity between the answers of the large model and the problem, in order to protect privacy while ensuring the accuracy of the overall process.
24 |
--------------------------------------------------------------------------------
/cloud2.py:
--------------------------------------------------------------------------------
1 | #cloud2 TencentModel
2 | import json
3 | from tencentcloud.common import credential
4 | from tencentcloud.common.profile.client_profile import ClientProfile
5 | from tencentcloud.common.profile.http_profile import HttpProfile
6 | from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
7 | from tencentcloud.tbp.v20190627 import tbp_client, models
8 |
9 | def cloud_model2(user_input):
10 | try:
11 | # Instantiating an authentication object requires passing in the Tencent Cloud account SecretId and SecretKey to enter the parameters. Here, it is also necessary to pay attention to the confidentiality of the key pair
12 | #Code leakage may lead to the leakage of SecretId and SecretKey, and threaten the security of all resources under the account. The following code example is for reference only, and it is recommended to use the key in a more secure way
13 | SecretId = "EnterYourSecretId"
14 | SecretKey = "EnterYourSecretKey"
15 | cred = credential.Credential(SecretId, SecretKey)
16 | #Instantiate an HTTP option
17 | httpProfile = HttpProfile()
18 | httpProfile.endpoint = "tbp.ap-guangzhou.tencentcloudapi.com"
19 |
20 | # Instantiating a client option
21 | clientProfile = ClientProfile()
22 | clientProfile.httpProfile = httpProfile
23 | # Instantiate the client object to request the product
24 | client = tbp_client.TbpClient(cred, "", clientProfile)
25 |
26 | #Instantiate a request object, and each interface will correspond to a request object
27 | req = models.TextProcessRequest()
28 | params = {
29 | "BotId": "58d826e3-73af-4ddb-887d-350475e88a9a",
30 | "BotEnv": "release",
31 | "TerminalId": "user1",
32 | "SessionAttributes": "True",
33 | "InputText": user_input,
34 | "PlatformType": "True",
35 | "PlatformId": "True"
36 | }
37 | req.from_json_string(json.dumps(params))
38 |
39 | #The returned resp is an instance of TextProcessResponse, corresponding to the request object
40 | resp = client.TextProcess(req)
41 | # Output the content in ResponseText
42 | response_text = json.loads(resp.to_json_string())["ResponseText"]
43 | return response_text
44 |
45 | except TencentCloudSDKException as err:
46 | return err
--------------------------------------------------------------------------------
/feature.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import jieba
3 | from sklearn.feature_extraction.text import TfidfVectorizer
4 | import pandas as pd
5 |
6 | #Import corpus
7 | def importdata(tag, flag):
8 | text_list = []
9 | if tag == "true":
10 | file_path = "data/Info/info2.txt"
11 | with open(file_path, encoding="UTF-8") as f:
12 | text = f.read()
13 | if flag:
14 | text_jieba = jieba.cut_for_search(text)
15 | else:
16 | text_jieba = jieba.lcut(text)
17 | text_list.append(" ".join(text_jieba))
18 | else:
19 | for i in range(3, 14):
20 | file_path = f"data/Info/info{i}.txt"
21 | with open(file_path, encoding="UTF-8") as f:
22 | text = f.read()
23 | text_jieba = jieba.lcut(text)
24 | text_list.append(" ".join(text_jieba))
25 | return text_list
26 |
27 | #Import stop list and remove useless text
28 | def stop_word():
29 | stopword_list = []
30 | for i in range(1, 5):
31 | file_path = f"data/StopWord/StopWord{i}.txt"
32 | with open(file_path, encoding="UTF-8") as f:
33 | stopwords = f.read().split("\n")
34 | stopword_list.extend(stopwords)
35 | return stopword_list
36 |
37 | #Use jieba participle in sentences
38 | def process_sentence(flag, sentence):
39 | if flag:
40 | sentence_jieba = jieba.lcut_for_search(sentence)
41 | sentence_jieba = " ".join(sentence_jieba)
42 | else:
43 | sentence_jieba = sentence
44 | return [sentence_jieba]
45 |
46 | #Using Tfidf to Obtain Corpus Features
47 | def vectorize_data(data, sentence_jieba):
48 | vectorizer = TfidfVectorizer(stop_words=stop_word())
49 | X = vectorizer.fit_transform(data).toarray()
50 | X_fea = vectorizer.get_feature_names_out()
51 | result = vectorizer.transform(sentence_jieba).toarray()
52 | X_pd = pd.DataFrame(result, columns=X_fea)
53 | return X_pd
54 | #Compare the input sentences with the features of the corpus
55 | def getFeature(tag, sentence, flag=True):
56 | data = importdata(tag, flag)
57 | st = process_sentence(flag, sentence)
58 | X_pd = vectorize_data(data, st)
59 | word_list = []
60 | X_pd_sort = X_pd.sort_values(by=0, axis=1, ascending=False)
61 | for i in range(0, 50):
62 | if X_pd_sort.iat[0, i] >= 0.20:
63 | tmp = X_pd_sort.iloc[:, i]
64 | tmp_frame = tmp.to_frame()
65 | word = "".join(tmp_frame.columns.tolist())
66 | if word not in word_list:
67 | word_list.append(word)
68 | else:
69 | break
70 | if len(word_list) == 0 and flag:
71 | word_list = getFeature(tag, sentence, False)
72 | return word_list
73 |
--------------------------------------------------------------------------------
/templates_files/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | FirewaLLM
7 |
8 |
9 |
10 |
11 |
12 |
13 |

14 |
15 |
16 |
17 |
18 |
19 |
20 | 青少年模式:
21 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
38 |
39 |
40 |
41 |
42 |
43 |
60 |
61 |
62 |
63 |
64 |
65 |
--------------------------------------------------------------------------------
/static_files/images/bot-avatar.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/mark.py:
--------------------------------------------------------------------------------
1 | # mark.py
2 |
3 | import re
4 | import jieba
5 | import torch
6 | from model import BertTextClassifier, BertLstmClassifier
7 | from transformers import BertTokenizer, BertConfig, AutoTokenizer
8 | import random
9 | from feature import getFeature, stop_word
10 |
11 | class MaskHandler:
12 | def __init__(self, model_path):
13 | # Initialize the local BERT model
14 | self.labels = ["Y", "N"]
15 | self.bert_config = BertConfig.from_pretrained('bert-base-chinese')
16 | self.model = BertLstmClassifier(self.bert_config, len(self.labels))
17 | self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
18 | self.model.eval()
19 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
20 |
21 | # Initialize the cloud-based roberta model
22 | self.cloud_tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
23 |
24 | def classic(self, query):
25 | sensitive_words = []
26 | for tmp_text in jieba.lcut(query):
27 | #print(tmp_text)
28 | token = self.tokenizer(tmp_text, add_special_tokens=True, padding='max_length', truncation=True,
29 | max_length=512)
30 | input_ids = torch.tensor([token['input_ids']], dtype=torch.long)
31 | attention_mask = torch.tensor([token['attention_mask']], dtype=torch.long)
32 | token_type_ids = torch.tensor([token['token_type_ids']], dtype=torch.long)
33 | predicted = self.model(input_ids, attention_mask, token_type_ids)
34 | pred_label = torch.argmax(predicted, dim=1)
35 |
36 | return sensitive_words
37 |
38 | #Desensitization algorithm
39 | def mask_sensitive_info(self, text, sensitive, level, tag):
40 | #print("脱敏等级第" + (str)(level+1) + "级")
41 | for word in sensitive:
42 | # text_jieba = jieba.lcut(text)
43 | len_word = len(word)
44 | if len_word < 8:
45 | tmp_level = 0
46 | else:
47 | tmp_level = level
48 | if tag == "true":
49 | length = len(word)
50 | else:
51 | length = (int)(len_word / 10 + tmp_level + 1)
52 | list = range(0, len(word))
53 | py = random.sample(list, length)
54 | tmp_word = word
55 | for count in range(0, length):
56 | pos = py[count]
57 | masked_sensitive = tmp_word[:pos] + '*' + tmp_word[pos + 1:]
58 | tmp_word = masked_sensitive
59 | text = re.sub(word, masked_sensitive, text, flags=re.IGNORECASE)
60 | return text
61 |
62 | #Divide sentences into punctuation marks
63 | def fun_splite(text):
64 | pattern = r',|\.|/|;|\'|`|\[|\]|<|>|\?|:|"|\{|\}|\~|!|@|#|\$|%|\^|&|\(|\)|-|=|\_|\+|,|。|、|;|‘|’|【|】|:|!| |…|(|),·'
65 | result_list = re.split(pattern, text.strip())
66 | return result_list
67 |
68 | #Split the sentence with a stop word list
69 | def fun_splitein(text):
70 | sentence_depart = jieba.lcut(text.strip())
71 | stopwords = stop_word()
72 | outstr = ""
73 | for word in sentence_depart:
74 | if word not in stopwords:
75 | if word != "\t":
76 | outstr += word
77 | # outstr += " "
78 | return outstr
79 |
80 | #Using models to determine whether sentences are sensitive
81 | def fun_isSen(maskHandler, text, tag):
82 | flag = False
83 | token = maskHandler.tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512)
84 | input_ids = torch.tensor([token['input_ids']], dtype=torch.long)
85 | attention_mask = torch.tensor([token['attention_mask']], dtype=torch.long)
86 | token_type_ids = torch.tensor([token['token_type_ids']], dtype=torch.long)
87 | predicted = maskHandler.model(input_ids, attention_mask, token_type_ids)
88 | output = torch.softmax(predicted, dim=1)
89 | print(output)
90 | #output [:, 1]is an insensitive probability
91 | if tag == "false":
92 | if output[:, 0].item() > 0.5:
93 | flag = True
94 | return flag
95 |
96 | #Invert the insensitive phrase returned by tfidf and return the sensitive phrase
97 | def getSen(nosen, text):
98 | sen = []
99 | text_jieba = jieba.lcut(text)
100 | for word in text_jieba:
101 | if word not in nosen and len(word) > 1:
102 | if word not in sen:
103 | sen.append(word)
104 | return sen
105 | #When inputting multiple sentences, determine whether each sentence is sensitive and then desensitize it
106 | def fun_1(text, selected_sen_level, tag):
107 | maskHandler = MaskHandler("model/sen_model.pkl") # Sensitive model
108 | text_splite = fun_splite(text)
109 | tmp = text
110 | for tmp_text in text_splite:
111 | text_stop = fun_splitein(tmp_text)
112 | sen_fea = []
113 | if fun_isSen(maskHandler, tmp_text, tag):
114 | if tag == "false":
115 | sen_fea = getSen(getFeature(tag, text_stop), text_stop)
116 | else:
117 | sen_fea = getFeature(tag, tmp_text, True)
118 | #print(sen_fea)
119 | res = maskHandler.mask_sensitive_info(tmp, sen_fea, selected_sen_level, tag)
120 | tmp = res
121 | print(tmp)
122 | return tmp
123 |
124 | if __name__ == '__main__':
125 | str = "..."
126 | fun_1(str, 1, "false")
127 |
128 |
--------------------------------------------------------------------------------
/Judgment_sensitivity/train.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | import os
3 | import torch
4 | import torch.nn as nn
5 | from transformers import BertTokenizer, AdamW, BertConfig
6 | from torch.utils.data import DataLoader
7 | from model import BertTextClassifier, BertLstmClassifier
8 | from dataset import InfoDataset
9 | from tqdm import tqdm
10 | from sklearn import metrics
11 | import pandas as pd
12 | import numpy as np
13 |
14 | def main():
15 | # Parameter settings
16 | batch_size = 64
17 |
18 | epochs = 50
19 | learning_rate = 1e-4
20 | # Obtain dataset
21 |
22 | train_dataset = InfoDataset("data/info/train.txt")
23 | valid_dataset = InfoDataset("data/info/test.txt")
24 | # Generate Batch
25 | valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
26 | # Read BERT's configuration file
27 | bert_config = BertConfig.from_pretrained('bert-base-chinese')
28 | num_labels = len(train_dataset.labels)
29 | # initial model
30 | model = BertLstmClassifier(bert_config, num_labels).to(device) #bert_model
31 | # optimizer
32 | optimizer = AdamW(model.parameters(), lr=learning_rate)
33 | # loss function
34 | criterion = nn.CrossEntropyLoss()
35 | best_f1 = 0
36 | tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
37 | for epoch in range(1, epochs + 1):
38 | total_loss = 0 # Total loss for the epoch
39 | total_correct = 0 # Total correct predictions for the epoch
40 | total_samples = 0 # Total processed samples for the epoch
41 |
42 | model.train()
43 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
44 | train_bar = tqdm(train_dataloader, ncols=100, desc='Epoch {} train'.format(epoch))
45 |
46 | for input_ids, token_type_ids, attention_mask, label_id in train_bar:
47 | # Move input tensors to the appropriate device (CPU or GPU)
48 | input_ids = input_ids.to(device)
49 | token_type_ids = token_type_ids.to(device)
50 | attention_mask = attention_mask.to(device)
51 | label_id = label_id.to(device)
52 |
53 | # Zero the gradients
54 | optimizer.zero_grad()
55 |
56 | # Forward pass
57 | output = model(input_ids, token_type_ids, attention_mask)
58 |
59 | # Compute the loss
60 | loss = criterion(output, label_id)
61 | total_loss += loss.item()
62 |
63 | # Backward pass and optimization
64 | loss.backward()
65 | optimizer.step()
66 |
67 | # Calculate the number of correct predictions and the total number of samples
68 | predicted_labels = torch.argmax(output, dim=1)
69 | total_correct += torch.sum(predicted_labels == label_id).item()
70 | total_samples += input_ids.size(0) # input_ids.size(0) is the batch size
71 |
72 | # Update the progress bar description
73 | train_bar.set_postfix(loss=loss.item())
74 |
75 | average_loss = total_loss / len(train_dataloader)
76 | average_accuracy = total_correct / total_samples
77 |
78 | print('Train ACC: {:.4f}\tLoss: {:.4f}'.format(average_accuracy, average_loss))
79 |
80 | # test and verify
81 | model.eval()
82 | total_loss = 0
83 | pred_labels = []
84 | true_labels = []
85 |
86 | with torch.no_grad():
87 | valid_bar = tqdm(valid_dataloader, ncols=100, desc='Epoch {} valid'.format(epoch))
88 | for input_ids, token_type_ids, attention_mask, label_id in valid_bar:
89 | input_ids = input_ids.to(device)
90 | token_type_ids = token_type_ids.to(device)
91 | attention_mask = attention_mask.to(device)
92 | label_id = label_id.to(device)
93 |
94 | output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
95 |
96 | loss = criterion(output, label_id)
97 | total_loss += loss.item()
98 |
99 | pred_label = torch.argmax(output, dim=1)
100 | acc = torch.sum(pred_label == label_id).item() / len(pred_label)
101 | valid_bar.set_postfix(loss=loss.item(), acc=acc)
102 |
103 | pred_labels.extend(pred_label.cpu().numpy().tolist())
104 | true_labels.extend(label_id.cpu().numpy().tolist())
105 |
106 | average_loss = total_loss / len(valid_dataloader)
107 | print('Validation Loss: {:.4f}'.format(average_loss))
108 |
109 | # Classification report
110 | report = metrics.classification_report(true_labels, pred_labels, labels=valid_dataset.labels_id,
111 | target_names=valid_dataset.labels)
112 | print('* Classification Report:')
113 | print(report)
114 |
115 | f1 = metrics.f1_score(true_labels, pred_labels, labels=valid_dataset.labels_id, average='micro')
116 |
117 | if not os.path.exists('models'):
118 | os.makedirs('models')
119 |
120 | if f1 > best_f1:
121 | best_f1 = f1
122 | torch.save(model.state_dict(), 'models/sen_model.pkl')
123 | #Check if the dataset meets the specifications
124 | def test_1():
125 | paths = "data/info/train.txt"
126 | paths1 = "data/info/test.txt"
127 | with open(paths, encoding="utf8") as f:
128 | lines = f.readlines()
129 | for line in tqdm(lines, ncols=100):
130 | #print(line)
131 | #print(line.strip().split("_"))
132 | print(line.strip().split('_'))
133 | text, label = line.strip().split('_')
134 | print(label)
135 | if __name__ == '__main__':
136 | # test_1() #split data
137 | main()
--------------------------------------------------------------------------------
/static_files/index.css:
--------------------------------------------------------------------------------
1 | body {
2 | padding: 0;
3 | margin: 0;
4 | font-family: Arial, sans-serif;
5 | }
6 |
7 | .container {
8 | display: flex;
9 | position: absolute;
10 | width: 100vw;
11 | height: 100vh;
12 | top: 0;
13 | left: 0;
14 | overflow: hidden;
15 | }
16 |
17 | /* Chat history management. */
18 | .chat-history {
19 | width: 250px;
20 | height: 100vh;
21 | background-color: #444;
22 | left: 0;
23 | position: relative;
24 | transition: width 0.3s;
25 | }
26 |
27 | .chat-panel {
28 | flex: 1;
29 | width: calc(100% - 250px);
30 | height: 100vh;
31 | background-color: #dae7e6;
32 | box-sizing: border-box;
33 | overflow: hidden;
34 | }
35 |
36 | .separator {
37 | height: 1px;
38 | background-color: rgba(255, 255, 255, 0.08);
39 | margin: 0 20px;
40 | }
41 |
42 |
43 | .chat-controls {
44 | display: flex;
45 | width: 100%;
46 | height: 45px;
47 | padding: 5px;
48 | box-sizing: border-box;
49 | margin-top: 8px;
50 | margin-bottom: 8px;
51 | justify-content: center;
52 | align-items: center;
53 | }
54 |
55 | .function-module {
56 | width: 100%;
57 | height: 550px;
58 | padding: 5px;
59 | box-sizing: border-box;
60 | margin-top: 8px;
61 | justify-content: center;
62 | align-items: center;
63 | overflow: auto;
64 | }
65 |
66 | .settings {
67 | display: flex;
68 | width: 100%;
69 | padding: 5px;
70 | height: calc(100% - 45px - 550px);
71 | box-sizing: border-box;
72 | margin-top: 8px;
73 | }
74 |
75 | .mode {
76 | display: flex;
77 | width: 100%;
78 | height: 45px;
79 | margin-top: 5px;
80 | justify-content: center;
81 | align-items: center;
82 | }
83 |
84 | .mode-label {
85 | position: relative;
86 | padding: 5px;
87 | box-sizing: border-box;
88 | color: white;
89 | }
90 |
91 | .switch {
92 | position: relative;
93 | display: inline-block;
94 | width: 80px;
95 | height: 34px;
96 | }
97 |
98 | .switch input {
99 | display: none;
100 | }
101 |
102 | .slider {
103 | position: absolute;
104 | cursor: pointer;
105 | top: 0;
106 | left: 0;
107 | right: 0;
108 | bottom: 0;
109 | background-color: #ccc;
110 | -webkit-transition: .4s;
111 | transition: .4s;
112 | border-radius: 34px;
113 | }
114 |
115 | .slider:before {
116 | position: absolute;
117 | content: "";
118 | height: 26px;
119 | width: 26px;
120 | left: 4px;
121 | bottom: 4px;
122 | background-color: white;
123 | -webkit-transition: .4s;
124 | transition: .4s;
125 | border-radius: 50%;
126 | }
127 |
128 | input:checked + .slider {
129 | background-color: #2196F3;
130 | }
131 |
132 | input:focus + .slider {
133 | box-shadow: 0 0 1px #2196F3;
134 | }
135 |
136 | input:checked + .slider:before {
137 | transform: translateX(46px);
138 | }
139 |
140 | .control-panel {
141 | width: 100%;
142 | height: 40px;
143 | display: inline-block;
144 | }
145 |
146 | .collapse-chat-btn {
147 | display: flex;
148 | width: 40px;
149 | height: 40px;
150 | margin: 8px 8px;
151 | padding: 10px;
152 | box-sizing: border-box;
153 | border-radius: 5px;
154 | justify-content: center;
155 | align-items: center;
156 | cursor: pointer;
157 | }
158 |
159 | .chat-panel.expanded {
160 | margin-left: 0;
161 | width: calc(100vw - 250px);
162 | }
163 |
164 | .chat-history.collapsed {
165 | width: 0;
166 | overflow: hidden;
167 | }
168 |
169 | .chat-history.collapsed + .container .chat-panel {
170 | width: 100%;
171 | }
172 |
173 | .control-panel {
174 | display: flex;
175 | width: 100%;
176 | height: 45px;
177 | margin: 8px 8px;
178 | justify-content: center;
179 | align-items: center;
180 | }
181 |
182 | .extras {
183 | display: flex;
184 | width: calc(100% - 40px);
185 | height: 40px;
186 | margin: 8px 8px;
187 | padding: 10px;
188 | box-sizing: border-box;
189 | justify-content: center;
190 | align-items: center;
191 | }
192 |
193 | .extras-one {
194 | display: flex;
195 | width: 40px;
196 | height: 40px;
197 | border: none;
198 | justify-content: center;
199 | align-items: center;
200 | background-color: transparent;
201 | }
202 |
203 | #chat-display {
204 | width: 100%;
205 | height: 550px;
206 | padding: 5px;
207 | box-sizing: border-box;
208 | margin-top: 8px;
209 | border-radius: 10px;
210 | overflow-y: scroll;
211 | }
212 |
213 | .user-input-outside {
214 | display: flex;
215 | width: 100%;
216 | height: calc(100% - 45px - 550px);
217 | box-sizing: border-box;
218 | padding: 10px;
219 | margin-bottom: 10px;
220 | justify-content: center;
221 | align-items: center;
222 | }
223 |
224 | .user-input-area {
225 | display: flex;
226 | width: 700px;
227 | height: 70px;
228 | background-color: #fff;
229 | padding: 5px;
230 | border-radius: 10px;
231 | box-sizing: border-box;
232 | justify-content: space-between;
233 | align-items: center;
234 | }
235 |
236 | #user-input {
237 | flex: 1;
238 | height: calc(100% - 10px);
239 | resize: none;
240 | border: none;
241 | margin: 0px 5px;
242 | padding: 5px;
243 | box-sizing: border-box;
244 | font-size: 16px;
245 | }
246 |
247 | #send-button {
248 | background-color: blue;
249 | color: white;
250 | border: none;
251 | border-radius: 5px;
252 | padding: 8px 10px;
253 | box-sizing: border-box;
254 | cursor: pointer;
255 | margin-right: 10px;
256 | }
257 |
258 | #model-select {
259 | color: #fff;
260 | background-color: #34495e;
261 | border-radius: 6px;
262 | padding: 8px 10px;
263 | box-sizing: border-box;
264 | border-radius: 5px;
265 | margin-right: 5px;
266 | }
267 |
268 | #sensitivity-level {
269 | color: #fff;
270 | background-color: #34495e;
271 | border-radius: 6px;
272 | padding: 8px 10px;
273 | box-sizing: border-box;
274 | border-radius: 5px;
275 | margin-left: 5px;
276 | margin-right: 5px;
277 | }
278 |
279 | .user-message-outside {
280 | display: flex;
281 | width: 100%;
282 | min-height: 45px;
283 | background-color: lightblue;
284 | justify-content: center;
285 | align-items: center;
286 | }
287 |
288 | .chatbot-message-outside {
289 | display: flex;
290 | width: 100%;
291 | min-height: 45px;
292 | background-color: lightskyblue;
293 | justify-content: center;
294 | align-items: center;
295 | }
296 |
297 | .model-message-outside {
298 | display: flex;
299 | width: 100%;
300 | min-height: 45px;
301 | justify-content: center;
302 | align-items: center;
303 | }
304 |
305 | .user-message {
306 | display: flex;
307 | width: 70%;
308 | }
309 |
310 | .chatbot-message {
311 | display: flex;
312 | width: 70%;
313 | }
314 |
315 | .model-message {
316 | display: flex;
317 | width: 70%;
318 | }
319 |
320 | .message-content {
321 | flex: 1;
322 | word-wrap: break-word;
323 | white-space: pre-line;
324 | padding: 10px;
325 | border-radius: 10px;
326 | align-self: flex-start;
327 | font-size: 16px;
328 | }
329 |
330 | .model-message-content {
331 | flex: 1;
332 | word-wrap: break-word;
333 | white-space: pre-line;
334 | padding: 10px;
335 | border-radius: 10px;
336 | align-self: flex-start;
337 | font-size: 16px;
338 | color: rgba(128, 128, 128, 0.5);
339 | }
340 |
341 | .avatar {
342 | display: flex;
343 | flex: 0 0 auto;
344 | width: 45px;
345 | height: 45px;
346 | overflow: hidden;
347 | margin: 10px 10px auto 10px;
348 | }
349 |
350 | .avatar-img {
351 | width: 80%;
352 | height: 80%;
353 | border-radius: 50%;
354 | }
355 |
356 | .copy-button {
357 | float: right;
358 | background-color: transparent;
359 | border: none;
360 | margin-right: 10px;
361 | cursor: pointer;
362 | }
363 |
364 | .hide-model-button {
365 | float: right;
366 | background-color: transparent;
367 | border: none;
368 | margin-right: 10px;
369 | cursor: pointer;
370 | color: gray;
371 | }
372 |
--------------------------------------------------------------------------------
/static_files/index.js:
--------------------------------------------------------------------------------
1 | // Fold button click event listener
2 | document.querySelector('.collapse-chat-btn').addEventListener('click', () => {
3 | document.querySelector('.chat-history').classList.toggle('collapsed');
4 | document.querySelector('.chat-panel').classList.toggle('expanded');
5 | });
6 |
7 |
8 | // Teenage mode
9 | document.addEventListener('DOMContentLoaded', function() {
10 | const toggleButton = document.getElementById('toggleButton');
11 | const hiddenSL = document.getElementById('sensitivity-level');
12 | const modelSelect = document.getElementById('model-select');
13 |
14 | toggleButton.addEventListener('change', function() {
15 | if (toggleButton.checked) {
16 | hiddenSL.style.display = 'none';
17 |
18 | modelSelect.innerHTML = `
19 |
20 |
21 |
22 | `;
23 | }
24 |
25 | else {
26 | hiddenSL.style.display = 'block';
27 |
28 | modelSelect.innerHTML = `
29 |
30 |
31 |
32 |
33 |
34 | `;
35 | }
36 | });
37 | });
38 |
39 |
40 | // Add message to chat-display
41 | function appendMessage(userType, message, modelMessage = "") {
42 | const chatDisplay = document.getElementById("chat-display");
43 | const outsideMessageContainer = document.createElement("div")
44 | outsideMessageContainer.classList.add(userType === "user" ? "user-message-outside" : "chatbot-message-outside");
45 | const messageContainer = document.createElement("div");
46 |
47 | const messageContent = document.createElement("div");
48 | messageContent.classList.add("message-content");
49 | messageContent.innerHTML = message;
50 |
51 | const avatar = document.createElement("div");
52 | avatar.classList.add("avatar");
53 | const avatarImg = document.createElement("img");
54 | avatarImg.classList.add("avatar-img")
55 |
56 | avatarImg.src = userType === "user" ? "../static/images/user-avatar.svg" : "../static/images/bot-avatar.svg";
57 | avatar.appendChild(avatarImg);
58 |
59 | // Copy button
60 | const copyButton = document.createElement("button");
61 | copyButton.textContent = "复制";
62 | copyButton.classList.add("copy-button");
63 | copyButton.addEventListener("click", function() {
64 | copyToClipboard(message);
65 | copyButton.textContent = "已复制";
66 | setTimeout(function() {
67 | copyButton.textContent = "复制";
68 | }, 1500);
69 | });
70 |
71 | // Hide model message button
72 | if (modelMessage !== "" && modelMessage.replace(/\s/g, '') != message.replace(/\s/g, '')) {
73 | const hideModelButton = document.createElement("button");
74 | hideModelButton.textContent = "隐藏模型消息";
75 | hideModelButton.classList.add("hide-model-button");
76 | hideModelButton.addEventListener("click", function() {
77 | outsideModelMessageContainer.style.display = outsideModelMessageContainer.style.display === "none" ? "flex" : "none";
78 | hideModelButton.textContent = outsideModelMessageContainer.style.display === "none" ? "展开模型消息" : "隐藏模型消息";
79 | });
80 |
81 | // Set different style class names based on user type
82 | messageContainer.classList.add(userType === "user" ? "user-message" : "chatbot-message");
83 | messageContainer.appendChild(avatar);
84 | messageContainer.appendChild(messageContent);
85 | messageContainer.appendChild(hideModelButton);
86 | messageContainer.appendChild(copyButton);
87 | outsideMessageContainer.appendChild(messageContainer)
88 |
89 | // Model message
90 | // width 100%
91 | const outsideModelMessageContainer = document.createElement("div");
92 | outsideModelMessageContainer.classList.add("model-message-outside");
93 |
94 | const modelMessageContainer = document.createElement("div");
95 | modelMessageContainer.classList.add("model-message");
96 |
97 | const modelAvatar = document.createElement("div");
98 | modelAvatar.classList.add("avatar");
99 |
100 | const modelMessageContent = document.createElement("div");
101 | modelMessageContent.classList.add("model-message-content");
102 | modelMessageContent.innerHTML = modelMessage;
103 |
104 | const modelCopyButton = document.createElement("button");
105 | modelCopyButton.textContent = "复制";
106 | modelCopyButton.classList.add("copy-button");
107 | modelCopyButton.addEventListener("click", function() {
108 | copyToClipboard(modelMessage);
109 | modelCopyButton.textContent = "已复制";
110 | setTimeout(function() {
111 | modelCopyButton.textContent = "复制";
112 | }, 1500);
113 | });
114 |
115 | modelMessageContainer.appendChild(modelAvatar);
116 | modelMessageContainer.appendChild(modelMessageContent);
117 | modelMessageContainer.appendChild(modelCopyButton);
118 | outsideModelMessageContainer.appendChild(modelMessageContainer)
119 |
120 |
121 | chatDisplay.appendChild(outsideMessageContainer);
122 | chatDisplay.appendChild(outsideModelMessageContainer);
123 | }
124 |
125 | else {
126 | messageContainer.classList.add(userType === "user" ? "user-message" : "chatbot-message");
127 | messageContainer.appendChild(avatar);
128 | messageContainer.appendChild(messageContent);
129 | messageContainer.appendChild(copyButton);
130 |
131 | outsideMessageContainer.appendChild(messageContainer)
132 | chatDisplay.appendChild(outsideMessageContainer);
133 | }
134 | }
135 |
136 |
137 | function copyToClipboard(text) {
138 | const textarea = document.createElement("textarea");
139 | textarea.value = text;
140 | document.body.appendChild(textarea);
141 | textarea.select();
142 | document.execCommand("copy");
143 | document.body.removeChild(textarea);
144 | }
145 |
146 | function sendMessage() {
147 | const userInput = document.getElementById("user-input").value;
148 | const selectedLevel = document.getElementById("sensitivity-level").value;
149 | const selectedTag = document.getElementById("toggleButton").checked;
150 | document.getElementById("user-input").value = "";
151 | fetch("/get_mask", {
152 | method: "POST",
153 | body: new URLSearchParams({
154 | "user_input": userInput,
155 | "sen_level": selectedLevel,
156 | "ask_tag":selectedTag
157 | }),
158 | headers: {
159 | "Content-Type": "application/x-www-form-urlencoded"
160 | }
161 | })
162 | .then(response => response.text())
163 | .then(data => {
164 | appendMessage("user", userInput, data);
165 | getMessage(data);
166 | })
167 | .catch(error => {
168 | console.error("Error:", error);
169 | });
170 | }
171 |
172 | function getMessage(data){
173 | const selected_model = document.getElementById("model-select").value;
174 | fetch("/get_response", {
175 | method: "POST",
176 | body: new URLSearchParams({
177 | "selected_model": selected_model,
178 | "mask_info": data
179 | }),
180 | headers: {
181 | "Content-Type": "application/x-www-form-urlencoded"
182 | }
183 | })
184 | .then(response => response.text())
185 | .then(data => {
186 | appendMessage("bot", data);
187 | })
188 | .catch(error => {
189 | console.error("Error:", error);
190 | });
191 |
192 | }
193 |
194 | document.getElementById('send-button').addEventListener('click', sendMessage);
195 | // Enter+ctrl
196 | document.getElementById('user-input').addEventListener('keydown', function (event) {
197 | if (event.key === "Enter" && !event.ctrlKey) {
198 | event.preventDefault();
199 |
200 | const startPos = this.selectionStart;
201 | const endPos = this.selectionEnd;
202 | this.value = this.value.substring(0, startPos) + "\n" + this.value.substring(endPos);
203 |
204 | this.selectionStart = startPos + 1;
205 | this.selectionEnd = startPos + 1;
206 | }
207 | else if (event.key === "Enter" && event.ctrlKey) {
208 | event.preventDefault();
209 |
210 | sendMessage();
211 | }
212 | });
213 |
214 |
215 | function clearChat() {
216 | var chatDisplay = document.getElementById("chat-display");
217 | var confirmation = confirm("记录清除后无法恢复,您确定要清除吗?");
218 | if (confirmation) {
219 | chatDisplay.innerHTML = "";
220 | }
221 | }
222 |
223 |
224 | function handleExport() {
225 | var chatDisplay = document.getElementById("chat-display");
226 | var chatContent = chatDisplay.innerText;
227 | var blob = new Blob([chatContent], { type: "text/plain" });
228 |
229 | var a = document.createElement("a");
230 | a.href = URL.createObjectURL(blob);
231 | a.download = "chat_export.txt";
232 | a.textContent = "Download";
233 | a.style.display = "none";
234 | document.body.appendChild(a);
235 | a.click();
236 | document.body.removeChild(a);
237 | }
--------------------------------------------------------------------------------