├── .gitignore ├── Dockerfile ├── README.md ├── requirements.txt └── src ├── app.py ├── intro.py └── neo4j_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9.5-slim-buster 2 | 3 | EXPOSE 8501 4 | 5 | WORKDIR /app 6 | 7 | COPY requirements.txt . 8 | RUN pip install -U pip 9 | RUN pip install --no-cache-dir -r requirements.txt 10 | 11 | RUN apt-get update && apt-get install nano 12 | 13 | COPY ./src /examples 14 | 15 | CMD streamlit run /examples/app.py -- "bolt://sandbox.ip.address:7687" "neo4j" "sandbox-user-password" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # neo4j_streamlit 2 | ### Written by: Dr. Clair J. Sullivan, Data Science Advocate, Neo4j 3 | #### email: clair.sullivan@neo4j.com 4 | #### Twitter: @CJLovesData1 5 | #### Last updated: 2021-06-10 6 | 7 | ## Introduction 8 | 9 | The purpose of this code is to demonstrate how to create a [Streamlit](https://streamlit.io) dashboard to be used for the visualization of embeddings created by the [Graph Data Science](https://dev.neo4j.com/graph_data_science) library in [Neo4j](https://dev.neo4j.com/neo4j). This will hopefully be useful for developing an intuitive feel for these embeddings and how the different hyperparameters impact the overall results. 10 | 11 | For more information on graph embeddings, you can read my blog post on [Getting Started with Graph Embeddings in Neo4j](https://dev.neo4j.com/intro_graph_emb_tds) in Towards Data Science (written in May, 2021). 12 | 13 | ## How to get started 14 | 15 | 1. Create a [Neo4j Sandbox](https://dev.neo4j.com/sandbox) 16 | - Select "Launch a Free Instance" 17 | - Select "New Project" 18 | - Choose "Graph Data Science": This will create an instance pre-populated with the Game of Thrones graph we will use to get started 19 | - Choose the drop-down on the right of the instance and select "Connection details" 20 | - Record the Bolt URL and Password 21 | 2. Build the Streamlit container 22 | - Edit the Dockerfile to include the Bolt URL and password from the previous step 23 | - From the root directory of this repo, type `docker build -t neo_stream .` 24 | - Once built, type `docker run -p 8501:8501 -v $PWD/src:/examples neo_stream` 25 | 3. Using a browser, navigate to the provided IP address 26 | - This typically will be something like `http://172.17.0.2:8501` 27 | 4. In the side panel, provide a name of the in-memory graph to be created and click "Create in-memory graph" 28 | 5. Have fun with the functionality in the main area of the dashboard! How well can you tune the hyperparameters to get as much clustering separation between the living (blue) and the dead (red) characters??? 29 | 30 | **Note:** If you would like, you can navigate to `https://bolt.URL.address:7474` and you can interact with the graph directly via Cypher. 31 | 32 | ## Major caveats!!! 33 | 34 | 1. This is _very much_ a work in progress! As such, there are only two types of graph embeddings included (FastRP and node2vec) and not all of the hyperparameters have been added yet. Be sure to watch for future modifications and regularly pull the main branch of this repo to keep with the current version. 35 | 2. The in-memory graph that is created is only the undirected, monopartite graph `(Person)-[:INTERACTS]-(Person)`. Future versions of this code will allow you to create additional graphs. 36 | 3. This is a very small graph! It only has about 2600 nodes and not quite 17,000 relationships. And we make that even smaller by limiting the nodes and relationships as described above. Based on these facts, we do not expect the quality of the embeddings to be that great. In future work, we will use a larger graph and discuss how to optimize embeddings. 37 | 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | neo4j==4.2.1 2 | numpy==1.20.3 3 | pandas==1.2.4 4 | scikit-learn==0.24.1 5 | streamlit==0.82.0 -------------------------------------------------------------------------------- /src/app.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import altair as alt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from sklearn.manifold import TSNE 8 | 9 | import streamlit as st 10 | 11 | from neo4j_utils import Neo4jConnection 12 | 13 | 14 | parser = argparse.ArgumentParser(description='Add uri, user, and pwd for Neo4j connection.') 15 | parser.add_argument('uri', type=str, default=None) 16 | parser.add_argument('user', type=str, default='neo4j') 17 | parser.add_argument('pwd', type=str, default=None) 18 | 19 | args=parser.parse_args() 20 | 21 | neo4j_utils = Neo4jConnection(uri=args.uri, user=args.user, pwd=args.pwd) 22 | 23 | st.set_page_config(layout="wide") 24 | 25 | ############################## 26 | # 27 | # Sidebar content 28 | # 29 | ############################## 30 | 31 | def get_node_labels(): 32 | 33 | label_ls = [] 34 | label_type_query = """CALL db.labels()""" 35 | result = neo4j_utils.query(label_type_query) 36 | for el in result: 37 | #st.write(el[0]) 38 | label_ls.append(el[0]) 39 | return label_ls 40 | 41 | 42 | def get_rel_types(): 43 | 44 | rel_ls = [] 45 | rel_type_query = """CALL db.relationshipTypes()""" 46 | result = neo4j_utils.query(rel_type_query) 47 | for el in result: 48 | rel_ls.append(el[0]) 49 | return rel_ls 50 | 51 | 52 | def get_graph_list(): 53 | 54 | graph_ls = [] 55 | list_graph_query = """CALL gds.graph.list()""" 56 | existing_graphs = neo4j_utils.query(list_graph_query) 57 | if existing_graphs: 58 | for el in existing_graphs: 59 | graph_ls.append(el[1]) 60 | return graph_ls 61 | 62 | ############################## 63 | 64 | ##### Get listing of graphs 65 | 66 | intro_text = """ 67 | # Introduction 68 | 69 | This is an embedding visualizer for the Neo4j Graph Data Science Game of Thrones graph. 70 | It is intended to be run in a free [Neo4j Sandbox](dev.neo4j.com/sandbox) instance. 71 | See the repository [README](https://github.com/cj2001/social_media_streamlit/blob/main/README.md) 72 | for more information on how to create a Sandbox and populate it with the graph. 73 | 74 | The graph we will be working with is the monopartite, undirected graph of `(Person)-[:INTERACTS]-(Person)`. 75 | Using this graph, we will explore the graph embeddings using the FastRP and node2vec algorithms. Most, 76 | but not all, of the hyperparameters are included so you can get a feel for how each impacts the 77 | overall embedding results. The goal is to observe the embedding difference of dead (index label = 0) and 78 | non-dead (index label = 1) characters with the hope that we can create differentiable clusters. 79 | 80 | **This is not an all-inclusive approach and much will be added to this dashboard over time!!!** 81 | """ 82 | 83 | st.sidebar.markdown(intro_text) 84 | 85 | st.sidebar.markdown("""---""") 86 | 87 | st.sidebar.header('Graph management') 88 | 89 | if st.sidebar.button('Get graph list'): 90 | graph_ls = get_graph_list() 91 | if len(graph_ls) > 0: 92 | for el in graph_ls: 93 | st.sidebar.write(el) 94 | else: 95 | st.sidebar.write('There are currently no graphs in memory.') 96 | 97 | st.sidebar.markdown("""---""") 98 | 99 | ##### Create in-memory graphs 100 | 101 | create_graph = st.sidebar.text_input('Name of graph to be created: ') 102 | if st.sidebar.button('Create in-memory graph'): 103 | 104 | create_graph_query = """CALL gds.graph.create( 105 | '%s', 106 | 'Person', 107 | { 108 | INTERACTS_WITH: { 109 | type: 'INTERACTS', 110 | orientation: 'UNDIRECTED' 111 | } 112 | } 113 | ) 114 | """ % (create_graph) 115 | result = neo4j_utils.query(create_graph_query) 116 | st.sidebar.write('Graph ', result[0][2], 'has ', result[0][3], 'nodes and ', result[0][4],' relationships.') 117 | 118 | st.sidebar.markdown("""---""") 119 | 120 | ##### Drop in-memory graph 121 | 122 | drop_graph = st.sidebar.selectbox('Choose an graph to drop: ', get_graph_list()) 123 | if st.sidebar.button('Drop in-memory graph'): 124 | drop_graph_query = """CALL gds.graph.drop('{}')""".format(drop_graph) 125 | result = neo4j_utils.query(drop_graph_query) 126 | st.sidebar.write('Graph ', result[0][0],' has been dropped.') 127 | 128 | st.sidebar.markdown("""---""") 129 | 130 | ############################## 131 | # 132 | # Main panel content 133 | # 134 | ############################## 135 | 136 | def create_graph_df(): 137 | 138 | df_query = """MATCH (n) RETURN n.name, n.frp_emb, n.n2v_emb""" 139 | df = pd.DataFrame([dict(_) for _ in neo4j_utils.query(df_query)]) 140 | 141 | return df 142 | 143 | 144 | def create_tsne_plot(emb_name='p.n2v_emb', n_components=2): 145 | 146 | tsne_query = """MATCH (p:Person) RETURN p.name AS name, p.death_year AS death_year, {} AS vec 147 | """.format(emb_name) 148 | df = pd.DataFrame([dict(_) for _ in neo4j_utils.query(tsne_query)]) 149 | df['is_dead'] = np.where(df['death_year'].isnull(), 1, 0) 150 | 151 | X_emb = TSNE(n_components=n_components).fit_transform(list(df['vec'])) 152 | 153 | tsne_df = pd.DataFrame(data = { 154 | 'x': [value[0] for value in X_emb], 155 | 'y': [value[1] for value in X_emb], 156 | 'label': df['is_dead'] 157 | }) 158 | 159 | return tsne_df 160 | 161 | ############################## 162 | 163 | 164 | col1, col2 = st.beta_columns((1, 2)) 165 | 166 | ##### 167 | # 168 | # Embedding column (col1) 169 | # 170 | ##### 171 | 172 | with col1: 173 | #emb_graph = st.text_input('Enter graph name for embedding creation:') 174 | emb_graph = st.selectbox('Enter graph name for embedding creation: ', get_graph_list()) 175 | 176 | ##### FastRP embedding creation 177 | 178 | with st.beta_expander('FastRP embedding creation'): 179 | st.markdown("Description of hyperparameters can be found [here](https://neo4j.com/docs/graph-data-science/current/algorithms/fastrp/#algorithms-embeddings-fastrp)") 180 | frp_dim = st.slider('FastRP embedding dimenson', value=4, min_value=2, max_value=50) 181 | frp_it_weight1 = st.slider('Iteration weight 1', value=0., min_value=0., max_value=1.) 182 | frp_it_weight2 = st.slider('Iteration weight 2', value=1., min_value=0., max_value=1.) 183 | frp_it_weight3 = st.slider('Iteration weight 3', value=1., min_value=0., max_value=1.) 184 | frp_norm = st.slider('FRP normalization strength', value=0., min_value=-1., max_value=1.) 185 | frp_seed = st.slider('Random seed', value=42, min_value=1, max_value=99) 186 | 187 | if st.button('Create FastRP embedding'): 188 | frp_query = """CALL gds.fastRP.write('%s', { 189 | embeddingDimension: %d, 190 | iterationWeights: [%f, %f, %f], 191 | normalizationStrength: %f, 192 | randomSeed: %d, 193 | writeProperty: 'frp_emb' 194 | }) 195 | """ % (emb_graph, frp_dim, frp_it_weight1, 196 | frp_it_weight2, frp_it_weight3, frp_norm, 197 | frp_seed) 198 | result = neo4j_utils.query(frp_query) 199 | 200 | ##### node2vec embedding creation 201 | 202 | with st.beta_expander('node2vec embedding creation'): 203 | st.markdown("Description of hyperparameters can be found [here](https://neo4j.com/docs/graph-data-science/current/algorithms/node2vec/)") 204 | n2v_dim = st.slider('node2vec embedding dimenson', value=4, min_value=2, max_value=50) 205 | n2v_walk_length = st.slider('Walk length', value=80, min_value=2, max_value=160) 206 | n2v_walks_node = st.slider('Walks per node', value=10, min_value=2, max_value=50) 207 | n2v_io_factor = st.slider('inOutFactor', value=1.0, min_value=0.001, max_value=1.0, step=0.05) 208 | n2v_ret_factor = st.slider('returnFactor', value=1.0, min_value=0.001, max_value=1.0, step=0.05) 209 | n2v_neg_samp_rate = st.slider('negativeSamplingRate', value=10, min_value=5, max_value=20) 210 | n2v_iterations = st.slider('Number of training iterations', value=1, min_value=1, max_value=10) 211 | n2v_init_lr = st.select_slider('Initial learning rate', value=0.01, options=[0.001, 0.005, 0.01, 0.05, 0.1]) 212 | n2v_min_lr = st.select_slider('Minimum learning rate', value=0.0001, options=[0.0001, 0.0005, 0.001, 0.005, 0.01]) 213 | n2v_walk_bs = st.slider('Walk buffer size', value=1000, min_value=100, max_value=2000) 214 | n2v_seed = st.slider('Random seed:', value=42, min_value=1, max_value=99) 215 | 216 | if st.button('Create node2vec embedding'): 217 | n2v_query = """CALL gds.beta.node2vec.write('%s', { 218 | embeddingDimension: %d, 219 | walkLength: %d, 220 | walksPerNode: %d, 221 | inOutFactor: %f, 222 | returnFactor: %f, 223 | negativeSamplingRate: %d, 224 | iterations: %d, 225 | initialLearningRate: %f, 226 | minLearningRate: %f, 227 | walkBufferSize: %d, 228 | randomSeed: %d, 229 | writeProperty: 'n2v_emb' 230 | }) 231 | """ % (emb_graph, n2v_dim, n2v_walk_length, 232 | n2v_walks_node, n2v_io_factor, n2v_ret_factor, 233 | n2v_neg_samp_rate, n2v_iterations, n2v_init_lr, 234 | n2v_min_lr, n2v_walk_bs, n2v_seed) 235 | result = neo4j_utils.query(n2v_query) 236 | 237 | st.markdown("---") 238 | 239 | if st.button('Show embeddings'): 240 | df = create_graph_df() 241 | st.dataframe(df) 242 | 243 | if st.button('Drop embeddings'): 244 | neo4j_utils.query('MATCH (n) REMOVE n.frp_emb') 245 | neo4j_utils.query('MATCH (n) REMOVE n.n2v_emb') 246 | 247 | ##### 248 | # 249 | # t-SNE column (col2) 250 | # 251 | ##### 252 | 253 | with col2: 254 | st.header('t-SNE') 255 | 256 | plt_emb = st.selectbox('Choose an embedding to plot: ', ['FastRP', 'node2vec']) 257 | if plt_emb == 'FastRP': 258 | emb_name = 'p.frp_emb' 259 | else: 260 | emb_name = 'p.n2v_emb' 261 | 262 | if st.button('Plot embeddings'): 263 | 264 | tsne_df = create_tsne_plot(emb_name=emb_name) 265 | ch_alt = alt.Chart(tsne_df).mark_point().encode( 266 | x='x', 267 | y='y', 268 | color=alt.Color('label:O', scale=alt.Scale(range=['red', 'blue'])) 269 | ).properties(width=800, height=800) 270 | st.altair_chart(ch_alt, use_container_width=True) 271 | 272 | 273 | 274 | 275 | -------------------------------------------------------------------------------- /src/intro.py: -------------------------------------------------------------------------------- 1 | import pandas 2 | import numpy 3 | import streamlit as st 4 | 5 | st.title("Streamlit intro") 6 | st.write("Examples below were taken from: https://streamlit.io/docs/getting_started.html#get-started") 7 | 8 | 9 | st.subheader("Here's our first attempt at using data to create a table:") 10 | 11 | st.code(""" 12 | df = pandas.DataFrame({ 13 | 'first column': [1, 2, 3, 4], 14 | 'second column': [10, 20, 30, 40] 15 | }) 16 | """, language='python') 17 | df = pandas.DataFrame({ 18 | 'first column': [1, 2, 3, 4], 19 | 'second column': [10, 20, 30, 40] 20 | }) 21 | # here we use 'magic', 22 | # any time that Streamlit sees a variable or a literal value on its own line, 23 | # it automatically writes that to your app using st.write() 24 | df 25 | 26 | 27 | 28 | st.subheader("Draw a line hart:") 29 | st.code(""" 30 | chart_data = pandas.DataFrame( 31 | numpy.random.randn(20, 3), 32 | columns=['a', 'b', 'c']) 33 | st.line_chart(chart_data) 34 | """, language='python') 35 | 36 | chart_data = pandas.DataFrame( 37 | numpy.random.randn(20, 3), 38 | columns=['a', 'b', 'c']) 39 | st.line_chart(chart_data) 40 | 41 | 42 | st.subheader("Let’s use Numpy to generate some sample data and plot it on a map of San Francisco:") 43 | 44 | st.code(""" 45 | map_data = pandas.DataFrame( 46 | numpy.random.randn(1000, 2) / [50, 50] + [37.76, -122.4], 47 | columns=['lat', 'lon']) 48 | st.map(map_data) 49 | """, language='python') 50 | 51 | map_data = pandas.DataFrame( 52 | numpy.random.randn(1000, 2) / [50, 50] + [37.76, -122.4], 53 | columns=['lat', 'lon']) 54 | st.map(map_data) 55 | 56 | 57 | 58 | st.title("Add interactivity with widgets") 59 | 60 | st.subheader("Use checkboxes to show/hide data") 61 | 62 | st.code(""" 63 | if st.checkbox('Show dataframe'): 64 | chart_data = pandas.DataFrame( 65 | numpy.random.randn(20, 3), 66 | columns=['a', 'b', 'c']) 67 | st.line_chart(chart_data) 68 | """, language='python') 69 | 70 | if st.checkbox('Show dataframe'): 71 | chart_data = pandas.DataFrame( 72 | numpy.random.randn(20, 3), 73 | columns=['a', 'b', 'c']) 74 | 75 | st.line_chart(chart_data) 76 | 77 | 78 | st.subheader("Put widgets in a sidebar") 79 | 80 | st.code(""" 81 | if st.checkbox('Show in sidebar'): 82 | option = st.sidebar.selectbox( 83 | 'Which number do you like best?', 84 | ["a", "b","c"]) 85 | 'You selected:', option 86 | """, language='python') 87 | 88 | if st.checkbox('Show in sidebar'): 89 | option = st.sidebar.selectbox( 90 | 'Which number do you like best?', 91 | ["a", "b","c"]) 92 | 93 | 'You selected:', option 94 | 95 | 96 | 97 | st.subheader("Show progress") 98 | 99 | st.code(""" 100 | import time 101 | # Add a placeholder 102 | latest_iteration = st.empty() 103 | bar = st.progress(0) 104 | for i in range(100): 105 | # Update the progress bar with each iteration. 106 | latest_iteration.text(f'Iteration {i+1}') 107 | bar.progress(i + 1) 108 | time.sleep(0.1) 109 | '...and now we\'re done!' 110 | """, language='python') 111 | 112 | import time 113 | 114 | # Add a placeholder 115 | latest_iteration = st.empty() 116 | bar = st.progress(0) 117 | 118 | for i in range(100): 119 | # Update the progress bar with each iteration. 120 | latest_iteration.text(f'Iteration {i+1}') 121 | bar.progress(i + 1) 122 | time.sleep(0.1) 123 | 124 | '...and now we\'re done!' -------------------------------------------------------------------------------- /src/neo4j_utils.py: -------------------------------------------------------------------------------- 1 | from neo4j import GraphDatabase 2 | 3 | class Neo4jConnection: 4 | 5 | def __init__(self, uri, user, pwd): 6 | self.__uri = uri 7 | self.__user = user 8 | self.__pwd = pwd 9 | self.__driver = None 10 | try: 11 | self.__driver = GraphDatabase.driver(self.__uri, auth=(self.__user, self.__pwd)) 12 | except Exception as e: 13 | print("Failed to create the driver:", e) 14 | 15 | def close(self): 16 | if self.__driver is not None: 17 | self.__driver.close() 18 | 19 | def query(self, query, parameters=None, db=None): 20 | assert self.__driver is not None, "Driver not initialized!" 21 | session = None 22 | response = None 23 | try: 24 | session = self.__driver.session(database=db) if db is not None else self.__driver.session() 25 | response = list(session.run(query, parameters)) 26 | except Exception as e: 27 | print("Query failed:", e) 28 | finally: 29 | if session is not None: 30 | session.close() 31 | return response 32 | 33 | 34 | --------------------------------------------------------------------------------