├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── doc2json.py ├── infra-add.py ├── infra.py ├── requirements.txt └── settings_mgr.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tar filter=lfs diff=lfs merge=lfs -text 29 | *.tflite filter=lfs diff=lfs merge=lfs -text 30 | *.tgz filter=lfs diff=lfs merge=lfs -text 31 | *.wasm filter=lfs diff=lfs merge=lfs -text 32 | *.xz filter=lfs diff=lfs merge=lfs -text 33 | *.zip filter=lfs diff=lfs merge=lfs -text 34 | *.zst filter=lfs diff=lfs merge=lfs -text 35 | *tfevents* filter=lfs diff=lfs merge=lfs -text 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/* 2 | .venv* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023, 2024 Nils Durner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLX Chat 2 | 3 | Chat interface for [MLX](https://github.com/ml-explore/mlx) for on-device Language Model use on Apple Silicon. 4 | Built on [FastMLX](https://github.com/Blaizzy/fastmlx). 5 | 6 | ## Features: 7 | * Plaintext file upload 8 | * chat history download 9 | * file download 10 | * example: download an ICS calendar file the model has created for you 11 | * streaming chat 12 | 13 | ## Using 14 | 1. (Install fastmlx: `pip3 install mlx-lm fastmlx==0.2.1`) 15 | 1. (Install model(s): run infra.py to view supported/installed models, modify & run infra-add.py to download & install new model) 16 | 1. `pip3 install -r requirements.txt` 17 | 1. `python3 ./app.py`` 18 | 1. Check output, open "local URL" in browser 19 | 1. Enter/select locally available model to chat with -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import base64 3 | import os 4 | import requests 5 | import json 6 | import fitz 7 | from PIL import Image 8 | import io 9 | from settings_mgr import generate_download_settings_js, generate_upload_settings_js 10 | 11 | from doc2json import process_docx 12 | 13 | dump_controls = False 14 | log_to_console = False 15 | 16 | temp_files = [] 17 | 18 | def encode_image(image_data): 19 | """Generates a prefix for image base64 data in the required format for the 20 | four known image formats: png, jpeg, gif, and webp. 21 | 22 | Args: 23 | image_data: The image data, encoded in base64. 24 | 25 | Returns: 26 | A string containing the prefix. 27 | """ 28 | 29 | # Get the first few bytes of the image data. 30 | magic_number = image_data[:4] 31 | 32 | # Check the magic number to determine the image type. 33 | if magic_number.startswith(b'\x89PNG'): 34 | image_type = 'png' 35 | elif magic_number.startswith(b'\xFF\xD8'): 36 | image_type = 'jpeg' 37 | elif magic_number.startswith(b'GIF89a'): 38 | image_type = 'gif' 39 | elif magic_number.startswith(b'RIFF'): 40 | if image_data[8:12] == b'WEBP': 41 | image_type = 'webp' 42 | else: 43 | # Unknown image type. 44 | raise Exception("Unknown image type") 45 | else: 46 | # Unknown image type. 47 | raise Exception("Unknown image type") 48 | 49 | return f"data:image/{image_type};base64,{base64.b64encode(image_data).decode('utf-8')}" 50 | 51 | def process_pdf_img(pdf_fn: str): 52 | pdf = fitz.open(pdf_fn) 53 | message_parts = [] 54 | 55 | for page in pdf.pages(): 56 | # Create a transformation matrix for rendering at the calculated scale 57 | mat = fitz.Matrix(0.6, 0.6) 58 | 59 | # Render the page to a pixmap 60 | pix = page.get_pixmap(matrix=mat, alpha=False) 61 | 62 | # Convert pixmap to PIL Image 63 | img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) 64 | 65 | # Convert PIL Image to bytes 66 | img_byte_arr = io.BytesIO() 67 | img.save(img_byte_arr, format='PNG') 68 | img_byte_arr = img_byte_arr.getvalue() 69 | 70 | # Encode image to base64 71 | base64_encoded = base64.b64encode(img_byte_arr).decode('utf-8') 72 | 73 | # Construct the data URL 74 | image_url = f"data:image/png;base64,{base64_encoded}" 75 | 76 | # Append the message part 77 | message_parts.append({ 78 | "type": "text", 79 | "text": f"Page {page.number} of file '{pdf_fn}'" 80 | }) 81 | message_parts.append({ 82 | "type": "image_url", 83 | "image_url": { 84 | "url": image_url, 85 | "detail": "high" 86 | } 87 | }) 88 | 89 | pdf.close() 90 | 91 | return message_parts 92 | 93 | def encode_file(fn: str) -> list: 94 | user_msg_parts = {} 95 | text_content = "" 96 | 97 | # if fn.endswith(".docx"): 98 | # user_msg_parts.append({"type": "text", "text": process_docx(fn)}) 99 | # elif fn.endswith(".pdf"): 100 | # user_msg_parts.extend(process_pdf_img(fn)) 101 | if False: 102 | pass 103 | else: 104 | with open(fn, mode="rb") as f: 105 | content = f.read() 106 | 107 | isImage = False 108 | if isinstance(content, bytes): 109 | try: 110 | # try to add as image 111 | content = encode_image(content) 112 | isImage = True 113 | except: 114 | # not an image, try text 115 | text_content = text_content + content.decode('utf-8', 'replace') 116 | else: 117 | text_content = text_content + str(content) 118 | 119 | if isImage: 120 | pass 121 | # user_msg_parts.append({"type": "image_url", 122 | # "image_url":{"url": content}}) 123 | else: 124 | user_msg_parts["text"] = text_content 125 | 126 | return user_msg_parts 127 | 128 | def undo(history): 129 | history.pop() 130 | return history 131 | 132 | def dump(history): 133 | return str(history) 134 | 135 | def load_settings(): 136 | # Dummy Python function, actual loading is done in JS 137 | pass 138 | 139 | def save_settings(acc, sec, prompt, temp, tokens, model): 140 | # Dummy Python function, actual saving is done in JS 141 | pass 142 | 143 | def process_values_js(): 144 | return """ 145 | () => { 146 | return ["system_prompt"]; 147 | } 148 | """ 149 | 150 | def bot(message, history, system_prompt, temperature, max_tokens, model): 151 | try: 152 | if False: 153 | pass 154 | else: 155 | if log_to_console: 156 | print(f"bot history: {str(history)}") 157 | 158 | history_openai_format = [] 159 | user_msg_parts = "" 160 | if system_prompt: 161 | user_msg_parts = system_prompt + "\n" 162 | for human, assi in history: 163 | if human is not None: 164 | if type(human) is tuple: 165 | fc = encode_file(human[0]) 166 | if fc["text"]: 167 | user_msg_parts = user_msg_parts + fc["text"] 168 | else: 169 | user_msg_parts = user_msg_parts + human 170 | 171 | if assi is not None: 172 | if user_msg_parts: 173 | history_openai_format.append({"role": "user", "content": user_msg_parts}) 174 | user_msg_parts = "" 175 | 176 | history_openai_format.append({"role": "assistant", "content": assi}) 177 | 178 | if message['text']: 179 | user_msg_parts = user_msg_parts + message['text'] 180 | if message['files']: 181 | for file in message['files']: 182 | fc = encode_file(file['path']) 183 | if fc["text"]: 184 | user_msg_parts = user_msg_parts + fc["text"] 185 | history_openai_format.append({"role": "user", "content": user_msg_parts}) 186 | user_msg_parts = [] 187 | 188 | if log_to_console: 189 | print(f"br_prompt: {str(history_openai_format)}") 190 | 191 | url = "http://localhost:8000/v1/chat/completions" 192 | headers = {"Content-Type": "application/json"} 193 | data = { 194 | "model": model, 195 | "messages": history_openai_format, 196 | "max_tokens": max_tokens, 197 | "stream": True 198 | } 199 | 200 | response = requests.post(url, headers=headers, json=data, stream=True) 201 | 202 | if response.status_code != 200: 203 | gr.Error(f"Error: Received status code {response.status_code} {response.text}") 204 | return 205 | 206 | full_content = "" 207 | 208 | try: 209 | for line in response.iter_lines(): 210 | if line: 211 | line = line.decode('utf-8') 212 | if line.startswith('data: '): 213 | event_data = line[6:] # Remove 'data: ' prefix 214 | if event_data == '[DONE]': 215 | break 216 | try: 217 | chunk_data = json.loads(event_data) 218 | content = chunk_data['choices'][0]['delta']['content'] 219 | full_content += content 220 | yield full_content 221 | except json.JSONDecodeError: 222 | gr.Error(f"Failed to decode JSON: {event_data}") 223 | except KeyError: 224 | gr.Error(f"Unexpected data structure: {chunk_data}") 225 | 226 | except requests.exceptions.RequestException as e: 227 | gr.Error(f"An error occurred: {e}") 228 | 229 | if log_to_console: 230 | print(f"br_result: {str(full_content)}") 231 | 232 | except Exception as e: 233 | raise gr.Error(f"Error: {str(e)}") 234 | 235 | def import_history(history, file): 236 | with open(file.name, mode="rb") as f: 237 | content = f.read() 238 | 239 | if isinstance(content, bytes): 240 | content = content.decode('utf-8', 'replace') 241 | else: 242 | content = str(content) 243 | os.remove(file.name) 244 | 245 | # Deserialize the JSON content 246 | import_data = json.loads(content) 247 | 248 | # Check if 'history' key exists for backward compatibility 249 | if 'history' in import_data: 250 | history = import_data['history'] 251 | system_prompt.value = import_data.get('system_prompt', '') # Set default if not present 252 | else: 253 | # Assume it's an old format with only history data 254 | history = import_data 255 | 256 | return history, system_prompt.value # Return system prompt value to be set in the UI 257 | 258 | with gr.Blocks(delete_cache=(86400, 86400)) as demo: 259 | gr.Markdown("# MLX Chat (Nils' Version™️)") 260 | with gr.Accordion("Startup"): 261 | gr.Markdown("""Use of this interface permitted under the terms and conditions of the 262 | [MIT license](https://github.com/ndurner/oai_chat/blob/main/LICENSE). 263 | Third party terms and conditions apply, particularly 264 | those of the LLM vendor and hosting provider (e.g. Hugging Face). This app and the AI models may make mistakes, so verify any outputs.""") 265 | 266 | model = gr.Dropdown(label="Model", value="meta-llama/Meta-Llama-3.1-8B-Instruct", allow_custom_value=True, elem_id="model", 267 | choices=["meta-llama/Meta-Llama-3.1-8B-Instruct", "google/gemma-2-9b-it"]) 268 | system_prompt = gr.TextArea("You are a helpful yet diligent AI assistant. Answer faithfully and factually correct. Respond with 'I do not know' if uncertain.", label="System Prompt", lines=3, max_lines=250, elem_id="system_prompt") 269 | temp = gr.Slider(0, 2, label="Temperature", elem_id="temp", value=1) 270 | max_tokens = gr.Slider(1, 16384, label="Max. Tokens", elem_id="max_tokens", value=800) 271 | save_button = gr.Button("Save Settings") 272 | load_button = gr.Button("Load Settings") 273 | dl_settings_button = gr.Button("Download Settings") 274 | ul_settings_button = gr.Button("Upload Settings") 275 | 276 | load_button.click(load_settings, js=""" 277 | () => { 278 | let elems = ['#system_prompt textarea', '#temp input', '#max_tokens input', '#model']; 279 | elems.forEach(elem => { 280 | let item = document.querySelector(elem); 281 | let event = new InputEvent('input', { bubbles: true }); 282 | item.value = localStorage.getItem(elem.split(" ")[0].slice(1)) || ''; 283 | item.dispatchEvent(event); 284 | }); 285 | } 286 | """) 287 | 288 | save_button.click(save_settings, [system_prompt, temp, max_tokens, model], js=""" 289 | (oai, sys, temp, ntok, model) => { 290 | localStorage.setItem('system_prompt', sys); 291 | localStorage.setItem('temp', document.querySelector('#temp input').value); 292 | localStorage.setItem('max_tokens', document.querySelector('#max_tokens input').value); 293 | localStorage.setItem('model', model); 294 | } 295 | """) 296 | 297 | control_ids = [ 298 | ('system_prompt', '#system_prompt textarea'), 299 | ('temp', '#temp input'), 300 | ('max_tokens', '#max_tokens input'), 301 | ('model', '#model')] 302 | controls = [system_prompt, temp, max_tokens, model] 303 | 304 | dl_settings_button.click(None, controls, js=generate_download_settings_js("oai_chat_settings.bin", control_ids)) 305 | ul_settings_button.click(None, None, None, js=generate_upload_settings_js(control_ids)) 306 | 307 | chat = gr.ChatInterface(fn=bot, multimodal=True, additional_inputs=controls, retry_btn = None, autofocus = False) 308 | chat.textbox.file_count = "multiple" 309 | chatbot = chat.chatbot 310 | chatbot.show_copy_button = True 311 | chatbot.height = 350 312 | 313 | if dump_controls: 314 | with gr.Row(): 315 | dmp_btn = gr.Button("Dump") 316 | txt_dmp = gr.Textbox("Dump") 317 | dmp_btn.click(dump, inputs=[chatbot], outputs=[txt_dmp]) 318 | 319 | with gr.Accordion("Import/Export", open = False): 320 | import_button = gr.UploadButton("History Import") 321 | export_button = gr.Button("History Export") 322 | export_button.click(lambda: None, [chatbot, system_prompt], js=""" 323 | (chat_history, system_prompt) => { 324 | const export_data = { 325 | history: chat_history, 326 | system_prompt: system_prompt 327 | }; 328 | const history_json = JSON.stringify(export_data); 329 | const blob = new Blob([history_json], {type: 'application/json'}); 330 | const url = URL.createObjectURL(blob); 331 | const a = document.createElement('a'); 332 | a.href = url; 333 | a.download = 'chat_history.json'; 334 | document.body.appendChild(a); 335 | a.click(); 336 | document.body.removeChild(a); 337 | URL.revokeObjectURL(url); 338 | } 339 | """) 340 | dl_button = gr.Button("File download") 341 | dl_button.click(lambda: None, [chatbot], js=""" 342 | (chat_history) => { 343 | // Attempt to extract content enclosed in backticks with an optional filename 344 | const contentRegex = /```(\\S*\\.(\\S+))?\\n?([\\s\\S]*?)```/; 345 | const match = contentRegex.exec(chat_history[chat_history.length - 1][1]); 346 | if (match && match[3]) { 347 | // Extract the content and the file extension 348 | const content = match[3]; 349 | const fileExtension = match[2] || 'txt'; // Default to .txt if extension is not found 350 | const filename = match[1] || `download.${fileExtension}`; 351 | // Create a Blob from the content 352 | const blob = new Blob([content], {type: `text/${fileExtension}`}); 353 | // Create a download link for the Blob 354 | const url = URL.createObjectURL(blob); 355 | const a = document.createElement('a'); 356 | a.href = url; 357 | // If the filename from the chat history doesn't have an extension, append the default 358 | a.download = filename.includes('.') ? filename : `${filename}.${fileExtension}`; 359 | document.body.appendChild(a); 360 | a.click(); 361 | document.body.removeChild(a); 362 | URL.revokeObjectURL(url); 363 | } else { 364 | // Inform the user if the content is malformed or missing 365 | alert('Sorry, the file content could not be found or is in an unrecognized format.'); 366 | } 367 | } 368 | """) 369 | import_button.upload(import_history, inputs=[chatbot, import_button], outputs=[chatbot, system_prompt]) 370 | 371 | demo.unload(lambda: [os.remove(file) for file in temp_files]) 372 | demo.queue().launch() -------------------------------------------------------------------------------- /doc2json.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import json 3 | import zipfile 4 | from lxml import etree 5 | 6 | # Define common fonts to ignore 7 | common_fonts = { 8 | 'Times New Roman', 9 | 'Arial', 10 | 'Calibri', 11 | # Add any other common fonts here 12 | } 13 | 14 | # Define elements to ignore 15 | ignored_elements = { 16 | 'proofErr', 17 | 'bookmarkStart', 18 | 'bookmarkEnd', 19 | 'lastRenderedPageBreak', 20 | 'webHidden', 21 | 'numPr', 22 | 'pBdr', 23 | 'ind', 24 | 'spacing', 25 | 'jc', 26 | 'tabs', 27 | 'sectPr', 28 | 'pgMar' 29 | # Add any other elements to ignore here 30 | } 31 | 32 | # Define attributes to ignore 33 | ignored_attributes = { 34 | 'rsidR', 35 | 'rsidRPr', 36 | 'rsidRDefault', 37 | 'rsidP', 38 | 'paraId', 39 | 'textId', 40 | 'rsidR', 41 | 'rsidRPr', 42 | 'rsidDel', 43 | 'rsidP', 44 | 'rsidTr', 45 | # Add any other attributes to ignore here 46 | } 47 | 48 | # Define metadata elements to ignore 49 | ignored_metadata_elements = { 50 | 'application', 51 | 'docSecurity', 52 | 'scaleCrop', 53 | 'linksUpToDate', 54 | 'charactersWithSpaces', 55 | 'hiddenSlides', 56 | 'mmClips', 57 | 'notes', 58 | 'words', 59 | 'characters', 60 | 'pages', 61 | 'lines', 62 | 'paragraphs', 63 | 'company', 64 | 'template', 65 | # Add any other metadata elements to ignore here 66 | } 67 | 68 | def remove_ignored_elements(tree): 69 | """Remove all ignored elements from the XML tree, except highlights.""" 70 | for elem in tree.xpath(".//*"): 71 | tag_without_ns = elem.tag.split('}')[-1] 72 | if tag_without_ns in ignored_elements: 73 | elem.getparent().remove(elem) 74 | elif elem.tag == '{http://schemas.openxmlformats.org/wordprocessingml/2006/main}rPr': # Check for highlights in rPr 75 | if not any(child.tag.endswith('highlight') for child in elem.getchildren()): 76 | elem.getparent().remove(elem) 77 | else: 78 | # Remove ignored attributes 79 | for attr in list(elem.attrib): 80 | attr_without_ns = attr.split('}')[-1] 81 | if attr_without_ns in ignored_attributes or attr_without_ns.startswith('rsid'): 82 | del elem.attrib[attr] 83 | return tree 84 | 85 | def etree_to_dict(t): 86 | """Convert an lxml etree to a nested dictionary, excluding ignored namespaces and attributes.""" 87 | tag = t.tag.split('}')[-1] # Remove namespace URI 88 | if tag in ignored_elements: 89 | return None 90 | 91 | d = {tag: {} if t.attrib else None} 92 | children = list(t) 93 | if children: 94 | dd = defaultdict(list) 95 | for dc in filter(None, map(etree_to_dict, children)): 96 | for k, v in dc.items(): 97 | dd[k].append(v) 98 | d = {tag: {k: v[0] if len(v) == 1 else v for k, v in dd.items()}} 99 | 100 | if t.attrib: 101 | # Filter out common fonts and ignored attributes 102 | filtered_attribs = {} 103 | for k, v in t.attrib.items(): 104 | k = k.split('}')[-1] # Remove namespace URI 105 | if k in ('ascii', 'hAnsi', 'cs', 'eastAsia'): 106 | if v not in common_fonts: 107 | filtered_attribs[k] = v 108 | elif k not in ignored_attributes and not k.startswith('rsid'): 109 | filtered_attribs[k] = v 110 | d[tag].update(filtered_attribs) 111 | 112 | if t.text: 113 | text = t.text.strip() 114 | # Here we ensure that the text encoding is correctly handled 115 | text = bytes(text, 'utf-8').decode('utf-8', 'ignore') 116 | if children or t.attrib: 117 | if text: 118 | d[tag]['#text'] = text 119 | else: 120 | d[tag] = text 121 | 122 | if not t.attrib and not children and not t.text: 123 | return None 124 | 125 | return d 126 | 127 | # Additionally, update the 'remove_ignored_elements' function to fix encoding 128 | def remove_ignored_elements(tree): 129 | """Remove all ignored elements from the XML tree, except highlights.""" 130 | for elem in tree.xpath(".//*"): 131 | tag_without_ns = elem.tag.split('}')[-1] 132 | if tag_without_ns in ignored_elements: 133 | elem.getparent().remove(elem) 134 | elif elem.tag == '{http://schemas.openxmlformats.org/wordprocessingml/2006/main}rPr': # Check for highlights in rPr 135 | if not any(child.tag.endswith('highlight') for child in elem.getchildren()): 136 | elem.getparent().remove(elem) 137 | else: 138 | # Remove ignored attributes 139 | for attr in list(elem.attrib): 140 | attr_without_ns = attr.split('}')[-1] 141 | if attr_without_ns in ignored_attributes or attr_without_ns.startswith('rsid'): 142 | del elem.attrib[attr] 143 | # Decode the text correctly for each XML element 144 | for elem in tree.xpath(".//text()"): 145 | elem_text = elem.strip() 146 | encoded_text = bytes(elem_text, 'utf-8').decode('utf-8', 'ignore') 147 | parent = elem.getparent() 148 | if parent is not None: 149 | parent.text = encoded_text 150 | return tree 151 | 152 | def extract_metadata(docx): 153 | """Extract metadata from the document properties, ignoring specified elements.""" 154 | metadata = {} 155 | with docx.open('docProps/core.xml') as core_xml: 156 | xml_content = core_xml.read() 157 | core_tree = etree.XML(xml_content) 158 | for child in core_tree.getchildren(): 159 | tag = child.tag.split('}')[-1] # Get tag without namespace 160 | if tag not in ignored_metadata_elements: 161 | metadata[tag] = child.text 162 | return metadata 163 | 164 | def process_docx(file_path): 165 | # Load the document with zipfile and lxml 166 | with zipfile.ZipFile(file_path) as docx: 167 | metadata = extract_metadata(docx) 168 | with docx.open('word/document.xml') as document_xml: 169 | xml_content = document_xml.read() 170 | document_tree = etree.XML(xml_content) 171 | 172 | # Remove the ignored elements 173 | document_tree = remove_ignored_elements(document_tree) 174 | 175 | # Convert the rest of the XML tree to a dictionary 176 | document_dict = etree_to_dict(document_tree) 177 | document_dict['metadata'] = metadata # Add metadata to the document dictionary 178 | 179 | docx_json = json.dumps(document_dict, ensure_ascii=False, indent=2) 180 | 181 | return docx_json 182 | -------------------------------------------------------------------------------- /infra-add.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | url = "http://localhost:8000/v1/models" 4 | params = { 5 | "model_name": "google/gemma-2-9b-it", 6 | } 7 | 8 | response = requests.post(url, params=params) 9 | print(response.json()) -------------------------------------------------------------------------------- /infra.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | url = "http://localhost:8000/v1/supported_models" 4 | response = requests.get(url) 5 | print("-- Supported\n") 6 | print(response.json()) 7 | 8 | url = "http://localhost:8000/v1/models" 9 | response = requests.get(url) 10 | print("\n\n-- Models\n") 11 | print(response.json()) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gradio >= 4.38.1 2 | requests 3 | lxml 4 | PyMuPDF -------------------------------------------------------------------------------- /settings_mgr.py: -------------------------------------------------------------------------------- 1 | def generate_download_settings_js(dl_fn, control_ids): 2 | js_code = """ 3 | async (""" + ", ".join([f'{ctrl[0]}' for ctrl in control_ids]) + """) => { 4 | const password = prompt("Please enter a password for encryption", " "); 5 | if (!password) { 6 | alert("No password provided. Cancelling download."); 7 | return; 8 | } 9 | 10 | let settings = {""" + ", ".join([f'"{ctrl[0]}": {ctrl[0]}' for ctrl in control_ids]) + """}; 11 | const settingsStr = JSON.stringify(settings); 12 | const textEncoder = new TextEncoder(); 13 | const encodedSettings = textEncoder.encode(settingsStr); 14 | const salt = crypto.getRandomValues(new Uint8Array(16)); 15 | const passwordBuffer = textEncoder.encode(password); 16 | const keyMaterial = await crypto.subtle.importKey('raw', passwordBuffer, {name: 'PBKDF2'}, false, ['deriveKey']); 17 | const key = await crypto.subtle.deriveKey( 18 | {name: 'PBKDF2', salt: salt, iterations: 100000, hash: 'SHA-256'}, 19 | keyMaterial, 20 | {name: 'AES-GCM', length: 256}, 21 | false, 22 | ['encrypt'] 23 | ); 24 | const iv = crypto.getRandomValues(new Uint8Array(12)); 25 | const encryptedSettings = await crypto.subtle.encrypt({name: 'AES-GCM', iv: iv}, key, encodedSettings); 26 | const blob = new Blob([salt, iv, new Uint8Array(encryptedSettings)], {type: 'application/octet-stream'}); 27 | const url = URL.createObjectURL(blob); 28 | const a = document.createElement('a'); 29 | a.href = url; 30 | a.download = '""" + dl_fn + """'; 31 | document.body.appendChild(a); 32 | a.click(); 33 | document.body.removeChild(a); 34 | URL.revokeObjectURL(url); 35 | }""" 36 | return js_code 37 | 38 | def generate_upload_settings_js(control_ids): 39 | js_code = """ 40 | async () => { 41 | const input = document.createElement('input'); 42 | input.type = 'file'; 43 | input.onchange = async e => { 44 | const file = e.target.files[0]; 45 | if (!file) { 46 | alert("No file selected."); 47 | return; 48 | } 49 | 50 | const password = prompt("Please enter the password for decryption", " "); 51 | if (!password) { 52 | alert("No password provided. Cancelling upload."); 53 | return; 54 | } 55 | 56 | const arrayBuffer = await file.arrayBuffer(); 57 | const salt = arrayBuffer.slice(0, 16); 58 | const iv = arrayBuffer.slice(16, 28); 59 | const encryptedData = arrayBuffer.slice(28); 60 | const textEncoder = new TextEncoder(); 61 | const passwordBuffer = textEncoder.encode(password); 62 | const keyMaterial = await crypto.subtle.importKey('raw', passwordBuffer, {name: 'PBKDF2'}, false, ['deriveKey']); 63 | const key = await crypto.subtle.deriveKey( 64 | {name: 'PBKDF2', salt: salt, iterations: 100000, hash: 'SHA-256'}, 65 | keyMaterial, 66 | {name: 'AES-GCM', length: 256}, 67 | false, 68 | ['decrypt'] 69 | ); 70 | 71 | try { 72 | const decryptedData = await crypto.subtle.decrypt({name: 'AES-GCM', iv: iv}, key, encryptedData); 73 | const textDecoder = new TextDecoder(); 74 | const settingsStr = textDecoder.decode(decryptedData); 75 | const settings = JSON.parse(settingsStr); 76 | """ + "\n".join([f'document.querySelector("{ctrl[1]}").value = settings["{ctrl[0]}"];' for ctrl in control_ids]) + """ 77 | """ + "\n".join([f'document.querySelector("{ctrl[1]}").dispatchEvent(new InputEvent("input", {{ bubbles: true }}));' for ctrl in control_ids]) + """ 78 | } catch (err) { 79 | alert("Failed to decrypt. Check your password and try again."); 80 | console.error("Decryption failed:", err); 81 | } 82 | }; 83 | input.click(); 84 | }""" 85 | return js_code --------------------------------------------------------------------------------