├── README.md ├── app.py ├── requirements.txt └── session_state.py /README.md: -------------------------------------------------------------------------------- 1 | # streamlit-google-oauth 2 | An example Streamlit application that incorporates Google OAuth 2.0 3 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import os 3 | import asyncio 4 | 5 | from session_state import get 6 | from httpx_oauth.clients.google import GoogleOAuth2 7 | 8 | 9 | async def write_authorization_url(client, 10 | redirect_uri): 11 | authorization_url = await client.get_authorization_url( 12 | redirect_uri, 13 | scope=["profile", "email"], 14 | extras_params={"access_type": "offline"}, 15 | ) 16 | return authorization_url 17 | 18 | 19 | async def write_access_token(client, 20 | redirect_uri, 21 | code): 22 | token = await client.get_access_token(code, redirect_uri) 23 | return token 24 | 25 | 26 | async def get_email(client, 27 | token): 28 | user_id, user_email = await client.get_id_email(token) 29 | return user_id, user_email 30 | 31 | 32 | def main(user_id, user_email): 33 | st.write(f"You're logged in as {user_email}") 34 | 35 | 36 | if __name__ == '__main__': 37 | client_id = os.environ['GOOGLE_CLIENT_ID'] 38 | client_secret = os.environ['GOOGLE_CLIENT_SECRET'] 39 | redirect_uri = os.environ['REDIRECT_URI'] 40 | 41 | client = GoogleOAuth2(client_id, client_secret) 42 | authorization_url = asyncio.run( 43 | write_authorization_url(client=client, 44 | redirect_uri=redirect_uri) 45 | ) 46 | 47 | session_state = get(token=None) 48 | if session_state.token is None: 49 | try: 50 | code = st.experimental_get_query_params()['code'] 51 | except: 52 | st.write(f'''

53 | Please login using this url

''', 55 | unsafe_allow_html=True) 56 | else: 57 | # Verify token is correct: 58 | try: 59 | token = asyncio.run( 60 | write_access_token(client=client, 61 | redirect_uri=redirect_uri, 62 | code=code)) 63 | except: 64 | st.write(f'''

65 | This account is not allowed or page was refreshed. 66 | Please try again: url

''', 68 | unsafe_allow_html=True) 69 | else: 70 | # Check if token has expired: 71 | if token.is_expired(): 72 | if token.is_expired(): 73 | st.write(f'''

74 | Login session has ended, 75 | please 76 | login again.

77 | ''') 78 | else: 79 | session_state.token = token 80 | user_id, user_email = asyncio.run( 81 | get_email(client=client, 82 | token=token['access_token']) 83 | ) 84 | session_state.user_id = user_id 85 | session_state.user_email = user_email 86 | main(user_id=session_state.user_id, 87 | user_email=session_state.user_email) 88 | else: 89 | main(user_id=session_state.user_id, 90 | user_email=session_state.user_email) 91 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit==0.81.1 2 | httpx-oauth==0.3.5 3 | -------------------------------------------------------------------------------- /session_state.py: -------------------------------------------------------------------------------- 1 | """Hack to add per-session state to Streamlit. 2 | Usage 3 | ----- 4 | >>> import SessionState 5 | >>> 6 | >>> session_state = SessionState.get(user_name='', favorite_color='black') 7 | >>> session_state.user_name 8 | '' 9 | >>> session_state.user_name = 'Mary' 10 | >>> session_state.favorite_color 11 | 'black' 12 | Since you set user_name above, next time your script runs this will be the 13 | result: 14 | >>> session_state = get(user_name='', favorite_color='black') 15 | >>> session_state.user_name 16 | 'Mary' 17 | """ 18 | 19 | try: 20 | import streamlit.ReportThread as ReportThread 21 | from streamlit.server.Server import Server 22 | except Exception: 23 | # Streamlit >= 0.65.0 24 | import streamlit.report_thread as ReportThread 25 | from streamlit.server.server import Server 26 | 27 | 28 | class SessionState(object): 29 | def __init__(self, **kwargs): 30 | """A new SessionState object. 31 | Parameters 32 | ---------- 33 | **kwargs : any 34 | Default values for the session state. 35 | Example 36 | ------- 37 | >>> session_state = SessionState(user_name='', favorite_color='black') 38 | >>> session_state.user_name = 'Mary' 39 | '' 40 | >>> session_state.favorite_color 41 | 'black' 42 | """ 43 | for key, val in kwargs.items(): 44 | setattr(self, key, val) 45 | 46 | 47 | def get(**kwargs): 48 | """Gets a SessionState object for the current session. 49 | Creates a new object if necessary. 50 | Parameters 51 | ---------- 52 | **kwargs : any 53 | Default values you want to add to the session state, if we're creating a 54 | new one. 55 | Example 56 | ------- 57 | >>> session_state = get(user_name='', favorite_color='black') 58 | >>> session_state.user_name 59 | '' 60 | >>> session_state.user_name = 'Mary' 61 | >>> session_state.favorite_color 62 | 'black' 63 | Since you set user_name above, next time your script runs this will be the 64 | result: 65 | >>> session_state = get(user_name='', favorite_color='black') 66 | >>> session_state.user_name 67 | 'Mary' 68 | """ 69 | # Hack to get the session object from Streamlit. 70 | 71 | ctx = ReportThread.get_report_ctx() 72 | 73 | this_session = None 74 | 75 | current_server = Server.get_current() 76 | if hasattr(current_server, '_session_infos'): 77 | # Streamlit < 0.56 78 | session_infos = Server.get_current()._session_infos.values() 79 | else: 80 | session_infos = Server.get_current()._session_info_by_id.values() 81 | 82 | for session_info in session_infos: 83 | s = session_info.session 84 | if ( 85 | # Streamlit < 0.54.0 86 | (hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg) 87 | or 88 | # Streamlit >= 0.54.0 89 | (not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue) 90 | or 91 | # Streamlit >= 0.65.2 92 | (not hasattr(s, 93 | '_main_dg') and s._uploaded_file_mgr == ctx.uploaded_file_mgr) 94 | ): 95 | this_session = s 96 | 97 | if this_session is None: 98 | raise RuntimeError( 99 | "Oh noes. Couldn't get your Streamlit Session object. " 100 | 'Are you doing something fancy with threads?') 101 | 102 | # Got the session object! Now let's attach some state into it. 103 | 104 | if not hasattr(this_session, '_custom_session_state'): 105 | this_session._custom_session_state = SessionState(**kwargs) 106 | 107 | return this_session._custom_session_state 108 | --------------------------------------------------------------------------------