├── README.md ├── app.py ├── groqcloud_darkmode.png └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # Stock Market Analysis with Llama 3 Function Calling 2 | 3 | Welcome to the Stock Market Analyst! This is a Streamlit web application that leverages the yfinance API to provide insights into stocks and their prices. The application uses the Llama 3 model on Groq in conjunction with Langchain to call functions based on the user prompt. 4 | 5 | ## Key Functions 6 | 7 | - **get_stock_info(symbol, key)**: This function fetches various information about a given stock symbol. The information can be anything from the company's address to its financial ratios. The 'key' parameter specifies the type of information to fetch. 8 | 9 | - **get_historical_price(symbol, start_date, end_date)**: This function fetches the historical stock prices for a given symbol from a specified start date to an end date. The returned data is a DataFrame with the date and closing price of the stock. 10 | 11 | - **plot_price_over_time(historical_price_dfs)**: This function takes a list of DataFrames (each containing historical price data for a stock) and plots the prices over time using Plotly. The plot is displayed in the Streamlit app. 12 | 13 | - **call_functions(llm_with_tools, user_prompt)**: This function takes the user's question, invokes the appropriate tool (either get_stock_info or get_historical_price), and generates a response. If the user asked for historical prices, it also calls plot_price_over_time to generate a plot. 14 | 15 | ## Function Calling 16 | 17 | The function calling in this application is handled by the Groq API, abstracted with Langchain. When the user asks a question, the application invokes the appropriate tool with parameters based on the user's question. The tool's output is then used to generate a response. 18 | 19 | ## Usage 20 | 21 | 1. Clone the repository to your local machine. 22 | 23 | 2. Install the required dependencies listed in the **requirements.txt** file. 24 | 25 | 3. Run the application using Streamlit with the command `streamlit run app.py`. 26 | 27 | 4. In the application, enter your question about a stock in the text input field. For example, "What is the current price of Google stock?" or "Show me the historical prices of Amazon and Tesla over the past year.". 28 | 29 | 5. If you want to provide additional context for the language model, you can do so in the sidebar. 30 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from langchain_groq import ChatGroq 2 | import os 3 | import yfinance as yf 4 | import pandas as pd 5 | 6 | from langchain_core.tools import tool 7 | from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, ToolMessage 8 | 9 | from datetime import date 10 | import pandas as pd 11 | import plotly.graph_objects as go 12 | 13 | import streamlit as st 14 | 15 | @tool 16 | def get_stock_info(symbol, key): 17 | '''Return the correct stock info value given the appropriate symbol and key. Infer valid key from the user prompt; it must be one of the following: 18 | 19 | address1, city, state, zip, country, phone, website, industry, industryKey, industryDisp, sector, sectorKey, sectorDisp, longBusinessSummary, fullTimeEmployees, companyOfficers, auditRisk, boardRisk, compensationRisk, shareHolderRightsRisk, overallRisk, governanceEpochDate, compensationAsOfEpochDate, maxAge, priceHint, previousClose, open, dayLow, dayHigh, regularMarketPreviousClose, regularMarketOpen, regularMarketDayLow, regularMarketDayHigh, dividendRate, dividendYield, exDividendDate, beta, trailingPE, forwardPE, volume, regularMarketVolume, averageVolume, averageVolume10days, averageDailyVolume10Day, bid, ask, bidSize, askSize, marketCap, fiftyTwoWeekLow, fiftyTwoWeekHigh, priceToSalesTrailing12Months, fiftyDayAverage, twoHundredDayAverage, currency, enterpriseValue, profitMargins, floatShares, sharesOutstanding, sharesShort, sharesShortPriorMonth, sharesShortPreviousMonthDate, dateShortInterest, sharesPercentSharesOut, heldPercentInsiders, heldPercentInstitutions, shortRatio, shortPercentOfFloat, impliedSharesOutstanding, bookValue, priceToBook, lastFiscalYearEnd, nextFiscalYearEnd, mostRecentQuarter, earningsQuarterlyGrowth, netIncomeToCommon, trailingEps, forwardEps, pegRatio, enterpriseToRevenue, enterpriseToEbitda, 52WeekChange, SandP52WeekChange, lastDividendValue, lastDividendDate, exchange, quoteType, symbol, underlyingSymbol, shortName, longName, firstTradeDateEpochUtc, timeZoneFullName, timeZoneShortName, uuid, messageBoardId, gmtOffSetMilliseconds, currentPrice, targetHighPrice, targetLowPrice, targetMeanPrice, targetMedianPrice, recommendationMean, recommendationKey, numberOfAnalystOpinions, totalCash, totalCashPerShare, ebitda, totalDebt, quickRatio, currentRatio, totalRevenue, debtToEquity, revenuePerShare, returnOnAssets, returnOnEquity, freeCashflow, operatingCashflow, earningsGrowth, revenueGrowth, grossMargins, ebitdaMargins, operatingMargins, financialCurrency, trailingPegRatio 20 | 21 | If asked generically for 'stock price', use currentPrice 22 | ''' 23 | data = yf.Ticker(symbol) 24 | stock_info = data.info 25 | return stock_info[key] 26 | 27 | 28 | @tool 29 | def get_historical_price(symbol, start_date, end_date): 30 | """ 31 | Fetches historical stock prices for a given symbol from 'start_date' to 'end_date'. 32 | - symbol (str): Stock ticker symbol. 33 | - end_date (date): Typically today unless a specific end date is provided. End date MUST be greater than start date 34 | - start_date (date): Set explicitly, or calculated as 'end_date - date interval' (for example, if prompted 'over the past 6 months', date interval = 6 months so start_date would be 6 months earlier than today's date). Default to '1900-01-01' if vaguely asked for historical price. Start date must always be before the current date 35 | """ 36 | 37 | data = yf.Ticker(symbol) 38 | hist = data.history(start=start_date, end=end_date) 39 | hist = hist.reset_index() 40 | hist[symbol] = hist['Close'] 41 | return hist[['Date', symbol]] 42 | 43 | def plot_price_over_time(historical_price_dfs): 44 | 45 | full_df = pd.DataFrame(columns = ['Date']) 46 | for df in historical_price_dfs: 47 | full_df = full_df.merge(df, on = 'Date', how = 'outer') 48 | 49 | # Create a Plotly figure 50 | fig = go.Figure() 51 | 52 | # Dynamically add a trace for each stock symbol in the DataFrame 53 | for column in full_df.columns[1:]: # Skip the first column since it's the date 54 | fig.add_trace(go.Scatter(x=full_df['Date'], y=full_df[column], mode='lines+markers', name=column)) 55 | 56 | 57 | # Update the layout to add titles and format axis labels 58 | fig.update_layout( 59 | title='Stock Price Over Time: ' + ', '.join(full_df.columns.tolist()[1:]), 60 | xaxis_title='Date', 61 | yaxis_title='Stock Price (USD)', 62 | yaxis_tickprefix='$', 63 | yaxis_tickformat=',.2f', 64 | xaxis=dict( 65 | tickangle=-45, 66 | nticks=20, 67 | tickfont=dict(size=10), 68 | ), 69 | yaxis=dict( 70 | showgrid=True, # Enable y-axis grid lines 71 | gridcolor='lightgrey', # Set grid line color 72 | ), 73 | legend_title_text='Stock Symbol', 74 | plot_bgcolor='gray', # Set plot background to white 75 | paper_bgcolor='gray', # Set overall figure background to white 76 | legend=dict( 77 | bgcolor='gray', # Optional: Set legend background to white 78 | bordercolor='black' 79 | ) 80 | ) 81 | 82 | # Show the figure 83 | st.plotly_chart(fig, use_container_width=True) 84 | 85 | def call_functions(llm_with_tools, user_prompt): 86 | system_prompt = 'You are a helpful finance assistant that analyzes stocks and stock prices. Today is {today}'.format(today = date.today()) 87 | 88 | messages = [SystemMessage(system_prompt), HumanMessage(user_prompt)] 89 | ai_msg = llm_with_tools.invoke(messages) 90 | messages.append(ai_msg) 91 | historical_price_dfs = [] 92 | symbols = [] 93 | for tool_call in ai_msg.tool_calls: 94 | selected_tool = {"get_stock_info": get_stock_info, "get_historical_price": get_historical_price}[tool_call["name"].lower()] 95 | tool_output = selected_tool.invoke(tool_call["args"]) 96 | if tool_call['name'] == 'get_historical_price': 97 | historical_price_dfs.append(tool_output) 98 | symbols.append(tool_output.columns[1]) 99 | else: 100 | messages.append(ToolMessage(tool_output, tool_call_id=tool_call["id"])) 101 | 102 | if len(historical_price_dfs) > 0: 103 | plot_price_over_time(historical_price_dfs) 104 | 105 | symbols = ' and '.join(symbols) 106 | messages.append(ToolMessage('Tell the user that a historical stock price chart for {symbols} been generated.'.format(symbols=symbols), tool_call_id=0)) 107 | 108 | return llm_with_tools.invoke(messages).content 109 | 110 | 111 | 112 | def main(): 113 | 114 | llm = ChatGroq(groq_api_key = os.getenv('GROQ_API_KEY'),model = 'llama3-70b-8192') 115 | 116 | tools = [get_stock_info, get_historical_price] 117 | llm_with_tools = llm.bind_tools(tools) 118 | 119 | # Display the Groq logo 120 | spacer, col = st.columns([5, 1]) 121 | with col: 122 | st.image('groqcloud_darkmode.png') 123 | 124 | # Display the title and introduction of the application 125 | st.title("Groqing the Stock Market with Llama 3") 126 | multiline_text = """ 127 | Try to ask it "What is the current price of Meta stock?" or "Show me the historical prices of Apple vs Microsoft stock over the past 6 months.". 128 | """ 129 | 130 | st.markdown(multiline_text, unsafe_allow_html=True) 131 | 132 | # Add customization options to the sidebar 133 | st.sidebar.title('Customization') 134 | additional_context = st.sidebar.text_input('Enter additional summarization context for the LLM here (i.e. write it in spanish):') 135 | 136 | # Get the user's question 137 | user_question = st.text_input("Ask a question about a stock or multiple stocks:") 138 | 139 | if user_question: 140 | response = call_functions(llm_with_tools, user_question) 141 | st.write(response) 142 | 143 | 144 | 145 | if __name__ == "__main__": 146 | main() 147 | -------------------------------------------------------------------------------- /groqcloud_darkmode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/definitive-io/llama3-function-calling/e28b487a1194cc9320922ba9c0bd74d4b0fd2fa0/groqcloud_darkmode.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | pandas 3 | numpy 4 | groq 5 | langchain_community 6 | langchain_groq 7 | yfinance 8 | plotly 9 | langchain_core --------------------------------------------------------------------------------