├── README.md
├── app.py
├── requirements.txt
└── session_state.py
/README.md:
--------------------------------------------------------------------------------
1 | # streamlit-google-oauth
2 | An example Streamlit application that incorporates Google OAuth 2.0
3 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import os
3 | import asyncio
4 |
5 | from session_state import get
6 | from httpx_oauth.clients.google import GoogleOAuth2
7 |
8 |
9 | async def write_authorization_url(client,
10 | redirect_uri):
11 | authorization_url = await client.get_authorization_url(
12 | redirect_uri,
13 | scope=["profile", "email"],
14 | extras_params={"access_type": "offline"},
15 | )
16 | return authorization_url
17 |
18 |
19 | async def write_access_token(client,
20 | redirect_uri,
21 | code):
22 | token = await client.get_access_token(code, redirect_uri)
23 | return token
24 |
25 |
26 | async def get_email(client,
27 | token):
28 | user_id, user_email = await client.get_id_email(token)
29 | return user_id, user_email
30 |
31 |
32 | def main(user_id, user_email):
33 | st.write(f"You're logged in as {user_email}")
34 |
35 |
36 | if __name__ == '__main__':
37 | client_id = os.environ['GOOGLE_CLIENT_ID']
38 | client_secret = os.environ['GOOGLE_CLIENT_SECRET']
39 | redirect_uri = os.environ['REDIRECT_URI']
40 |
41 | client = GoogleOAuth2(client_id, client_secret)
42 | authorization_url = asyncio.run(
43 | write_authorization_url(client=client,
44 | redirect_uri=redirect_uri)
45 | )
46 |
47 | session_state = get(token=None)
48 | if session_state.token is None:
49 | try:
50 | code = st.experimental_get_query_params()['code']
51 | except:
52 | st.write(f'''
53 | Please login using this url
''',
55 | unsafe_allow_html=True)
56 | else:
57 | # Verify token is correct:
58 | try:
59 | token = asyncio.run(
60 | write_access_token(client=client,
61 | redirect_uri=redirect_uri,
62 | code=code))
63 | except:
64 | st.write(f'''
65 | This account is not allowed or page was refreshed.
66 | Please try again: url
''',
68 | unsafe_allow_html=True)
69 | else:
70 | # Check if token has expired:
71 | if token.is_expired():
72 | if token.is_expired():
73 | st.write(f'''
74 | Login session has ended,
75 | please
76 | login again.
77 | ''')
78 | else:
79 | session_state.token = token
80 | user_id, user_email = asyncio.run(
81 | get_email(client=client,
82 | token=token['access_token'])
83 | )
84 | session_state.user_id = user_id
85 | session_state.user_email = user_email
86 | main(user_id=session_state.user_id,
87 | user_email=session_state.user_email)
88 | else:
89 | main(user_id=session_state.user_id,
90 | user_email=session_state.user_email)
91 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | streamlit==0.81.1
2 | httpx-oauth==0.3.5
3 |
--------------------------------------------------------------------------------
/session_state.py:
--------------------------------------------------------------------------------
1 | """Hack to add per-session state to Streamlit.
2 | Usage
3 | -----
4 | >>> import SessionState
5 | >>>
6 | >>> session_state = SessionState.get(user_name='', favorite_color='black')
7 | >>> session_state.user_name
8 | ''
9 | >>> session_state.user_name = 'Mary'
10 | >>> session_state.favorite_color
11 | 'black'
12 | Since you set user_name above, next time your script runs this will be the
13 | result:
14 | >>> session_state = get(user_name='', favorite_color='black')
15 | >>> session_state.user_name
16 | 'Mary'
17 | """
18 |
19 | try:
20 | import streamlit.ReportThread as ReportThread
21 | from streamlit.server.Server import Server
22 | except Exception:
23 | # Streamlit >= 0.65.0
24 | import streamlit.report_thread as ReportThread
25 | from streamlit.server.server import Server
26 |
27 |
28 | class SessionState(object):
29 | def __init__(self, **kwargs):
30 | """A new SessionState object.
31 | Parameters
32 | ----------
33 | **kwargs : any
34 | Default values for the session state.
35 | Example
36 | -------
37 | >>> session_state = SessionState(user_name='', favorite_color='black')
38 | >>> session_state.user_name = 'Mary'
39 | ''
40 | >>> session_state.favorite_color
41 | 'black'
42 | """
43 | for key, val in kwargs.items():
44 | setattr(self, key, val)
45 |
46 |
47 | def get(**kwargs):
48 | """Gets a SessionState object for the current session.
49 | Creates a new object if necessary.
50 | Parameters
51 | ----------
52 | **kwargs : any
53 | Default values you want to add to the session state, if we're creating a
54 | new one.
55 | Example
56 | -------
57 | >>> session_state = get(user_name='', favorite_color='black')
58 | >>> session_state.user_name
59 | ''
60 | >>> session_state.user_name = 'Mary'
61 | >>> session_state.favorite_color
62 | 'black'
63 | Since you set user_name above, next time your script runs this will be the
64 | result:
65 | >>> session_state = get(user_name='', favorite_color='black')
66 | >>> session_state.user_name
67 | 'Mary'
68 | """
69 | # Hack to get the session object from Streamlit.
70 |
71 | ctx = ReportThread.get_report_ctx()
72 |
73 | this_session = None
74 |
75 | current_server = Server.get_current()
76 | if hasattr(current_server, '_session_infos'):
77 | # Streamlit < 0.56
78 | session_infos = Server.get_current()._session_infos.values()
79 | else:
80 | session_infos = Server.get_current()._session_info_by_id.values()
81 |
82 | for session_info in session_infos:
83 | s = session_info.session
84 | if (
85 | # Streamlit < 0.54.0
86 | (hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
87 | or
88 | # Streamlit >= 0.54.0
89 | (not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
90 | or
91 | # Streamlit >= 0.65.2
92 | (not hasattr(s,
93 | '_main_dg') and s._uploaded_file_mgr == ctx.uploaded_file_mgr)
94 | ):
95 | this_session = s
96 |
97 | if this_session is None:
98 | raise RuntimeError(
99 | "Oh noes. Couldn't get your Streamlit Session object. "
100 | 'Are you doing something fancy with threads?')
101 |
102 | # Got the session object! Now let's attach some state into it.
103 |
104 | if not hasattr(this_session, '_custom_session_state'):
105 | this_session._custom_session_state = SessionState(**kwargs)
106 |
107 | return this_session._custom_session_state
108 |
--------------------------------------------------------------------------------