├── README.md ├── requirements.txt ├── multipage.py ├── prophet_script2.py ├── nbeats2.py ├── image_bot2.py ├── data_bot2.py └── main.py /README.md: -------------------------------------------------------------------------------- 1 | # forecast-demo -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit==1.15.2 2 | pandas==1.3.3 3 | numpy==1.21.2 4 | matplotlib==3.4.3 5 | prophet==1.0.1 6 | holidays==0.11.3.1 7 | tensorflow==2.6.0 8 | Pillow==8.3.2 9 | google-generativeai==0.2.0 10 | python-dotenv==0.19.1 11 | -------------------------------------------------------------------------------- /multipage.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from prophet import Prophet 6 | from prophet.plot import add_changepoints_to_plot 7 | import holidays 8 | from datetime import date, datetime, timedelta 9 | import streamlit as st 10 | import tensorflow as tf 11 | from PIL import Image 12 | import os 13 | from dotenv import load_dotenv 14 | import google.generativeai as genai 15 | from streamlit_option_menu import option_menu 16 | # from keras.saving import register_keras_serializable 17 | 18 | load_dotenv() 19 | os.getenv("GOOGLE_API_KEY") 20 | genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) 21 | 22 | 23 | st.set_page_config( 24 | page_title="Demand Forecasting App", 25 | page_icon="📊" 26 | ) 27 | 28 | import main 29 | import image_bot2 30 | import data_bot2 31 | 32 | 33 | class MultiApp: 34 | 35 | def __init__(self): 36 | self.apps = [] 37 | 38 | def add_app(self, title, func): 39 | 40 | self.apps.append({ 41 | "title": title, 42 | "function": func 43 | }) 44 | 45 | def run(): 46 | if 'selected_index' not in st.session_state: 47 | st.session_state.selected_index = 0 48 | 49 | selected = option_menu( 50 | menu_title='', 51 | options=['Generate Forecasts','Chat with Image', 'Chat with Data'], 52 | icons=['cloud-arrow-up','graph-up-arrow', 'database-check'], 53 | menu_icon='chat-text-fill', 54 | default_index=st.session_state.selected_index, 55 | orientation="horizontal", 56 | styles={ 57 | "container": {"padding": "0!important", "background-color": "white"}, 58 | "icon": {"color": "black", "font-size": "default"}, 59 | "nav-link": {"color": "black", "font-size": "default", "text-align": "left", "margin": "0px", "--hover-color": "#e8f5e9"}, 60 | "nav-link-selected": {"background-color": "#02ab21", "color": "white"}, 61 | } 62 | ) 63 | 64 | st.session_state.selected_index = ['Generate Forecasts', 'Chat with Image', 'Chat with Data'].index(selected) 65 | 66 | 67 | if selected == "Generate Forecasts": 68 | main.app() 69 | if selected == "Chat with Image": 70 | image_bot2.app() 71 | if selected == "Chat with Data": 72 | data_bot2.app() 73 | 74 | 75 | run() -------------------------------------------------------------------------------- /prophet_script2.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from prophet import Prophet 5 | from prophet.plot import add_changepoints_to_plot 6 | import holidays 7 | from datetime import date, datetime, timedelta 8 | import streamlit as st 9 | import tensorflow as tf 10 | # from keras.saving import register_keras_serializable 11 | 12 | model = Prophet() 13 | 14 | 15 | def plot_time_series(timesteps, values, format='-', start=0, end=None, label=None): 16 | fig, ax = plt.subplots(figsize=(15, 10)) 17 | ax.plot(timesteps[start:end], values[start:end], format, label=label) 18 | ax.set_xlabel("Timeline") 19 | ax.set_ylabel("Forecasted Values of Sales") 20 | if label: 21 | ax.legend(fontsize=10) 22 | ax.grid(True) 23 | st.pyplot(fig) 24 | return fig 25 | 26 | # Define functions specific to the Prophet model 27 | def read_process(file): 28 | df = pd.read_csv(file) 29 | 30 | # Try to parse the 'Date' column with multiple formats 31 | date_formats = ["%d-%m-%Y", "%m/%d/%Y"] 32 | for date_format in date_formats: 33 | try: 34 | df['Date'] = pd.to_datetime(df['Date'], format=date_format) 35 | break 36 | except ValueError: 37 | continue 38 | 39 | # If the date parsing failed, raise an error 40 | if df['Date'].isna().any(): 41 | raise ValueError("Date parsing failed for all formats") 42 | 43 | # Convert the 'Date' column to the desired format 44 | df['Date'] = df['Date'].dt.strftime("%m/%d/%Y") 45 | 46 | # Create the new DataFrame with the required columns 47 | data = pd.DataFrame() 48 | data["ds"] = df["Date"] 49 | data["y"] = df["Sales"] 50 | 51 | return data 52 | 53 | def evaluate(df, end_date): 54 | df['ds'] = pd.to_datetime(df['ds']) 55 | start_date = df["ds"].iloc[-1] 56 | d_1 = df["ds"].iloc[0] 57 | d_2 = df["ds"].iloc[1] 58 | if isinstance(end_date, date): 59 | end_date = pd.Timestamp(end_date) 60 | diff_dates = (d_2 - d_1).days 61 | if diff_dates == 1: 62 | days_selected = (end_date - start_date).days 63 | st.write(f"Number of Days selected: {days_selected}") 64 | return days_selected, diff_dates 65 | 66 | if diff_dates == 7: 67 | weeks_selected = (end_date - start_date).days // 7 68 | st.write(f"Number of Weeks selected: {weeks_selected}") 69 | return weeks_selected, diff_dates 70 | return days_selected, diff_dates 71 | 72 | def forecast(model, df, timesteps, f): 73 | model.fit(df) 74 | if f == 1: 75 | future_df = model.make_future_dataframe(periods=timesteps, freq='D') 76 | if f == 7: 77 | future_df = model.make_future_dataframe(periods=timesteps, freq='W') 78 | future_forecasts = model.predict(future_df) 79 | future_forecasts = model.predict(future_df) 80 | dates = future_forecasts["ds"] 81 | preds = future_forecasts["yhat"] 82 | last_idx = df.index[-1] 83 | fig = plot_time_series(timesteps=dates[last_idx:], values=preds[last_idx:]) 84 | return future_forecasts, last_idx, fig 85 | -------------------------------------------------------------------------------- /nbeats2.py: -------------------------------------------------------------------------------- 1 | # nbeats.py 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from prophet import Prophet 6 | from prophet.plot import add_changepoints_to_plot 7 | import holidays 8 | from datetime import date, datetime, timedelta 9 | import streamlit as st 10 | import tensorflow as tf 11 | # from keras.saving import register_keras_serializable 12 | 13 | def plot_time_series(timesteps, values, format='-', start=0, end=None, label=None): 14 | fig, ax = plt.subplots(figsize=(15, 10)) 15 | ax.plot(timesteps[start:end], values[start:end], format, label=label) 16 | ax.set_xlabel("Timeline") 17 | ax.set_ylabel("Forecasted Values of Sales") 18 | if label: 19 | ax.legend(fontsize=10) 20 | ax.grid(True) 21 | st.pyplot(fig) 22 | return fig 23 | 24 | 25 | 26 | 27 | WINDOW_SIZE = 7 28 | HORIZON = 1 29 | 30 | class NBeatsBlock(tf.keras.layers.Layer): 31 | def __init__(self, input_size: int, theta_size: int, horizon: int, n_neurons: int, n_layers: int, **kwargs): 32 | super().__init__(**kwargs) 33 | self.input_size = input_size 34 | self.theta_size = theta_size 35 | self.horizon = horizon 36 | self.n_neurons = n_neurons 37 | self.n_layers = n_layers 38 | self.hidden = [tf.keras.layers.Dense(n_neurons, activation="relu") for _ in range(n_layers)] 39 | self.theta_layer = tf.keras.layers.Dense(theta_size, activation='linear', name='theta') 40 | 41 | def call(self, inputs): 42 | x = inputs 43 | for layer in self.hidden: 44 | x = layer(x) 45 | theta = self.theta_layer(x) 46 | backcast, forecast = theta[:, :self.input_size], theta[:, -self.horizon:] 47 | return backcast, forecast 48 | 49 | def read_and_process_nbeats(df): 50 | data = pd.DataFrame() 51 | data["ds"] = df["Date"] 52 | data["y"] = df["Sales"] 53 | data = data.set_index("ds") 54 | data_nbeats = data.copy() 55 | for i in range(WINDOW_SIZE): 56 | data_nbeats[f"y + {i+1}"] = data_nbeats["y"].shift(periods=i+1) 57 | X_all= data_nbeats.dropna().drop("y", axis=1) 58 | y_all = data_nbeats.dropna()["y"] 59 | return X_all, y_all 60 | 61 | def make_forecast_dates(df, end_date): 62 | start_date = df.iloc[-1]["Date"] 63 | dates_to_be_forecasted = pd.date_range(start=start_date, end=end_date) 64 | dates_to_be_forecasted = dates_to_be_forecasted[1:] 65 | st.write(f"Number of Timesteps selected: {len(dates_to_be_forecasted)}") 66 | return dates_to_be_forecasted, len(dates_to_be_forecasted) 67 | 68 | def make_future_forecast(values, model, into_future, window_size=WINDOW_SIZE) -> list: 69 | future_forecast = [] 70 | last_window = values[-WINDOW_SIZE:] 71 | last_window = np.asarray(last_window) 72 | for _ in range(into_future): 73 | future_pred = model.predict(tf.expand_dims(last_window, axis=0)) 74 | print(f"Predicting on:\n {last_window} -> Prediction: {tf.squeeze(future_pred).numpy()}\n") 75 | future_forecast.append(tf.squeeze(future_pred).numpy()) 76 | last_window = np.append(last_window, future_pred)[-WINDOW_SIZE:] 77 | return future_forecast 78 | -------------------------------------------------------------------------------- /image_bot2.py: -------------------------------------------------------------------------------- 1 | import google.generativeai as genai 2 | import streamlit as st 3 | from PIL import Image 4 | import os 5 | from dotenv import load_dotenv 6 | 7 | load_dotenv() 8 | os.getenv("GOOGLE_API_KEY") 9 | genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) 10 | 11 | 12 | if "conversation_history_image" not in st.session_state: 13 | st.session_state.conversation_history_image = [] 14 | 15 | 16 | def add_custom_css_image_bot(): 17 | st.markdown(""" 18 | 32 | """, unsafe_allow_html=True) 33 | 34 | 35 | 36 | def app(): 37 | add_custom_css_image_bot() 38 | st.title("Image Description and Context Generation") 39 | 40 | # Load and display the image 41 | uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"]) 42 | if uploaded_file is not None: 43 | image = Image.open(uploaded_file) 44 | st.image(image, caption="Uploaded Image.", use_column_width=True) 45 | 46 | # Convert the image to bytes 47 | # img_byte_arr = io.BytesIO() 48 | # image.save(img_byte_arr, format=image.format) 49 | # img_byte_arr = img_byte_arr.getvalue() 50 | 51 | # Add chat message for the initial prompt 52 | st.chat_message("📈").write("Analyze the trends in this graph.") 53 | 54 | user_prompt = st.chat_input("Enter your prompt here:") 55 | 56 | if user_prompt: 57 | # Hardcoded default prompt for forecasting graphs and trend analysis 58 | default_prompt = """ 59 | You are an expert in analyzing forecasting graphs for trend analysis. 60 | You will receive input images as graphs and you will have to answer questions based on the observed trends in brief and elaborate it. 61 | """ 62 | 63 | # Combine the default prompt with the user-provided prompt 64 | combined_prompt = f"{default_prompt}\n{user_prompt}" 65 | 66 | # Pass the combined prompt and image bytes to the model 67 | model = genai.GenerativeModel("gemini-1.5-flash") 68 | response = model.generate_content( 69 | [combined_prompt, image], 70 | generation_config = genai.types.GenerationConfig( 71 | temperature = 1.0), 72 | stream=True) 73 | response.resolve() 74 | 75 | st.session_state.conversation_history_image.append(("👦🏻", user_prompt, "user-message")) 76 | st.session_state.conversation_history_image.append(("🤖", response.text, "bot-message")) 77 | 78 | # Display the conversation history 79 | for speaker, message, css_class in st.session_state.conversation_history_image: 80 | st.markdown(f'
{speaker} : {message}
', unsafe_allow_html=True) 81 | 82 | if __name__ == "__main__": 83 | app() 84 | -------------------------------------------------------------------------------- /data_bot2.py: -------------------------------------------------------------------------------- 1 | import google.generativeai as genai 2 | import streamlit as st 3 | import os 4 | from dotenv import load_dotenv 5 | 6 | load_dotenv() 7 | os.getenv("GOOGLE_API_KEY") 8 | genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) 9 | 10 | if "conversation_history_data" not in st.session_state: 11 | st.session_state.conversation_history_data = [] 12 | 13 | def add_custom_css_data_bot(): 14 | st.markdown(""" 15 | 29 | """, unsafe_allow_html=True) 30 | 31 | def app(): 32 | add_custom_css_data_bot() 33 | st.title("Data Analysis and Context Generation") 34 | 35 | # Upload and display the data file 36 | uploaded_file = st.file_uploader("Choose a data file...", type=["csv", "xlsx"]) 37 | if uploaded_file is not None: 38 | # Read the data file 39 | if uploaded_file.type == "text/csv": 40 | import pandas as pd 41 | data = pd.read_csv(uploaded_file) 42 | elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": 43 | import pandas as pd 44 | data = pd.read_excel(uploaded_file) 45 | 46 | st.write(data) 47 | 48 | # Add chat message for the initial prompt 49 | st.chat_message("📊").write("Analyze the trends in this dataset.") 50 | 51 | user_prompt = st.chat_input("Enter your prompt here:") 52 | 53 | if user_prompt: 54 | # Hardcoded default prompt for analyzing datasets 55 | default_prompt = """ 56 | The model will be getting a CSV file containing columns of dates and forecast values of sales or prices. 57 | The model has to understand the data thoroughly and generate responses. 58 | """ 59 | 60 | # Combine the default prompt with the user-provided prompt 61 | combined_prompt = f"{default_prompt}\n{user_prompt}" 62 | 63 | # Convert the data to text for the model (could be JSON, CSV, etc.) 64 | data_text = data.to_string() 65 | 66 | # Pass the combined prompt and data text to the model 67 | model = genai.GenerativeModel("gemini-1.5-flash") 68 | response = model.generate_content( 69 | [combined_prompt, data_text], 70 | generation_config = genai.types.GenerationConfig( 71 | top_p = 0.6, 72 | top_k = 5, 73 | temperature = 0.8), 74 | stream=True) 75 | response.resolve() 76 | 77 | st.session_state.conversation_history_data.append(("👦🏻", user_prompt, "user-message")) 78 | st.session_state.conversation_history_data.append(("🤖", response.text, "bot-message")) 79 | 80 | # Display the conversation history 81 | for speaker, message, css_class in st.session_state.conversation_history_data: 82 | st.markdown(f'
{speaker} : {message}
', unsafe_allow_html=True) 83 | 84 | if __name__ == "__main__": 85 | app() 86 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from prophet import Prophet 6 | from prophet.plot import add_changepoints_to_plot 7 | from datetime import datetime 8 | import tensorflow as tf 9 | from io import BytesIO 10 | import os 11 | from dotenv import load_dotenv 12 | import time 13 | 14 | 15 | from prophet_script2 import read_process, evaluate, forecast 16 | from nbeats2 import read_and_process_nbeats, make_forecast_dates, make_future_forecast, NBeatsBlock, plot_time_series, WINDOW_SIZE 17 | 18 | 19 | 20 | load_dotenv() 21 | os.getenv("GOOGLE_API_KEY") 22 | 23 | # Helper function to save Matplotlib figure to bytes 24 | def save_fig_to_bytes(fig): 25 | img_bytes = BytesIO() 26 | fig.savefig(img_bytes, format='png') 27 | img_bytes.seek(0) 28 | return img_bytes 29 | 30 | # Functions to generate forecasts and files 31 | def generate_prophet_files(uploaded_file, end_date): 32 | model = Prophet() 33 | df = read_process(uploaded_file) 34 | timesteps, freq = evaluate(df, end_date) 35 | fut, last_idx, fig = forecast(model, df, timesteps, freq) 36 | download_file = pd.DataFrame() 37 | download_file["Date"] = fut["ds"][last_idx:] 38 | download_file["Predictions"] = fut["yhat"][last_idx:] 39 | download_file = download_file.reset_index(drop=True) 40 | csv = download_file.to_csv(index=False).encode('utf-8') 41 | img_bytes = save_fig_to_bytes(fig) 42 | return csv, img_bytes 43 | 44 | def generate_nbeats_files(uploaded_file, end_date): 45 | nbeats_model = tf.keras.models.load_model("C:/Users/Siddharth/Desktop/woodpeckers/nbeats.keras", custom_objects={'NBeatsBlock': NBeatsBlock}) 46 | df = pd.read_csv(uploaded_file, parse_dates=["Date"]) 47 | df['Date'] = df['Date'].dt.strftime("%m/%d/%Y") 48 | a, b = read_and_process_nbeats(df) 49 | x, y = make_forecast_dates(df, end_date) 50 | preds = make_future_forecast(b, nbeats_model, y, WINDOW_SIZE) 51 | forecast_df = pd.DataFrame() 52 | forecast_df["Date"] = x 53 | forecast_df["Predictions"] = preds 54 | forecast_df = forecast_df.reset_index(drop=True) 55 | fig = plot_time_series(timesteps=forecast_df["Date"], values=forecast_df["Predictions"]) 56 | csv = forecast_df.to_csv(index=False).encode('utf-8') 57 | img_bytes = save_fig_to_bytes(fig) 58 | return csv, img_bytes 59 | 60 | # Main application function 61 | def app(): 62 | st.title("Generate Forecasts") 63 | st.header("File Upload") 64 | 65 | uploaded_file = st.file_uploader("Upload a File", type="csv") 66 | end_date = st.date_input("Enter Last Date to be Forecasted", datetime(2019, 7, 6)) 67 | 68 | model_selection = st.selectbox("Select Model to generate Forecasts", ["Prophet", "N-Beats"], index=None, placeholder="Models..") 69 | 70 | # Unique keys with timestamps to ensure no conflicts 71 | timestamp = int(time.time()) 72 | csv_key = f'csv_download_button_{model_selection}_{timestamp}' 73 | img_key = f'img_download_button_{model_selection}_{timestamp}' 74 | 75 | # Generate forecasts and files only if button is clicked 76 | if st.button("Generate Forecasts?"): 77 | if uploaded_file is not None: 78 | if model_selection == "Prophet": 79 | st.session_state.csv_data, st.session_state.img_data = generate_prophet_files(uploaded_file, end_date) 80 | st.success("Forecasts Generated") 81 | elif model_selection == "N-Beats": 82 | st.session_state.csv_data, st.session_state.img_data = generate_nbeats_files(uploaded_file, end_date) 83 | st.success("Forecasts Generated") 84 | 85 | # Display download buttons if files are generated 86 | if 'csv_data' in st.session_state and 'img_data' in st.session_state: 87 | st.download_button( 88 | label="Download Forecasts as CSV", 89 | data=st.session_state.csv_data, 90 | file_name='forecasts.csv', 91 | mime='text/csv', 92 | key=csv_key 93 | ) 94 | st.download_button( 95 | label="Download Forecast as Image", 96 | data=st.session_state.img_data, 97 | file_name='forecast.png', 98 | mime='image/png', 99 | key=img_key 100 | ) 101 | 102 | if __name__ == "__main__": 103 | app() --------------------------------------------------------------------------------