├── Satisfy-Regular.ttf ├── README.md ├── gemini_image.py ├── image_new.py └── aac_gemma3.py /Satisfy-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ushareng/AAC-using-Gemma3/main/Satisfy-Regular.ttf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AAC tool for Nonverbal Autism using Gemma 3 (Keras ) and Gemini 2 | 3 | # Use Case 4 | AAC tools helps Nonverbal autistic individuals to communicate . 5 | 6 | # Setup 7 | - Set the api key with ```export GOOGLE_API_KEY=``` 8 | - Authenticate google cloud from cli or from cloud console 9 | 10 | # Demo 11 | 12 | 13 | https://github.com/user-attachments/assets/993dae99-5343-46c0-ab4a-084c73356b1b 14 | 15 | 16 | 17 | # Flowchart 18 | ![aac_tool_flowchart_sample_corrected drawio (1)](https://github.com/user-attachments/assets/301d15bc-426a-4d0e-affa-81f271c3048e) 19 | -------------------------------------------------------------------------------- /gemini_image.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | import mimetypes 4 | import io 5 | from google import genai 6 | from google.genai import types 7 | 8 | 9 | def generate_flashcard_image(word): 10 | """Generates a vector image with the given word for flashcards and returns base64-encoded image data.""" 11 | client = genai.Client(api_key=os.environ.get("GOOGLE_API_KEY")) 12 | 13 | model = "gemini-2.0-flash-exp-image-generation" 14 | 15 | prompt = ( 16 | "Create a vector image for the word given below to be used in flash cards. " 17 | "Create an image of smaller size. Write the name of the word followed by the image. " 18 | "The image should be there along with the word, not just a description. " 19 | "Keep the image dimensions in a 2:3 ratio.DO NOT GENERATE TEXT BASED IMAGES\n" 20 | f"Word - {word}" 21 | ) 22 | 23 | contents = [ 24 | types.Content( 25 | role="user", 26 | parts=[types.Part.from_text(text=prompt)], 27 | ), 28 | ] 29 | generate_content_config = types.GenerateContentConfig( 30 | temperature=0, 31 | response_modalities=["image", "text"], 32 | #response_mime_type="image/png", # Ensures image output 33 | ) 34 | 35 | for chunk in client.models.generate_content_stream( 36 | model=model, contents=contents, config=generate_content_config 37 | ): 38 | if ( 39 | chunk.candidates 40 | and chunk.candidates[0].content 41 | and chunk.candidates[0].content.parts 42 | ): 43 | inline_data = chunk.candidates[0].content.parts[0].inline_data 44 | if inline_data: 45 | img_bytes = inline_data.data 46 | 47 | # Convert binary image data to base64 48 | return base64.b64encode(img_bytes).decode("utf-8") 49 | 50 | return None 51 | 52 | 53 | if __name__ == "__main__": 54 | word = "Apple" 55 | 56 | 57 | -------------------------------------------------------------------------------- /image_new.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFont, ImageDraw 2 | import requests 3 | import io 4 | import base64 5 | import random 6 | import os 7 | 8 | def get_font(font_source: str, size: int): 9 | """Loads a font from a URL or local file with a given size.""" 10 | if font_source.startswith("http"): 11 | try: 12 | response = requests.get(font_source, stream=True) 13 | response.raise_for_status() 14 | font_bytes = io.BytesIO(response.content) 15 | return ImageFont.truetype(font_bytes, size) 16 | except Exception as e: 17 | print(f" Error loading font from {font_source}: {e}") 18 | return ImageFont.load_default() 19 | else: 20 | try: 21 | return ImageFont.truetype(font_source, size) 22 | except Exception as e: 23 | print(f" Error loading local font {font_source}: {e}") 24 | return ImageFont.load_default() 25 | 26 | def find_best_font_size(text, font_source, card_size, padding=20): 27 | """Finds the best font size dynamically for a given word.""" 28 | font_size = 400 29 | steps = [20] * 10 + [10] * 6 + [5] * 8 + [2] * 50 30 | for step in steps: 31 | font = get_font(font_source, font_size) 32 | text_width, text_height = ImageDraw.Draw(Image.new("RGB", (1, 1))).textbbox((0, 0), text, font=font)[2:] 33 | if text_width + padding < card_size[0] and text_height + padding < card_size[1]: 34 | return font_size 35 | font_size -= step 36 | return 50 37 | 38 | def create_gradient_bg(size, color1, color2): 39 | """Creates a gradient background.""" 40 | img = Image.new("RGB", size, color1) 41 | draw = ImageDraw.Draw(img) 42 | for i in range(size[1]): 43 | r = int(color1[0] + (color2[0] - color1[0]) * (i / size[1])) 44 | g = int(color1[1] + (color2[1] - color1[1]) * (i / size[1])) 45 | b = int(color1[2] + (color2[2] - color1[2]) * (i / size[1])) 46 | draw.line([(0, i), (size[0], i)], fill=(r, g, b)) 47 | return img 48 | 49 | def generate_flashcard(text: str): 50 | """Creates a flashcard for a given text with random styling.""" 51 | CARD_SIZE = (600, 400) 52 | OUTPUT_FOLDER = "flashcards" 53 | os.makedirs(OUTPUT_FOLDER, exist_ok=True) 54 | 55 | FONTS = { 56 | "Berkshire Swash": "https://github.com/google/fonts/raw/main/ofl/berkshireswash/BerkshireSwash-Regular.ttf", 57 | "Bungee Tint": "https://github.com/google/fonts/raw/main/ofl/bungeetint/BungeeTint-Regular.ttf", 58 | # "Concert One": "https://github.com/google/fonts/raw/main/ofl/concertone/ConcertOne-Regular.ttf", 59 | "Cookie": "https://github.com/google/fonts/raw/main/ofl/cookie/Cookie-Regular.ttf", 60 | "Courgette": "https://github.com/google/fonts/raw/main/ofl/courgette/Courgette-Regular.ttf", 61 | # "Gravitas One": "GravitasOne-Regular.ttf", 62 | # "Lilita One": "https://github.com/google/fonts/raw/main/ofl/lilitaone/LilitaOne-Regular.ttf", 63 | "Protest Riot": "https://github.com/google/fonts/raw/main/ofl/protestriot/ProtestRiot-Regular.ttf", 64 | "Satisfy": "Satisfy-Regular.ttf", 65 | "Yatra One": "https://github.com/google/fonts/raw/main/ofl/yatraone/YatraOne-Regular.ttf" 66 | } 67 | 68 | font_name, font_source = random.choice(list(FONTS.items())) 69 | color1 = random.choice([(255, 87, 34), (33, 150, 243), (76, 175, 80)]) 70 | color2 = random.choice([(255, 193, 7), (233, 30, 99), (156, 39, 176)]) 71 | 72 | img = create_gradient_bg(CARD_SIZE, color1, color2) 73 | best_font_size = find_best_font_size(text, font_source, CARD_SIZE) 74 | font = get_font(font_source, best_font_size) 75 | draw = ImageDraw.Draw(img) 76 | 77 | bbox = draw.textbbox((0, 0), text, font=font) 78 | text_width = bbox[2] - bbox[0] 79 | text_height = bbox[3] - bbox[1] 80 | text_baseline = bbox[1] 81 | 82 | x = (CARD_SIZE[0] - text_width) // 2 83 | y = (CARD_SIZE[1] - text_height) // 2 - text_baseline // 2 84 | draw.text((x, y), text, font=font, fill="white") 85 | 86 | # file_path = os.path.join(OUTPUT_FOLDER, f"{text}.png") 87 | # img.save(file_path) 88 | 89 | buffered = io.BytesIO() 90 | img.save(buffered, format="PNG") 91 | img_bytes = buffered.getvalue() 92 | img_base64 = base64.b64encode(img_bytes).decode('utf-8') 93 | 94 | print(f"✅ Flashcard created: font {font_name} and size {best_font_size}") 95 | 96 | return img_base64 97 | 98 | -------------------------------------------------------------------------------- /aac_gemma3.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import random 3 | import string 4 | import io 5 | import os 6 | import base64 7 | from streamlit_extras.stylable_container import stylable_container 8 | import keras_hub 9 | from image_new import generate_flashcard 10 | from gemini_image import generate_flashcard_image 11 | from google.cloud import texttospeech 12 | import tensorflow as tf 13 | 14 | st.set_page_config(layout="wide") 15 | 16 | client_speech = texttospeech.TextToSpeechClient() 17 | audio_config = texttospeech.AudioConfig( 18 | audio_encoding=texttospeech.AudioEncoding.LINEAR16, 19 | speaking_rate=1 20 | ) 21 | 22 | voice = texttospeech.VoiceSelectionParams( 23 | language_code="en-US", 24 | name="en-US-Studio-O", 25 | ) 26 | 27 | gemma_model = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_instruct_12b_text") 28 | 29 | def get_items(clicked_texts): 30 | sentence = " ".join(clicked_texts) if clicked_texts else "" 31 | 32 | if not sentence: 33 | prompt = "Give me 12 common words for an AAC app. Output as a Python list: ['word1', 'word2', ...]" 34 | else: 35 | prompt = ( 36 | f"This is an AAC application. Given the sentence: '{sentence}', what are the 12 most likely next words the user might want to pick? " 37 | "Reply only with a Python list: ['word1', 'word2', ...]" 38 | ) 39 | 40 | response = gemma_model.generate(prompt, max_length=100) 41 | start = response.find('[') 42 | end = response.find(']') 43 | word_list = response[start + 1:end].strip().split(',') 44 | 45 | # Clean up the words 46 | final_words = [word.strip().strip("'\"") for word in word_list] 47 | 48 | next_items = [] 49 | for i, word in enumerate(final_words): 50 | img_base64 = generate_flashcard_image(word) 51 | 52 | if not img_base64: # Fallback to generate_flashcard() if no image was generated 53 | img_base64 = generate_flashcard(word) 54 | 55 | img_url = f"data:image/jpeg;base64,{img_base64}" if img_base64 else "" 56 | 57 | next_items.append({ 58 | "id": i, 59 | "label": word, 60 | "image_url": img_url 61 | }) 62 | 63 | return next_items 64 | 65 | # Session state for clicked texts 66 | if "clicked_texts" not in st.session_state: 67 | st.session_state.clicked_texts = [] 68 | 69 | # Function to handle button clicks 70 | def on_click(label): 71 | st.session_state.clicked_texts.append(label) 72 | 73 | 74 | def create_button_grid(items, columns=4): 75 | rows = len(items) // columns + (len(items) % columns > 0) 76 | with st.container(): 77 | for row in range(rows): 78 | cols = st.columns(columns) 79 | for col in range(columns): 80 | idx = row * columns + col 81 | if idx < len(items): 82 | item = items[idx] 83 | with cols[col]: 84 | with st.container(border=True): 85 | 86 | st.image(item['image_url'], use_container_width=True) 87 | st.button( 88 | item['label'], 89 | key='button_' + ''.join(random.choices(string.ascii_lowercase + string.digits, k=10)), 90 | on_click=lambda label=item['label']: on_click(label), 91 | use_container_width=True 92 | ) 93 | 94 | def text_to_speech(text): 95 | tts = gTTS(text) 96 | audio = io.BytesIO() 97 | tts.write_to_fp(audio) 98 | audio.seek(0) 99 | return audio 100 | 101 | 102 | # Main UI 103 | def main(): 104 | st.title("AAC Tool for Autism Using Keras Gemma 3") 105 | 106 | # Style 107 | st.markdown( 108 | """ 109 | 114 | """, 115 | unsafe_allow_html=True, 116 | ) 117 | 118 | col1, col2 = st.columns([6, 1]) 119 | with col1: 120 | st.text_input("Your sentence so far:", value=" ".join(st.session_state.clicked_texts), key="text_bar") 121 | with col2: 122 | if st.button("🗑️ Clear", use_container_width=True): 123 | st.session_state.clicked_texts = [] 124 | 125 | # Display AAC options 126 | items = get_items(st.session_state.clicked_texts)[:12] # show only top 12 127 | if st.button("Genreate Audio of Text", type="primary",use_container_width=True ): 128 | text = st.session_state.clicked_texts 129 | if text: 130 | input_text = texttospeech.SynthesisInput(text=" ".join(text)) 131 | response = client_speech.synthesize_speech( 132 | request={"input": input_text, "voice": voice, "audio_config": audio_config} 133 | ) 134 | st.audio( 135 | f"data:audio/mp3;base64,{base64.b64encode(response.audio_content).decode()}", format="audio/mp3" 136 | ) 137 | create_button_grid(items, columns=4) 138 | 139 | if __name__ == "__main__": 140 | main() 141 | --------------------------------------------------------------------------------