├── static ├── user_pwd.csv ├── label.csv └── history.csv ├── views ├── __pycache__ │ └── rubbish.cpython-310.pyc └── rubbish.py ├── README.md ├── LICENSE └── app.py /static/user_pwd.csv: -------------------------------------------------------------------------------- 1 | scq,123 2 | njh,123 3 | njh1,123 4 | njh12,123 5 | njh123,123 6 | -------------------------------------------------------------------------------- /static/label.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzunjh/refuse-classification/HEAD/static/label.csv -------------------------------------------------------------------------------- /static/history.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzunjh/refuse-classification/HEAD/static/history.csv -------------------------------------------------------------------------------- /views/__pycache__/rubbish.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzunjh/refuse-classification/HEAD/views/__pycache__/rubbish.cpython-310.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # refuse-classification 2 | 基于深度学习的垃圾分类系统(模型使用ONNX导入) 3 | 前端地址: https://github.com/MxianD/Rubbish_Classify 4 | 5 | ONNX垃圾分类模型的网盘地址 6 | 链接:https://pan.baidu.com/s/1ne7Y68izkcoJytFwPH--XA?pwd=fkx6 7 | 提取码:fkx6 8 | 下载后请将其存放在static文件下 9 | 前端启动命令 10 | ``` 11 | 安装依赖项 12 | npm i 13 | 14 | 运行前端工程 15 | npm run dev 16 | ``` 17 | 后端安装完包以后,直接运行app.py文件即可 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 家辉酱 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /views/rubbish.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | import cv2 4 | import numpy as np 5 | import onnxruntime as ort 6 | 7 | model_path = 'static/e10_resnet50.onnx' 8 | ort_session = ort.InferenceSession(model_path, providers=ort.get_available_providers()) 9 | 10 | 11 | def getRubbish(img): 12 | predictions = [] 13 | img = cv2.imdecode(np.fromfile(img, dtype=np.uint8), 1) 14 | img = cv2.resize(img, (224, 224)) 15 | img = np.transpose(img, (2, 0, 1)) 16 | img = np.expand_dims(img, 0) 17 | img = img.astype(np.float32) 18 | img /= 255 19 | ort_inputs = {ort_session.get_inputs()[0].name: img} 20 | ort_outputs = ort_session.run(None, ort_inputs) 21 | 22 | # 获取预测结果 23 | predictions.extend(ort_outputs[0]) 24 | 25 | # 将预测结果转换为numpy数组 26 | predictions_np = np.array(predictions) 27 | 28 | # 对预测结果应用softmax 29 | predictions_softmax = softmax(predictions_np) 30 | 31 | # 获取预测标签 32 | predicted_labels = np.argmax(predictions_softmax, axis=1) 33 | return predicted_labels 34 | 35 | 36 | def softmax(x): 37 | exp_x = np.exp(x) 38 | softmax_x = exp_x / np.sum(exp_x, axis=1, keepdims=True) 39 | return softmax_x 40 | 41 | 42 | def validate_login(username, password): 43 | with open('static/user_pwd.csv', 'r') as file: 44 | reader = csv.reader(file) 45 | for row in reader: 46 | if row[0] == username and row[1] == password: 47 | return True 48 | return False 49 | 50 | 51 | def register_user(username, password): 52 | with open('static/user_pwd.csv', 'a', newline='') as file: 53 | writer = csv.writer(file) 54 | writer.writerow([username, password]) 55 | 56 | 57 | def get_user_info(username): 58 | user_data = {} 59 | with open('static/user_pwd.csv', mode='r') as csvfile: 60 | reader = csv.DictReader(csvfile) 61 | for row in reader: 62 | if row['username'] == username: 63 | user_data = row 64 | break 65 | return user_data 66 | 67 | 68 | def update_password(username, old_password, new_password): 69 | success = False 70 | users = [] 71 | with open('static/user_pwd.csv', mode='r') as csvfile: 72 | reader = csv.DictReader(csvfile) 73 | for row in reader: 74 | if row['username'] == username: 75 | if row['password'] == old_password: 76 | row['password'] = new_password 77 | success = True 78 | users.append(row) 79 | 80 | if success: 81 | with open('static/user_pwd.csv', mode='w', newline='') as csvfile: 82 | fieldnames = ['username', 'password'] 83 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 84 | writer.writeheader() 85 | for user in users: 86 | writer.writerow(user) 87 | 88 | return success 89 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import os 4 | import ast 5 | from datetime import datetime 6 | from math import ceil 7 | import pandas as pd 8 | from flask import Flask, request, jsonify 9 | from flask import json 10 | 11 | from views.rubbish import validate_login, register_user, getRubbish, update_password, get_user_info 12 | 13 | app = Flask(__name__) 14 | if __name__ == '__main__': 15 | app.run() 16 | 17 | app = Flask(__name__) 18 | 19 | # 全局字典,存储用户的分类进度 20 | progress = {} 21 | 22 | 23 | # 登录功能 24 | @app.route('/login', methods=['POST']) 25 | def login(): 26 | username = request.form.get('username') 27 | password = request.form.get('password') 28 | 29 | if validate_login(username, password): 30 | return jsonify({'success': True, 'message': '登录成功'}) 31 | else: 32 | return jsonify({'success': False, 'message': '用户名或密码错误'}) 33 | 34 | 35 | # 注册功能 36 | @app.route('/register', methods=['POST']) 37 | def register(): 38 | username = request.form.get('username') 39 | password = request.form.get('password') 40 | 41 | if validate_login(username, password): 42 | return jsonify({'success': False, 'message': '用户名已存在'}) 43 | else: 44 | register_user(username, password) 45 | return jsonify({'success': True, 'message': '注册成功'}) 46 | 47 | 48 | # 获取用户个人信息 49 | @app.route('/user_info', methods=['GET']) 50 | def user_info(): 51 | username = request.args.get('username') 52 | 53 | user = get_user_info(username) 54 | if user: 55 | return jsonify({'success': True, 'user': user}) 56 | else: 57 | return jsonify({'success': False, 'message': '用户不存在'}) 58 | 59 | 60 | # 修改密码 61 | @app.route('/change_password', methods=['POST']) 62 | def change_password(): 63 | username = request.form.get('username') 64 | old_password = request.form.get('old_password') 65 | new_password = request.form.get('new_password') 66 | 67 | result = update_password(username, old_password, new_password) 68 | if result: 69 | return jsonify({'success': True, 'message': '密码修改成功'}) 70 | else: 71 | return jsonify({'success': False, 'message': '原密码错误'}) 72 | 73 | 74 | # 垃圾批量异步处理 75 | @app.route('/rubbish', methods=['POST']) 76 | def rubbish(): 77 | df = pd.read_csv('static/label.csv', encoding='gbk', header=None, names=["name", "class"]) 78 | labels = df.iloc[:, 0] 79 | files = request.files.getlist('file') # 获取多个文件 80 | 81 | async def process_file(file, username): 82 | rest = await asyncio.get_event_loop().run_in_executor(None, getRubbish, file) # 异步执行分类任务 83 | rest = labels[rest].values[0] 84 | bgroup, sgroup = rest.split('/') 85 | formatted_rest = {"Bgroup": bgroup, "Sgroup": sgroup} 86 | 87 | # 更新进度 88 | if username in progress: 89 | progress[username] += 1 90 | else: 91 | progress[username] = 1 92 | 93 | return formatted_rest 94 | 95 | async def process_files(fileIO, username): 96 | tasks = [] 97 | for file in fileIO: 98 | task = asyncio.create_task(process_file(file, username)) 99 | tasks.append(task) 100 | return await asyncio.gather(*tasks) 101 | 102 | username = request.form.get('username') 103 | loop = asyncio.new_event_loop() 104 | asyncio.set_event_loop(loop) 105 | results = loop.run_until_complete(process_files(files, username)) 106 | 107 | # 将垃圾分类记录存储在history.csv文件中 108 | image_paths = [os.path.normpath(os.path.join('C:/Users/27877/Desktop/validate', file.filename)) for file in files] 109 | 110 | history_file = 'static/history.csv' 111 | 112 | # 检查文件是否存在 113 | if os.path.exists(history_file): 114 | # 检查文件是否为空 115 | if os.stat(history_file).st_size > 0: 116 | history = pd.read_csv(history_file, encoding='gbk') 117 | else: 118 | # 如果文件为空,创建一个空的DataFrame 119 | history = pd.DataFrame(columns=['username', 'result', 'image_path', 'process_time']) 120 | else: 121 | # 如果文件不存在,创建一个空的DataFrame 122 | history = pd.DataFrame(columns=['username', 'result', 'image_path', 'process_time']) 123 | 124 | # 获取当前时间 125 | current_time = datetime.now().strftime("%Y/%m/%d %H:%M:%S") 126 | 127 | new_records = [ 128 | {'username': username, 'result': json.dumps(result, ensure_ascii=False), 'image_path': path, 129 | 'process_time': current_time} for 130 | result, path in 131 | zip(results, image_paths)] 132 | new_records_df = pd.DataFrame(new_records) 133 | history = pd.concat([history, new_records_df], ignore_index=True) 134 | history.to_csv('static/history.csv', index=False, encoding='gbk') 135 | 136 | return jsonify(results) 137 | 138 | 139 | # 添加进度查询接口 140 | @app.route('/progress', methods=['GET']) 141 | def get_progress(): 142 | username = request.args.get('username') 143 | if username and username in progress: 144 | return jsonify({"progress": progress[username]}) 145 | else: 146 | return jsonify({"error": "User not found"}), 404 147 | 148 | 149 | # 查询用户的垃圾分类记录 150 | @app.route('/history', methods=['GET']) 151 | def get_history(): 152 | username = request.args.get('username') 153 | page = int(request.args.get('page', 1)) 154 | per_page = int(request.args.get('per_page', 10)) 155 | date = request.args.get('date', None) 156 | bgroup = request.args.get('bgroup', None) 157 | 158 | history = pd.read_csv('static/history.csv', encoding='gbk') 159 | user_history = history[history['username'] == username] 160 | 161 | # 根据日期筛选 162 | if date: 163 | user_history = user_history[user_history['process_time'].apply( 164 | lambda x: datetime.strptime(x.split(' ')[0], '%Y/%m/%d').date() == datetime.strptime(date, 165 | '%Y-%m-%d').date())] 166 | 167 | # 根据 result 中的 Bgroup 筛选 168 | if bgroup: 169 | user_history = user_history[ 170 | user_history['result'].apply(lambda x: bgroup in ast.literal_eval(x).get('Bgroup', ''))] 171 | 172 | total_records = len(user_history) 173 | total_pages = ceil(total_records / per_page) 174 | 175 | start = (page - 1) * per_page 176 | end = start + per_page 177 | user_history = user_history.iloc[start:end] 178 | 179 | user_records = [] 180 | for _, row in user_history.iterrows(): 181 | record = { 182 | 'image_path': row['image_path'], 183 | 'result': ast.literal_eval(row['result']), 184 | 'process_time': row['process_time'] 185 | } 186 | user_records.append(record) 187 | 188 | return jsonify({ 189 | 'success': True, 190 | 'records': user_records, 191 | 'total_records': total_records, 192 | 'total_pages': total_pages, 193 | 'current_page': page, 194 | 'per_page': per_page 195 | }) 196 | --------------------------------------------------------------------------------