├── cord_ann ├── __init__.py ├── clusters.py ├── embeddings.py ├── mapping.py └── index.py ├── frontend ├── .gitignore ├── .env.defaults ├── src │ ├── themes │ │ ├── index.js │ │ ├── dark.js │ │ ├── light.js │ │ └── colors.js │ ├── index.js │ ├── components │ │ ├── services │ │ │ └── getSearchResults.js │ │ ├── Progress.js │ │ ├── LandingPage.js │ │ ├── TitleAppBar.js │ │ ├── SearchBar.js │ │ ├── ResultShow.js │ │ └── ResultCard.js │ ├── styles │ │ └── d3-overrides.css │ ├── index.html │ └── App.js ├── server.js ├── docker-compose.yml ├── Dockerfile ├── .eslintrc.js ├── babel.config.js ├── package.json ├── webpack.config.js └── README.md ├── imgs └── cord_ann_example.gif ├── .gitmodules ├── requirements.txt ├── setup.py ├── LICENCE ├── generate_embeddings.py ├── Dockerfile ├── download_data.py ├── cluster_sentences.py ├── search_index.py ├── create_index.py ├── extract_sentences.py ├── index_server.py └── README.md /cord_ann/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /frontend/.gitignore: -------------------------------------------------------------------------------- 1 | node_modules -------------------------------------------------------------------------------- /frontend/.env.defaults: -------------------------------------------------------------------------------- 1 | API_BASE_URL=localhost:5000 -------------------------------------------------------------------------------- /imgs/cord_ann_example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeanNaren/CORD-19-ANN/HEAD/imgs/cord_ann_example.gif -------------------------------------------------------------------------------- /frontend/src/themes/index.js: -------------------------------------------------------------------------------- 1 | export { default as dark } from './dark' 2 | export { default as light } from './light' 3 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "sentence-transformers"] 2 | path = sentence-transformers 3 | url = https://github.com/SeanNaren/sentence-transformers.git 4 | -------------------------------------------------------------------------------- /frontend/src/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react' 2 | import { render } from 'react-dom' 3 | import App from './App' 4 | 5 | 6 | render(, document.getElementById('root')) 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pysbd 2 | sentencepiece 3 | scispacy 4 | https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.2.4/en_core_sci_sm-0.2.4.tar.gz 5 | transformers 6 | ./sentence-transformers 7 | tornado 8 | pandas 9 | nmslib -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='cord_ann', 4 | version='0.1', 5 | description='CORD-19-ANN', 6 | author='Sean Narenthiran', 7 | packages=find_packages(), 8 | zip_safe=False) 9 | -------------------------------------------------------------------------------- /frontend/src/components/services/getSearchResults.js: -------------------------------------------------------------------------------- 1 | import axios from 'axios'; 2 | 3 | export function getSearchResults(sentence) { 4 | return axios({ 5 | url: process.env.API_BASE_URL, 6 | method: 'post', 7 | data: sentence 8 | }) 9 | } -------------------------------------------------------------------------------- /frontend/src/styles/d3-overrides.css: -------------------------------------------------------------------------------- 1 | .Axis path, 2 | .Axis line { 3 | stroke: rgba(0, 0, 0, 0.12); 4 | } 5 | 6 | /* .Axis text { 7 | font-size: 11px; 8 | fill: rgba(0, 0, 0, 0.87); 9 | } */ 10 | 11 | .Axis-Bottom text { 12 | transform: rotate(90deg); 13 | text-anchor: start; 14 | } -------------------------------------------------------------------------------- /frontend/src/themes/dark.js: -------------------------------------------------------------------------------- 1 | import colors from './colors' 2 | const { dark, purple } = colors; 3 | 4 | export default { 5 | palette: { 6 | type: "dark", 7 | primary: { 8 | main: dark, 9 | }, 10 | secondary: { 11 | main: purple, 12 | }, 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /frontend/src/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | CORD-19-ANN 8 | 9 | 10 | 11 |
12 | 13 | 14 | -------------------------------------------------------------------------------- /frontend/server.js: -------------------------------------------------------------------------------- 1 | var path = require('path'); 2 | var express = require('express'); 3 | 4 | var app = express(); 5 | 6 | app.use(express.static(path.join(__dirname, 'build'))); 7 | app.set('port', process.env.PORT || 8080); 8 | 9 | var server = app.listen(app.get('port'), function () { 10 | console.log('listening on port ', server.address().port); 11 | }); -------------------------------------------------------------------------------- /frontend/src/themes/light.js: -------------------------------------------------------------------------------- 1 | import colors from './colors' 2 | const { dark, purple } = colors; 3 | 4 | export default { 5 | palette: { 6 | type: "light", 7 | primary: { 8 | main: dark, 9 | }, 10 | secondary: { 11 | main: purple, 12 | }, 13 | }, 14 | typography: { 15 | body2: { 16 | fontSize: '0.75rem' 17 | } 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /frontend/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.3' 2 | 3 | services: 4 | frontend: 5 | container_name: frontend 6 | build: 7 | context: . 8 | dockerfile: Dockerfile 9 | volumes: 10 | - '.:/usr/src/app' 11 | - '/usr/node_modules' 12 | environment: 13 | - NODE_ENV=development 14 | ports: 15 | - '8080:8080' 16 | command: run doc 17 | -------------------------------------------------------------------------------- /frontend/Dockerfile: -------------------------------------------------------------------------------- 1 | # base image 2 | FROM node:8.12.0 3 | 4 | # set working directory 5 | RUN mkdir /usr/src/app 6 | WORKDIR /usr/src/app 7 | 8 | # add `/usr/src/app/node_modules/.bin` to $PATH 9 | ENV PATH /usr/src/app/node_modules/.bin:$PATH 10 | 11 | # install and cache app dependencies 12 | COPY package.json /usr/src/app/package.json 13 | RUN npm install --silent 14 | 15 | # start app 16 | ENTRYPOINT ["npm"] 17 | -------------------------------------------------------------------------------- /cord_ann/clusters.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | 3 | 4 | def cluster_embeddings(embeddings, sentences, num_clusters): 5 | clustering_model = KMeans(n_clusters=num_clusters) 6 | clustering_model.fit(embeddings) 7 | cluster_assignment = clustering_model.labels_ 8 | 9 | clustered_sentences = [[] for i in range(num_clusters)] 10 | for sentence_id, cluster_id in enumerate(cluster_assignment): 11 | clustered_sentences[cluster_id].append(sentences[sentence_id]) 12 | return clustered_sentences 13 | -------------------------------------------------------------------------------- /frontend/src/App.js: -------------------------------------------------------------------------------- 1 | import React from 'react' 2 | import { createMuiTheme, MuiThemeProvider } from '@material-ui/core/styles' 3 | import { CssBaseline } from '@material-ui/core' 4 | import * as Themes from './themes' 5 | import LandingPage from './components/LandingPage' 6 | import TitleAppBar from './components/TitleAppBar' 7 | function App() { 8 | return ( 9 | 10 | 11 | 12 | 13 | 14 | ) 15 | } 16 | 17 | export default App 18 | -------------------------------------------------------------------------------- /frontend/src/components/Progress.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { makeStyles } from '@material-ui/core/styles'; 3 | import LinearProgress from '@material-ui/core/LinearProgress'; 4 | 5 | const useStyles = makeStyles((theme) => ({ 6 | root: { 7 | width: '100%', 8 | '& > * + *': { 9 | marginLeft: theme.spacing(2), 10 | }, 11 | }, 12 | })); 13 | 14 | export default function SearchLinearProgress() { 15 | const classes = useStyles(); 16 | 17 | return ( 18 |
19 | 20 |
21 | ); 22 | } -------------------------------------------------------------------------------- /frontend/.eslintrc.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | "env": { 3 | "browser": true, 4 | "es6": true 5 | }, 6 | "extends": [ 7 | "eslint:recommended", 8 | "plugin:react/recommended" 9 | ], 10 | "globals": { 11 | "Atomics": "readonly", 12 | "SharedArrayBuffer": "readonly" 13 | }, 14 | "parserOptions": { 15 | "ecmaFeatures": { 16 | "jsx": true 17 | }, 18 | "ecmaVersion": 2018, 19 | "sourceType": "module" 20 | }, 21 | "plugins": [ 22 | "react" 23 | ], 24 | "rules": { 25 | } 26 | }; -------------------------------------------------------------------------------- /frontend/src/themes/colors.js: -------------------------------------------------------------------------------- 1 | const Colors = { 2 | //theme colors 3 | dark: "#424242", 4 | purple: "#71368A", 5 | 6 | //dataset markers colors 7 | golden: "#FFD700", 8 | red: "#cc2b1d", 9 | blue: "#6236FF", 10 | 11 | //AQ sensor icon 12 | white: "#FFFFFF", 13 | darkGreen: "#327d4a", 14 | gray: "#A8A8A8", 15 | 16 | //Drawer background 17 | grayLight: "#fafafa", 18 | 19 | // Choropleth color bounds 20 | navy: '#004c6d', 21 | green: '#42f54e', 22 | orange: '#F28F3B', 23 | black: '#000000', 24 | dark_navy: '#022F59', 25 | yellow: '#FFD000' 26 | }; 27 | 28 | export default Colors; 29 | -------------------------------------------------------------------------------- /cord_ann/embeddings.py: -------------------------------------------------------------------------------- 1 | from sentence_transformers import SentenceTransformer 2 | 3 | 4 | class EmbeddingModel: 5 | def __init__(self, model_name_or_path, device, batch_size, show_progress_bar=True): 6 | self.model = SentenceTransformer(model_name_or_path=model_name_or_path, 7 | device=device) 8 | self.batch_size = batch_size 9 | self.show_progress_bar = show_progress_bar 10 | 11 | def encode_sentences(self, sentences): 12 | sentence_embeddings = self.model.encode(sentences=sentences, 13 | batch_size=self.batch_size, 14 | show_progress_bar=self.show_progress_bar) 15 | return sentence_embeddings 16 | -------------------------------------------------------------------------------- /frontend/src/components/LandingPage.js: -------------------------------------------------------------------------------- 1 | import React from 'react' 2 | import SearchBar from './SearchBar' 3 | import { makeStyles } from '@material-ui/core/styles'; 4 | 5 | const useStyles = makeStyles(() => ({ 6 | root: { 7 | padding: '32px 22px', 8 | display: 'flex', 9 | flexDirection: 'column', 10 | alignItems: 'center', 11 | } 12 | })); 13 | 14 | function LandingPage() { 15 | const classes = useStyles(); 16 | 17 | 18 | return ( 19 |
20 |

Enter a sentence to query the index. Results include the title of the journal and the sentence that was found ordered by similarity. More details are shown if expanded.

21 | 22 |
23 | ) 24 | } 25 | export default LandingPage; -------------------------------------------------------------------------------- /frontend/src/components/TitleAppBar.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { makeStyles } from '@material-ui/core/styles'; 3 | import AppBar from '@material-ui/core/AppBar'; 4 | import Toolbar from '@material-ui/core/Toolbar'; 5 | import Typography from '@material-ui/core/Typography'; 6 | 7 | const useStyles = makeStyles(() => ({ 8 | appRoot: { 9 | flexGrow: 1, 10 | }, 11 | appTitle: { 12 | flexGrow: 1, 13 | }, 14 | appBar: { 15 | alignItems: 'center', 16 | backgroundColor: '#c13164', 17 | } 18 | })); 19 | 20 | export default function TitleAppBar() { 21 | const classes = useStyles(); 22 | 23 | return ( 24 |
25 | 26 | 27 | 28 | CORD-19-ANN 29 | 30 | 31 | 32 |
33 | ); 34 | } -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Sean Naren 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /generate_embeddings.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | from cord_ann.embeddings import EmbeddingModel 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--sentences_path', default="cord_19_dataset_formatted/cord_19_sentences.txt", 10 | help='Path to extracted sentences') 11 | parser.add_argument('--embedding_path', default="embeddings.npy", 12 | help='Output path of the generated embeddings') 13 | parser.add_argument('--model_name_or_path', default='bert-base-nli-mean-tokens') 14 | parser.add_argument('--batch_size', default=8, type=int) 15 | parser.add_argument('--device', default='cuda') 16 | args = parser.parse_args() 17 | 18 | with open(args.sentences_path) as f: 19 | sentences = f.read().split('\n') 20 | 21 | model = EmbeddingModel(model_name_or_path=args.model_name_or_path, 22 | device=args.device, 23 | batch_size=args.batch_size) 24 | sentence_embeddings = model.encode_sentences(sentences=sentences) 25 | 26 | np.save(args.embedding_path, sentence_embeddings) 27 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | 3 | ARG PYTHON_VERSION=3.7 4 | RUN apt-get update && apt-get install -y --no-install-recommends \ 5 | build-essential \ 6 | cmake \ 7 | git \ 8 | wget \ 9 | ca-certificates \ 10 | libjpeg-dev \ 11 | libpng-dev && \ 12 | rm -rf /var/lib/apt/lists/* 13 | 14 | 15 | RUN wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \ 16 | chmod +x ~/miniconda.sh && \ 17 | ~/miniconda.sh -b -p /opt/conda && \ 18 | rm ~/miniconda.sh && \ 19 | /opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scikit-learn scipy ipython mkl mkl-include ninja cython typing && \ 20 | /opt/conda/bin/conda clean -ya 21 | ENV PATH /opt/conda/bin:$PATH 22 | 23 | RUN conda install pytorch cpuonly faiss-cpu -c pytorch 24 | RUN conda install -c conda-forge spacy 25 | 26 | WORKDIR /workspace/ 27 | 28 | # Install CORD-19-ANN 29 | ADD . /workspace/CORD-19-ANN 30 | 31 | WORKDIR /workspace/CORD-19-ANN 32 | 33 | # Pre-requisities. Errors out if not installed before running requirments install 34 | RUN pip install pysbd sentencepiece transformers 35 | RUN pip install -r requirements.txt 36 | RUN pip install . 37 | 38 | ENTRYPOINT ["/opt/conda/bin/python"] -------------------------------------------------------------------------------- /download_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import urllib.request 4 | import tarfile 5 | 6 | from tqdm import tqdm 7 | 8 | datasets = ['comm_use_subset.tar.gz', 'noncomm_use_subset.tar.gz', 'custom_license.tar.gz'] 9 | metadata = "metadata.csv" 10 | server = "https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/2020-03-27/" 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--output_path', default="cord_19_dataset/") 14 | args = parser.parse_args() 15 | os.makedirs(args.output_path, exist_ok=True) 16 | 17 | print('Beginning download of datasets') 18 | for dataset in tqdm(datasets, total=len(datasets)): 19 | url = server + dataset 20 | dataset_path = os.path.join(args.output_path, dataset) 21 | urllib.request.urlretrieve(url, dataset_path) 22 | 23 | print("Extracting", dataset) 24 | tar = tarfile.open(dataset_path) 25 | tar.extractall(args.output_path) 26 | tar.close() 27 | os.remove(dataset_path) 28 | 29 | print('Downloading metadata') 30 | metadata_path = os.path.join(args.output_path, 'metadata.csv') 31 | urllib.request.urlretrieve(server + metadata, metadata_path) 32 | 33 | print("All datasets downloaded and extracted") 34 | -------------------------------------------------------------------------------- /cord_ann/mapping.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as df 3 | from tqdm import tqdm 4 | 5 | 6 | def create_sentence_to_article_mapping(articles): 7 | mapping = [] 8 | for article in tqdm(articles, total=len(articles)): 9 | for paragraph_idx, paragraph in enumerate(article['body_text']): 10 | sentence_mappings = [{ 11 | "sentence_idx": sentence_idx, 12 | "paper_id": article['paper_id'], 13 | "paragraph_idx": paragraph_idx 14 | } for sentence_idx in range(len(paragraph['sentences']))] 15 | mapping += sentence_mappings 16 | return mapping 17 | 18 | 19 | def load_sentence_to_article_mapping(mapping_path): 20 | with open(mapping_path) as f: 21 | mapping = json.load(f) 22 | return mapping 23 | 24 | 25 | def load_metadata(metadata_path): 26 | # TODO not efficient to search metadata like this. 27 | metadata_df = df.read_csv(metadata_path) 28 | return metadata_df 29 | 30 | 31 | def flatten_sentences(articles): 32 | flattened_sentences = [] 33 | for article in tqdm(articles, total=len(articles)): 34 | for paragraph in article['body_text']: 35 | sentences = paragraph['sentences'] 36 | flattened_sentences += sentences 37 | return flattened_sentences 38 | -------------------------------------------------------------------------------- /cluster_sentences.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from cord_ann.clusters import cluster_embeddings 5 | from cord_ann.embeddings import load_embedding_model, encode_sentences 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser(description='Takes a text file of sentences and applies clustering ' 9 | 'based on a pre-trained sentence embedding model') 10 | parser.add_argument('--input_path', default='sentences.txt') 11 | parser.add_argument('--model_name_or_path', default='bert-base-nli-mean-tokens') 12 | parser.add_argument('--batch_size', default=8, type=int, 13 | help='Batch size for the transformer model encoding') 14 | parser.add_argument('--device', default='cpu', 15 | help='Set to cuda to use the GPU') 16 | parser.add_argument('--num_clusters', default=5, type=int, 17 | help='Number of clusters for Kmeans') 18 | args = parser.parse_args() 19 | sentences = Path(args.input_path).read_text().split('\n') 20 | model = load_embedding_model(model_name_or_path=args.model_name_or_path, 21 | device=args.device) 22 | 23 | embeddings = encode_sentences(model=model, 24 | batch_size=args.batch_size, 25 | sentences=sentences) 26 | clusters = cluster_embeddings(sentences=sentences, 27 | embeddings=embeddings, 28 | num_clusters=args.num_clusters) 29 | 30 | for i, cluster in enumerate(clusters): 31 | print("Cluster ", i + 1) 32 | print(cluster) 33 | print("") 34 | -------------------------------------------------------------------------------- /search_index.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | 5 | from cord_ann.mapping import load_sentence_to_article_mapping, load_metadata 6 | 7 | from cord_ann.embeddings import EmbeddingModel 8 | from cord_ann.index import search_args, Index, paths_from_dataset_path 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser = search_args(parser) 13 | parser.add_argument('--input_path', default="sentences.txt") 14 | parser.add_argument('--output_path', default="search.json") 15 | args = parser.parse_args() 16 | articles_path, _, mapping_path, metadata_path = paths_from_dataset_path(args.dataset_path) 17 | 18 | sentences = Path(args.input_path).read_text().strip().split('\n') 19 | sent_article_mapping = load_sentence_to_article_mapping(mapping_path) 20 | metadata = load_metadata(metadata_path) 21 | 22 | model = EmbeddingModel(model_name_or_path=args.model_name_or_path, 23 | device=args.device, 24 | batch_size=args.batch_size, 25 | show_progress_bar=not args.silent) 26 | 27 | index = Index(index_path=args.index_path, 28 | index_type=args.index_type, 29 | articles_path=articles_path, 30 | mapping=sent_article_mapping, 31 | metadata=metadata, 32 | k=args.k, 33 | num_workers=args.num_workers) 34 | search_embeddings = model.encode_sentences(sentences=sentences) 35 | results = index.search_index(sentences=sentences, 36 | search_embeddings=search_embeddings) 37 | with open(args.output_path, 'w') as f: 38 | json.dump(results, f) 39 | -------------------------------------------------------------------------------- /create_index.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--embedding_path', default="embeddings.npy", 7 | help='Path to generated embeddings.') 8 | parser.add_argument('--output_path', default="index", 9 | help='Path to save index') 10 | parser.add_argument('--index_type', default="nmslib", type=str, choices=["nmslib", "faiss"], 11 | help='Type of index you want like to create') 12 | parser.add_argument('--faiss_config', default='PCAR256,SQ8', type=str, 13 | help='FAISS offers a large selection of parameters that can be seen here:' 14 | 'https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index') 15 | if __name__ == "__main__": 16 | args = parser.parse_args() 17 | 18 | embeddings = numpy.load(args.embedding_path) 19 | 20 | if args.index_type == 'nmslib': 21 | import nmslib 22 | 23 | index = nmslib.init(method='hnsw', space='cosinesimil') 24 | index.addDataPointBatch(embeddings) 25 | index.createIndex({'post': 2}, print_progress=True) 26 | index.saveIndex(args.output_path, save_data=False) 27 | elif args.index_type == 'faiss': 28 | import faiss 29 | 30 | d = embeddings.shape[-1] 31 | index = faiss.index_factory(d, args.faiss_config) # build the index 32 | if not index.is_trained: 33 | print("Training index.") 34 | index.train(embeddings) 35 | print("Adding embeddings to index.") 36 | index.add(embeddings) # add vectors to the index 37 | print(index.ntotal) 38 | print("Saving index.") 39 | faiss.write_index(index, args.output_path) 40 | -------------------------------------------------------------------------------- /frontend/babel.config.js: -------------------------------------------------------------------------------- 1 | module.exports = (api) => { 2 | api.cache(false); 3 | 4 | const conditionalPresets = []; 5 | 6 | const presets = [ 7 | [ 8 | "@babel/env", { 9 | "useBuiltIns": "usage", 10 | "corejs": { 11 | "version": "@3", 12 | "proposals": true, 13 | }, 14 | "targets": { 15 | "browsers": process.env.BABEL_SUPPORTED_BROWSERS || "", 16 | }, 17 | }, 18 | ], 19 | "@babel/react", 20 | ...conditionalPresets, 21 | ]; 22 | 23 | const conditionalPlugins = (process.env.BABEL_USE_MATERIAL_UI_ES_MODULES === 'true') 24 | ? [ 25 | [ 26 | "import", { 27 | "libraryName": "@material-ui/icons", 28 | "libraryDirectory": "", 29 | "camel2DashComponentName": false, 30 | }, 31 | ], 32 | ] 33 | : [ 34 | [ 35 | "import", { 36 | "libraryName": "@material-ui/core", 37 | "libraryDirectory": "", 38 | "camel2DashComponentName": false, 39 | }, 40 | "@material-ui/core", 41 | ], 42 | [ 43 | "import", { 44 | "libraryName": "@material-ui/core/colors", 45 | "libraryDirectory": "", 46 | "camel2DashComponentName": false, 47 | }, 48 | "@material-ui/core/colors" 49 | ], 50 | [ 51 | "import", { 52 | "libraryName": "@material-ui/core/styles", 53 | "libraryDirectory": "", 54 | "camel2DashComponentName": false, 55 | }, 56 | "@material-ui/core/styles" 57 | ], 58 | [ 59 | "import", { 60 | "libraryName": "@material-ui/icons", 61 | "libraryDirectory": "", 62 | "camel2DashComponentName": false, 63 | }, 64 | "@material-ui/icons" 65 | ], 66 | ] 67 | ; 68 | 69 | const plugins = [ 70 | "@babel/proposal-class-properties", 71 | "@babel/syntax-dynamic-import", 72 | "@babel/transform-runtime", 73 | ...conditionalPlugins, 74 | ]; 75 | 76 | return { 77 | presets, 78 | plugins, 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /extract_sentences.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import json 4 | from multiprocessing.pool import Pool 5 | from shutil import copyfile 6 | 7 | from cord_ann.index import paths_from_dataset_path 8 | from cord_ann.mapping import flatten_sentences, create_sentence_to_article_mapping 9 | from tqdm import tqdm 10 | from pathlib import Path 11 | import spacy 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser( 15 | description="Formats the CORD19 data into appropriate formats for indexing/searching and generating embeddings") 16 | parser.add_argument('--input_path', default="cord_19_dataset/") 17 | parser.add_argument('--output_dir', default="cord_19_dataset_formatted/") 18 | parser.add_argument('--num_workers', default=8, type=int) 19 | args = parser.parse_args() 20 | 21 | output_dir = Path(args.output_dir) 22 | output_dir.mkdir(parents=True, exist_ok=True) 23 | 24 | articles_path, sentences_path, mapping_path, metadata_path = paths_from_dataset_path(output_dir) 25 | 26 | articles_path.mkdir(parents=True, exist_ok=True) 27 | 28 | nlp = spacy.load("en_core_sci_sm") # Use SciSpacy to tokenize into sentences 29 | articles_path = Path(args.input_path) 30 | articles = list(articles_path.rglob('*.json')) 31 | 32 | 33 | def _tokenize_paragraphs(journal): 34 | with journal.open('r') as f: 35 | journal = json.load(f) 36 | for x, paragraph in enumerate(journal['body_text']): 37 | paragraph = paragraph['text'] 38 | doc = nlp(paragraph) 39 | sentences = [x.text for x in doc.sents] 40 | journal['body_text'][x]['sentences'] = sentences 41 | paper_id = journal['paper_id'] 42 | output_path = articles_path / (paper_id + '.json') 43 | with output_path.open('w') as f: 44 | json.dump(journal, f) 45 | return journal 46 | 47 | 48 | with Pool(processes=args.num_workers) as pool: 49 | articles_tokenized = list(tqdm(pool.imap_unordered(_tokenize_paragraphs, articles), total=len(articles))) 50 | 51 | with open(sentences_path, 'w') as f: 52 | f.write('\n'.join(flatten_sentences(articles_tokenized))) 53 | with open(mapping_path, 'w') as f: 54 | sent_article_mapping = create_sentence_to_article_mapping(articles_tokenized) 55 | json.dump(sent_article_mapping, f) 56 | 57 | # Copy the metadata.csv file to the formatted directory 58 | copyfile(metadata_path, output_dir / 'metadata.csv') 59 | -------------------------------------------------------------------------------- /frontend/src/components/SearchBar.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { makeStyles } from '@material-ui/core/styles'; 3 | import Paper from '@material-ui/core/Paper'; 4 | import InputBase from '@material-ui/core/InputBase'; 5 | import IconButton from '@material-ui/core/IconButton'; 6 | import SearchIcon from '@material-ui/icons/Search'; 7 | import ResultShow from './ResultShow' 8 | import SearchLinearProgress from './Progress' 9 | import { getSearchResults } from './services/getSearchResults.js' 10 | 11 | const useStyles = makeStyles((theme) => ({ 12 | root: { 13 | padding: '2px 4px', 14 | width: '100%', 15 | }, 16 | input: { 17 | marginLeft: theme.spacing(1), 18 | flex: 1, 19 | }, 20 | iconButton: { 21 | padding: 10, 22 | }, 23 | divider: { 24 | height: 28, 25 | margin: 4, 26 | }, 27 | searchBar: { 28 | width: '100%', 29 | display: 'flex', 30 | alignItems: 'center', 31 | } 32 | })); 33 | 34 | export default function SearchBar() { 35 | const classes = useStyles(); 36 | const [response, setResponse] = React.useState(null); 37 | const [sentence, setSentence] = React.useState(null); 38 | const [loading, setLoading] = React.useState(false); 39 | 40 | const handleClick = ev => { 41 | setLoading(true); 42 | getSearchResults(sentence).then(res=>{ 43 | setResponse(res.data); 44 | setLoading(false); 45 | }); 46 | } 47 | 48 | const handleSearch = ev => { 49 | if (ev.key === 'Enter') { 50 | setLoading(true); 51 | getSearchResults(sentence).then(res=>{ 52 | setResponse(res.data); 53 | setLoading(false); 54 | }); 55 | ev.preventDefault(); 56 | } 57 | } 58 | 59 | return ( 60 |
61 | 62 | setSentence(e.target.value)} 66 | /> 67 | 68 | 69 | 70 | 71 | {loading == true && } 72 | {response !== null && } 73 |
74 | ); 75 | } 76 | -------------------------------------------------------------------------------- /frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "homepage": "http://seannaren.github.io/CORD-19-ANN", 3 | "name": "template-frontend-react-redux-mui", 4 | "version": "1.0.0", 5 | "description": "A template for quickly setting up a react/redux/Material UI frontend", 6 | "engines": { 7 | "node": "12.12.0", 8 | "npm": "6.11.3" 9 | }, 10 | "keywords": [ 11 | "Material UI", 12 | "React", 13 | "Redux" 14 | ], 15 | "main": "src/index.js", 16 | "scripts": { 17 | "start": "./node_modules/webpack/bin/webpack.js -p --progress && node server.js", 18 | "build": "npm run clean && webpack --mode production --progress", 19 | "clean": "rm -rf build && mkdir build", 20 | "dev": "webpack-dev-server --mode development --progress", 21 | "doc": "webpack-dev-server --mode development --progress --host 0.0.0.0", 22 | "dev:hot": "webpack-dev-server --mode development --progress --hot", 23 | "serve": "serve build", 24 | "setup": "touch .env && npm install && npm audit fix", 25 | "test": "jest", 26 | "predeploy": "npm run build", 27 | "deploy": "gh-pages -d build" 28 | }, 29 | "author": "Maher Atashfaraz", 30 | "license": "ISC", 31 | "private": true, 32 | "dependencies": { 33 | "@date-io/moment": "^1.3.9", 34 | "@material-ui/core": "^4.2.0", 35 | "@material-ui/icons": "^4.5.1", 36 | "@material-ui/pickers": "^3.2.2", 37 | "axios": "^0.19.0", 38 | "clsx": "^1.0.4", 39 | "core-js": "^3.1.4", 40 | "d3": "^5.9.7", 41 | "dotenv-defaults": "^1.0.2", 42 | "dotenv-webpack": "^1.7.0", 43 | "moment": "^2.24.0", 44 | "prop-types": "^15.7.2", 45 | "react": "^16.8.6", 46 | "react-dom": "^16.8.6" 47 | }, 48 | "devDependencies": { 49 | "@babel/core": "^7.5.4", 50 | "@babel/plugin-proposal-class-properties": "^7.5.0", 51 | "@babel/plugin-syntax-dynamic-import": "^7.2.0", 52 | "@babel/plugin-transform-runtime": "^7.5.0", 53 | "@babel/preset-env": "^7.5.4", 54 | "@babel/preset-react": "^7.0.0", 55 | "babel-eslint": "^10.0.2", 56 | "babel-loader": "^8.0.6", 57 | "babel-plugin-import": "^1.12.0", 58 | "brotli-webpack-plugin": "^1.1.0", 59 | "copy-webpack-plugin": "^5.0.3", 60 | "css-loader": "^2.1.1", 61 | "eslint": "^5.16.0", 62 | "eslint-loader": "^2.2.1", 63 | "eslint-plugin-react": "^7.19.0", 64 | "eslint-plugin-react-hooks": "^1.6.1", 65 | "file-loader": "^3.0.1", 66 | "html-loader": "^0.5.5", 67 | "html-webpack-plugin": "^3.2.0", 68 | "jest": "^24.9.0", 69 | "json-loader": "^0.5.7", 70 | "less": "^3.9.0", 71 | "less-loader": "^5.0.0", 72 | "node-sass": "^4.12.0", 73 | "sass-loader": "^7.1.0", 74 | "serve": "^11.2.0", 75 | "style-loader": "^0.23.1", 76 | "webpack": "^4.35.3", 77 | "webpack-cli": "^3.3.6", 78 | "webpack-dev-server": "^3.7.2", 79 | "gh-pages": "^2.0.1" 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /index_server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import tornado.ioloop 5 | import tornado.web 6 | 7 | from cord_ann.embeddings import EmbeddingModel 8 | from cord_ann.index import search_args, Index, paths_from_dataset_path 9 | from cord_ann.mapping import load_sentence_to_article_mapping, load_metadata 10 | 11 | 12 | class QueryHandler(tornado.web.RequestHandler): 13 | def set_default_headers(self): 14 | self.set_header("Access-Control-Allow-Origin", "*") 15 | self.set_header("Access-Control-Allow-Headers", "x-requested-with") 16 | self.set_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS') 17 | 18 | def data_received(self, chunk): 19 | pass 20 | 21 | def initialize(self, args, model, index): 22 | self.args = args 23 | self.index = index 24 | self.model = model 25 | 26 | def options(self): 27 | # no body 28 | self.set_status(204) 29 | self.finish() 30 | 31 | def post(self): 32 | if self.request.headers.get("Content-Type", "").startswith("application/json"): 33 | sentences = json.loads(self.request.body) 34 | is_json = True 35 | else: 36 | sentences = [self.request.body.decode("utf-8")] 37 | is_json = False 38 | search_embeddings = self.model.encode_sentences(sentences=sentences) 39 | results = self.index.search_index(sentences=sentences, 40 | search_embeddings=search_embeddings) 41 | results = results if is_json else results[0] # Assume if not json, it was a single sentence 42 | self.write(json.dumps(results)) 43 | self.finish() 44 | 45 | 46 | def make_app(args): 47 | return tornado.web.Application([ 48 | ('/query', QueryHandler, args), 49 | ], debug=True, autoreload=False) 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser = search_args(parser) 55 | parser.add_argument('--port', default=8888, type=int) 56 | parser.add_argument('--address', default="") 57 | args = parser.parse_args() 58 | articles_path, _, mapping_path, metadata_path = paths_from_dataset_path(args.dataset_path) 59 | sent_article_mapping = load_sentence_to_article_mapping(mapping_path) 60 | metadata = load_metadata(metadata_path) 61 | 62 | model = EmbeddingModel(model_name_or_path=args.model_name_or_path, 63 | device=args.device, 64 | batch_size=args.batch_size, 65 | show_progress_bar=not args.silent) 66 | 67 | index = Index(index_path=args.index_path, 68 | index_type=args.index_type, 69 | articles_path=articles_path, 70 | mapping=sent_article_mapping, 71 | metadata=metadata, 72 | k=args.k, 73 | num_workers=args.num_workers) 74 | 75 | app_arguments = { 76 | 'args': args, 77 | 'model': model, 78 | 'index': index, 79 | } 80 | app = make_app(args=app_arguments) 81 | print("Index Server is listening...") 82 | app.listen(port=args.port, 83 | address=args.address) 84 | tornado.ioloop.IOLoop.current().start() 85 | -------------------------------------------------------------------------------- /frontend/src/components/ResultShow.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { withStyles } from '@material-ui/core/styles'; 3 | import PropTypes from 'prop-types'; 4 | import { makeStyles } from '@material-ui/core/styles'; 5 | import MuiExpansionPanelDetails from '@material-ui/core/ExpansionPanelDetails'; 6 | import MuiExpansionPanel from '@material-ui/core/ExpansionPanel'; 7 | import MuiExpansionPanelSummary from '@material-ui/core/ExpansionPanelSummary'; 8 | import Typography from '@material-ui/core/Typography'; 9 | import ExpandMoreIcon from '@material-ui/icons/ExpandMore'; 10 | import CustomizedTabs from './ResultCard' 11 | import MuiLink from '@material-ui/core/Link'; 12 | 13 | const useStyles = makeStyles((theme) => ({ 14 | root: { 15 | width: '100%', 16 | padding: '10px 5px' 17 | }, 18 | heading: { 19 | fontSize: theme.typography.pxToRem(18), 20 | fontWeight: 500, 21 | flexBasis: '33.33%', 22 | flexShrink: 0, 23 | }, 24 | secondaryHeading: { 25 | fontSize: theme.typography.pxToRem(13), 26 | color: theme.palette.text.secondary, 27 | }, 28 | searchTitle: { 29 | padding: '28px 5px' 30 | }, 31 | authors: { 32 | display: 'flex', 33 | flexDirection: 'row', 34 | flexWrap: 'wrap', 35 | paddingBottom: '16px' 36 | }, 37 | expansionDetails: { 38 | display: 'flex', 39 | flexDirection: 'column', 40 | } 41 | })); 42 | 43 | const ExpansionPanel = withStyles({ 44 | root: { 45 | marginBottom: '10px', 46 | }, 47 | expanded: {}, 48 | })(MuiExpansionPanel); 49 | 50 | const ExpansionPanelSummary = withStyles({ 51 | content: { 52 | 'display': 'initial', 53 | 'padding': '0px 24px 0px 24px' 54 | } 55 | })(MuiExpansionPanelSummary); 56 | 57 | const ExpansionPanelDetails = withStyles({ 58 | root: { 59 | 'padding': '8px 49px 24px', 60 | } 61 | })(MuiExpansionPanelDetails); 62 | 63 | const Link = withStyles({ 64 | root: { 65 | color: 'blue' 66 | } 67 | })(MuiLink); 68 | 69 | export default function ResultShow(props) { 70 | const classes = useStyles(); 71 | const [expanded, setExpanded] = React.useState(false); 72 | // const index = 0; 73 | 74 | const handleChange = (panel) => (event, isExpanded) => { 75 | setExpanded(isExpanded ? panel : false); 76 | }; 77 | 78 | return ( 79 |
80 | {props.response.hits.map((item, index) => { 81 | return ( 82 | } 84 | aria-controls="panel1bh-content" 85 | id="panel1bh-header" 86 | > 87 | {item.title === "" ? 'No Title' : item.title} 88 | Matching Sentence: {item.sentence} 89 | 90 | 91 |
92 | {item.authors.map((author, i) => { 93 | return ( 94 | 95 | {author.first} {author.last}{i === item.authors.length - 1 ? null : ','}  ) 96 | })} 97 |
98 | {item.metadata === undefined ? 'N/A' : item.metadata.url} 99 |
100 | {CustomizedTabs(item)} 101 |
102 |
103 |
) 104 | })} 105 |
106 | ); 107 | } 108 | 109 | 110 | 111 | ResultShow.propTypes = { 112 | response: PropTypes.object, 113 | } -------------------------------------------------------------------------------- /frontend/webpack.config.js: -------------------------------------------------------------------------------- 1 | const path = require('path'); 2 | const BrotliPlugin = require('brotli-webpack-plugin'); 3 | const Dotenv = require('dotenv-webpack'); 4 | const HtmlWebPackPlugin = require('html-webpack-plugin'); 5 | const webpack = require('webpack'); 6 | 7 | require('dotenv-defaults').config({ 8 | path: __dirname + '/.env', 9 | encoding: 'utf8', 10 | defaults: __dirname + '/.env.defaults', 11 | }); 12 | 13 | const commonConfig = { 14 | entry: './src/index.js', 15 | module: { 16 | rules: [ 17 | { 18 | test: /\.(js|jsx)$/, 19 | exclude: /node_modules/, 20 | use: ['babel-loader'], 21 | }, 22 | { 23 | test: /\.js$/, 24 | exclude: /node_modules/, 25 | use: ['babel-loader', 26 | { 27 | 'loader': 'eslint-loader', options: { 28 | emitWarning: true 29 | }, 30 | }], 31 | }, 32 | { 33 | test: /\.html$/, 34 | exclude: /template\.html$/, 35 | use: { 36 | loader: 'html-loader', 37 | options: { 38 | minimize: true, 39 | removeComments: false, 40 | collapseWhitespace: true, 41 | }, 42 | } 43 | }, 44 | { 45 | test: /\.(png|jpg|gif)$/, 46 | use: 'file-loader', 47 | }, 48 | { 49 | test: /\.less$/, 50 | use: ['style-loader', 'css-loader', 'less-loader'], 51 | }, 52 | { 53 | test: /\.scss$/, 54 | use: ['style-loader', 'css-loader', 'sass-loader'], 55 | }, 56 | { 57 | test: /\.css$/, 58 | use: ['style-loader', 'css-loader'], 59 | }, 60 | { 61 | test: /\.(woff|woff2|ttf|eot|svg|otf)(\?v=\d+\.\d+\.\d+)?$/, 62 | use: { 63 | loader: 'file-loader', 64 | options: { 65 | name: '[name].[ext]', 66 | outputPath: 'fonts/', 67 | } 68 | } 69 | }, 70 | { 71 | test: /\.geojson$/, 72 | loader: 'json-loader', 73 | }, 74 | ] 75 | }, 76 | plugins: [ 77 | new Dotenv({ 78 | defaults: __dirname + '/.env.defaults', 79 | path: __dirname + '/.env', 80 | }), 81 | new HtmlWebPackPlugin({ 82 | template: './src/index.html', 83 | title: process.env.APP_TITLE, 84 | filename: 'index.html', 85 | }), 86 | ], 87 | }; 88 | 89 | if (process.env.BABEL_USE_MATERIAL_UI_ES_MODULES) { 90 | commonConfig.resolve = { 91 | alias: { 92 | '@material-ui/core': '@material-ui/core/es', 93 | }, 94 | } 95 | } 96 | 97 | module.exports = (env, argv = { mode: 'development' }) => { 98 | switch (argv.mode) { 99 | default: 100 | case 'development': { 101 | return { 102 | ...commonConfig, 103 | devServer: { 104 | compress: process.env.WEBPACK_DEV_SERVER_COMPRESS === 'true', 105 | historyApiFallback: true, 106 | host: process.env.WEBPACK_DEV_SERVER_HOST, 107 | open: process.env.WEBPACK_DEV_SERVER_OPEN === 'true', 108 | port: process.env.WEBPACK_DEV_SERVER_PORT, 109 | }, 110 | devtool: 'eval-source-map', 111 | } 112 | } 113 | 114 | case 'production': { 115 | return { 116 | ...commonConfig, 117 | output: { 118 | path: path.resolve(__dirname, 'build'), 119 | filename: '[name].[contenthash].bundle.js', 120 | }, 121 | optimization: { 122 | splitChunks: { 123 | cacheGroups: { 124 | commons: { 125 | test: /[\\/]node_modules[\\/]/, 126 | name: 'vendors', 127 | chunks: 'all', 128 | }, 129 | }, 130 | }, 131 | }, 132 | plugins: [ 133 | new HtmlWebPackPlugin({ 134 | template: './src/index.html', 135 | title: process.env.APP_TITLE, 136 | filename: 'index.html', 137 | }), 138 | new webpack.DefinePlugin({ 139 | 'process.env': { 140 | 'API_BASE_URL': JSON.stringify(process.env.API_BASE_URL) 141 | } 142 | }), 143 | new BrotliPlugin({ 144 | asset: '[path].br[query]', 145 | test: /\.(js|css|html|svg)$/, 146 | threshold: 10240, 147 | minRatio: 0.8, 148 | }), 149 | ], 150 | } 151 | } 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /frontend/src/components/ResultCard.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import PropTypes from 'prop-types'; 3 | import { makeStyles, withStyles } from '@material-ui/core/styles'; 4 | import Tabs from '@material-ui/core/Tabs'; 5 | import Tab from '@material-ui/core/Tab'; 6 | import Typography from '@material-ui/core/Typography'; 7 | import Box from '@material-ui/core/Box'; 8 | 9 | const AntTabs = withStyles({ 10 | root: { 11 | borderBottom: '1px solid #e8e8e8', 12 | }, 13 | indicator: { 14 | backgroundColor: '#1890ff', 15 | }, 16 | })(Tabs); 17 | 18 | const AntTab = withStyles((theme) => ({ 19 | root: { 20 | variant: 'button', 21 | textTransform: 'none', 22 | minWidth: 72, 23 | fontWeight: theme.typography.fontWeightRegular, 24 | marginRight: theme.spacing(4), 25 | '&:hover': { 26 | color: '#40a9ff', 27 | opacity: 1, 28 | }, 29 | '&$selected': { 30 | color: '#1890ff', 31 | fontWeight: theme.typography.fontWeightMedium, 32 | }, 33 | '&:focus': { 34 | color: '#40a9ff', 35 | }, 36 | }, 37 | selected: {}, 38 | }))((props) => ); 39 | 40 | const useStyles = makeStyles((theme) => ({ 41 | root: { 42 | flexGrow: 1, 43 | }, 44 | padding: { 45 | padding: theme.spacing(3), 46 | }, 47 | demo1: { 48 | backgroundColor: theme.palette.background.paper, 49 | }, 50 | demo2: { 51 | backgroundColor: '#2e1534', 52 | }, 53 | })); 54 | 55 | function TabPanel(props) { 56 | const { children, value, index, ...other } = props; 57 | 58 | return ( 59 | 69 | ); 70 | } 71 | 72 | TabPanel.propTypes = { 73 | children: PropTypes.node, 74 | index: PropTypes.any.isRequired, 75 | value: PropTypes.any.isRequired, 76 | }; 77 | 78 | function a11yProps(index) { 79 | return { 80 | id: `wrapped-tab-${index}`, 81 | 'aria-controls': `wrapped-tabpanel-${index}`, 82 | }; 83 | } 84 | 85 | function splitParagraph(paragraph, sentence) { 86 | /** 87 | Splits the paragraph into three sections to allow the matching sentence to be highlighted. 88 | Note that this means multiple occurrences of the sentence are not highlighted. 89 | */ 90 | var index = paragraph.indexOf(sentence); 91 | return [paragraph.substring(0, index), 92 | paragraph.substring(index, index + sentence.length), 93 | paragraph.substring(index + sentence.length)]; 94 | } 95 | 96 | export default function CustomizedTabs(item) { 97 | const classes = useStyles(); 98 | const [value, setValue] = React.useState(0); 99 | 100 | const handleChange = (event, newValue) => { 101 | setValue(newValue); 102 | }; 103 | 104 | return ( 105 |
106 |
107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | {splitParagraph(item.paragraph.text, item.sentence)[0]} 115 | 116 | 117 | {splitParagraph(item.paragraph.text, item.sentence)[1]} 118 | 119 | 120 | {splitParagraph(item.paragraph.text, item.sentence)[2]} 121 | 122 | 123 | 124 | 125 | {item.abstract.length === 0 ? 'No Abstract Available' : item.abstract[0].text} 126 | 127 | 128 | 129 | 130 | Distance: {Number((item.distance).toFixed(4))} 131 | 132 | 133 | 134 |
135 |
136 | ); 137 | } 138 | -------------------------------------------------------------------------------- /frontend/README.md: -------------------------------------------------------------------------------- 1 | # CORD-19-ANN 2 | 3 | Below are instructions on how to setup the front-end for the search capabilities. 4 | 5 | ## Table of Contents: 6 | - [Table of Contents:](#Table-of-Contents) 7 | - [Prerequisites](#Prerequisites) 8 | - [Setup](#Setup) 9 | - [Babel & Webpack Configuration](#Babel--Webpack-Configuration) 10 | - [Configuration using Environment Variables](#Configuration-using-Environment-Variables) 11 | - [Babel Configuration](#Babel-Configuration) 12 | - [Development and Production Build Modes](#Development-and-Production-Build-Modes) 13 | - [Development Builds](#Development-Builds) 14 | - [Production Build](#Production-Build) 15 | - [Docker Deployment](#Docker-Deployment) 16 | 17 | ## Prerequisites 18 | The following tools are required to set up or run this template: 19 | - [node](https://nodejs.org/) v12.4.0 20 | - [npm](https://www.npmjs.com/) v6.9.0 **or** [Yarn](https://yarnpkg.com/) v1.16.0 21 | 22 | ## Setup 23 | 1. Clone the repo 24 | 2. Navigate to the root directory of this new repo and run either of the commands below: 25 | ```shell 26 | npm install 27 | ``` 28 | 3. You'll need to modify the `.env.defaults` file to point to the URL of the search index. We've assumed you've ran the `index_server.py` on the appropriate node as explained in the README. 29 | 30 | A blank *.env* file is also created in the root directory (more on environment variables [here](#configuration-using-environment-variables)). 31 | 32 | ## Babel & Webpack Configuration 33 | 34 | ### Configuration using Environment Variables 35 | The *webpack.config.js* uses the [dotenv-webpack](https://www.npmjs.com/package/dotenv-webpack) plugin alongside [dotenv-defaults](https://www.npmjs.com/package/dotenv-defaults) to expose any environment variables set in the *.env* or *.env.defaults* file in the root directory. These variables are available within the webpack configuration itself and also anywhere within the application in the format `process.env.[VARIABLE]`. 36 | 37 | The root *.env.defaults* file must only contain non-sensitive configuration variables and should be considered safe to commit to any version control system. 38 | 39 | Any sensitive details, such as passwords or private keys, should be stored in the root *.env* file. This file should **never** be committed and accordingly is already listed within the root *.gitignore* file. The *.env* file also serves to overwrite any non-sensitive variables defined within the root *.env.defaults* file. 40 | 41 | ### Babel Configuration 42 | This project uses [Babel](https://babeljs.io/) to convert, transform and polyfill ECMAScript 2015+ code into a backwards compatible version of JavaScript. 43 | 44 | As with *webpack.config.js* the environmental variables defined in *.env.defaults* and *.env* are available within the *babel.config.js* where Babel's configuration is programmatically created. 45 | 46 | ### Development and Production Build Modes 47 | In *webpack.config.js* there's a common configuration object for both `development` and `production` builds called `commonConfig` which mainly handles loading for various file types. Extend this object with any modules or plugins which apply to both build modes. 48 | 49 | Within a switch statement after the `commonConfig` object individual properties for both the `development` and `production` builds can be defined separately as needed. 50 | 51 | Webpack will use the `--mode` flag it recieves when run to determine which build to bundle. This flag defaults to `development`. 52 | 53 | #### Development Builds 54 | A `development` build can be run in the following ways: 55 | ```shell 56 | // with npm 57 | npm run dev 58 | // or 59 | npm run dev:hot 60 | 61 | // with yarn 62 | yarn run dev 63 | // or 64 | yarn run dev:hot 65 | ``` 66 | Both the `dev` and `dev:hot` scripts use [webpack-dev-server](https://webpack.js.org/configuration/dev-server/) to serve a `development` build locally. Some options are configured already in the root *.env.defaults* and can be overriden in the root *.env* file or within the root *webpack.config.js* itself as required. 67 | 68 | #### Production Build 69 | A `production` build can be run by the following command: 70 | ``` 71 | // with npm 72 | npm run build 73 | 74 | // with yarn 75 | yarn run build 76 | ``` 77 | The `build` script will write an optimized and compressed build to the 'build' directory. If a different directory is required it will mean changing the *.package.json*'s `clean` script as well as the `output.path` property of the *webpack.config.json* accordingly. 78 | 79 | 80 | ## Docker Deployment 81 | In order to run the project locally on Docker, first you need to run the backend. Please follow the instruction below to run the backend. 82 | Once your backend is up and running, you need to open a new tab on your terminal and navigate to the root of the project and run `docker-compose up` in your terminal (you need to have Docker installed on your machine before running this script). If the command is succesfully executed then you should be able to navigate to the project in your browser on `localhost:8080`. 83 | -------------------------------------------------------------------------------- /cord_ann/index.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | class Index: 9 | def __init__(self, index_path, index_type, articles_path, mapping, metadata, k, num_workers): 10 | self.index = self.load_index(index_path, index_type) 11 | self.index_type = index_type 12 | self.articles_path = articles_path 13 | self.mapping = mapping 14 | self.metadata = metadata 15 | self.k = k 16 | self.num_workers = num_workers 17 | 18 | def load_index(self, index_path, index_type): 19 | if index_type == 'nmslib': 20 | import nmslib 21 | index = nmslib.init(method='hnsw', space='cosinesimil') 22 | index.loadIndex(index_path) 23 | elif index_type == 'faiss': 24 | import faiss 25 | index = faiss.read_index(index_path) 26 | else: 27 | raise TypeError('Index type can only be faiss or nmslib.') 28 | return index 29 | 30 | def search_index(self, sentences, search_embeddings, return_batch_ids=False): 31 | if self.index_type == 'nmslib': 32 | batch = self.index.knnQueryBatch(search_embeddings, 33 | k=self.k, 34 | num_threads=self.num_workers) 35 | batch = np.array(batch) 36 | batch_ids = batch[:, 0].astype(np.int) 37 | batch_distances = batch[:, 1].astype(np.float32) 38 | elif self.index_type == 'faiss': 39 | batch_distances, batch_ids = self.index.search(np.array(search_embeddings), k=self.k) 40 | else: 41 | raise TypeError('Index type can only be faiss or nmslib.') 42 | 43 | results = self._format_results(batch_ids=batch_ids, 44 | batch_distances=batch_distances, 45 | sentences=sentences, 46 | articles_path=self.articles_path, 47 | mapping=self.mapping) 48 | if return_batch_ids: 49 | return results, batch_ids 50 | return results 51 | 52 | def _load_article(self, articles_path, paper_id): 53 | json_path = Path(articles_path) / (paper_id + '.json') 54 | with json_path.open() as f: 55 | article = json.load(f) 56 | return article 57 | 58 | def _find_metadata(self, paper_id): 59 | metadata = self.metadata[self.metadata['sha'] == paper_id] 60 | if len(metadata) == 1: 61 | metadata = metadata.iloc[0].to_dict() 62 | return { 63 | 'doi': metadata['doi'] if not pd.isna(metadata['doi']) else 'N/A', 64 | 'url': metadata['url'] if not pd.isna(metadata['url']) else 'N/A', 65 | 'journal': metadata['journal'] if not pd.isna(metadata['journal']) else 'N/A', 66 | 'publish_time': metadata['publish_time'] if not pd.isna(metadata['publish_time']) else 'N/A', 67 | } 68 | else: 69 | return None # No metadata was found 70 | 71 | def _extract_k_hits(self, ids, distances, sentence, articles_path, sent_article_mapping): 72 | extracted = { 73 | "query": sentence, 74 | "hits": [] 75 | } 76 | 77 | for id, distance in zip(ids, distances): 78 | mapping = sent_article_mapping[id] 79 | paragraph_idx = mapping["paragraph_idx"] 80 | sentence_idx = mapping["sentence_idx"] 81 | paper_id = mapping["paper_id"] 82 | article = self._load_article(articles_path=articles_path, 83 | paper_id=paper_id) 84 | hit = { 85 | 'title': article['metadata']['title'], 86 | 'authors': article['metadata']['authors'], 87 | 'paragraph': article['body_text'][paragraph_idx], 88 | 'sentence': article['body_text'][paragraph_idx]["sentences"][sentence_idx], 89 | 'abstract': article['abstract'], 90 | 'distance': float(distance), 91 | } 92 | metadata = self._find_metadata(paper_id) 93 | if metadata: 94 | hit['metadata'] = metadata 95 | extracted["hits"].append(hit) 96 | return extracted 97 | 98 | def _format_results(self, batch_ids, batch_distances, sentences, articles_path, mapping): 99 | return [self._extract_k_hits(ids=batch_ids[x], 100 | distances=batch_distances[x], 101 | sentence=query_sentence, 102 | articles_path=articles_path, 103 | sent_article_mapping=mapping) for x, query_sentence in enumerate(sentences)] 104 | 105 | 106 | def search_args(parser): 107 | parser.add_argument('--index_path', default="index", 108 | help='Path to the created index') 109 | parser.add_argument('--index_type', default="nmslib", type=str, choices=["nmslib", "faiss"], 110 | help='Type of index') 111 | parser.add_argument('--dataset_path', default="cord_19_dataset_formatted/", 112 | help='Path to the extracted dataset') 113 | parser.add_argument('--model_name_or_path', default='bert-base-nli-mean-tokens') 114 | parser.add_argument('--batch_size', default=8, type=int, 115 | help='Batch size for the transformer model encoding') 116 | parser.add_argument('--num_workers', default=8, type=int, 117 | help='Number of workers to use when parallelizing the index search') 118 | parser.add_argument('--k', default=10, type=int, 119 | help='The top K hits to return from the index') 120 | parser.add_argument('--device', default='cpu', 121 | help='Set to cuda to use the GPU') 122 | parser.add_argument('--silent', action="store_true", 123 | help='Turn off progress bar when searching') 124 | return parser 125 | 126 | 127 | def paths_from_dataset_path(dataset_path): 128 | """ 129 | Creates paths to the files required for searching the index. 130 | :param dataset_path: The path to the extracted dataset. 131 | :return: Paths to various important files/folders for searching the index. 132 | """ 133 | dataset_path = Path(dataset_path) 134 | articles_path = dataset_path / 'articles/' 135 | sentences_path = dataset_path / 'cord_19_sentences.txt' 136 | mapping_path = dataset_path / 'cord_19_sent_to_article_mapping.json' 137 | metadata_path = dataset_path / 'metadata.csv' 138 | return articles_path, sentences_path, mapping_path, metadata_path 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CORD-19-ANN 2 | 3 | ![cord_website](imgs/cord_ann_example.gif) 4 | 5 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/137jbpY3yQJGSzlLHZGUYBk5F78bwuqKJ) [![Open In Colab](https://github.com/aleen42/badges/raw/master/src/medium.svg)](https://medium.com/@seannaren/cord-19-ann-semantic-search-engine-using-s-bert-aebc5bcc5442?sk=92ea4a22df3cd1343c86a1e880b78f6f) [GitHub Pages](https://seannaren.github.io/CORD-19-ANN/) 6 | 7 | This repo contains the scripts and models to search [CORD-19](https://pages.semanticscholar.org/coronavirus-research) using [S-BERT](https://github.com/UKPLab/sentence-transformers) embeddings via [nmslib](https://github.com/nmslib/nmslib/blob/master/python_bindings/README.md) or [faiss](https://github.com/facebookresearch/faiss). 8 | 9 | Sentence embeddings are not perfect for searching (see [this issue](https://github.com/UKPLab/sentence-transformers/issues/174)) however can provide insight into the data that basic search functionality cannot. There is still room to improve the retrieval of relevant documents. 10 | 11 | We're not versed in the medical field, so any feedback or improvements we deeply encourage in the form of issues/PRs! 12 | 13 | We've included pre-trained models and the FAISS index to start your own server with instructions below. 14 | 15 | Finally we provide a front-end that can be used to search through the dataset and extract information via a UI. Instructions and installation for the front-end can be found [here](frontend/README.md). 16 | 17 | We currently are hosting the server on a gcp instance, if anyone can contribute for a more permanent hosting solution it would be appreciated. 18 | 19 | ## Installation 20 | 21 | ### Source 22 | We assume you have installed PyTorch and the necessary CUDA packages from [here](https://pytorch.org/). We suggest using Conda to make installation easier. 23 | ``` 24 | # Install FAISS 25 | conda install faiss-cpu -c pytorch # Other instructions can be found at https://github.com/facebookresearch/faiss/blob/master/INSTALL.md 26 | 27 | git clone https://github.com/SeanNaren/CORD-19-ANN.git --recursive 28 | cd CORD-19-ANN/ 29 | pip install -r requirements.txt 30 | pip install . 31 | ``` 32 | 33 | ### Docker 34 | 35 | We also provide a docker container: 36 | 37 | ``` 38 | docker pull seannaren/cord-19-ann 39 | sudo docker run -it --net=host --ipc=host --entrypoint=/bin/bash --rm seannaren/cord-19-ann 40 | ``` 41 | 42 | ## Download Models 43 | 44 | We currently offer sentence models trained on [BlueBERT](https://github.com/ncbi-nlp/bluebert) (base uncased model) and [BioBERT](https://github.com/naver/biobert-pretrained) (base cased model) with the appropriate metadata/index. We currently serve S-BlueBERT however it is interchangeable. 45 | 46 | 47 | ### Download S-BERT Models and Search Index 48 | 49 | Download the corresponding Model and Index file. We suggest using S-BioBERT and assume you have done so for the subsequent commands. They are interchangeable however. 50 | 51 | | Model | Index | Test MedNLI Accuracy | Test STS Benchmark Cosine Pearson | 52 | |-----------------------------|--------------------------------|-----------------|------------------------------| 53 | | [S-BioBERT Base Cased](https://github.com/SeanNaren/CORD-19-ANN/releases/download/V1.0/s-biobert_base_cased_mli.tar.gz) | [BioBERT_faiss_PCAR128_SQ8](https://github.com/SeanNaren/CORD-19-ANN/releases/download/V1.0/biobert_mli_faiss_PCAR128_SQ8) | 0.7482 | 0.7122 | 54 | | [S-BlueBERT Base Uncased](https://github.com/SeanNaren/CORD-19-ANN/releases/download/V1.0/s-bluebert_base_uncased_mli.tar.gz) | [BlueBERT_faiss_PCAR128_SQ8](https://github.com/SeanNaren/CORD-19-ANN/releases/download/V1.0/bluebert_mli_faiss_PCAR128_SQ8) | 0.7525 | 0.6923 | 55 | | S-Bert Base Cased | | 0.5689 | 0.7265 | 56 | 57 | 58 | ### Download Metadata 59 | ``` 60 | wget https://github.com/SeanNaren/CORD-19-ANN/releases/download/V1.0/cord_19_dataset_formatted_2020_03_27.tar.gz 61 | tar -xzvf cord_19_dataset_formatted_2020_03_27.tar.gz cord_19_dataset_formatted/ 62 | ``` 63 | 64 | ## Searching the Index 65 | 66 | We assume you've chosen the s-biobert model, it should be straightforward to swap in any other pre-trained models offered in this repo by modifying the paths below. 67 | 68 | We recommend using the server but we do offer a simple script to search given a text file of sentences: 69 | 70 | ``` 71 | echo "These RNA transcripts may be spliced to give rise to mRNAs encoding the envelope (Env) glycoproteins (Fig. 1a)" > sentences.txt 72 | python search_index.py --index_path biobert_mli_faiss_PCAR128_SQ8 --index_type faiss --model_name_or_path s-biobert_base_cased_mli/ --dataset_path cord_19_dataset_formatted/ --input_path sentences.txt --output_path output.json 73 | ``` 74 | 75 | #### Using the server 76 | 77 | To start the server: 78 | ``` 79 | YOUR_IP=0.0.0.0 80 | YOUR_PORT=1337 81 | python index_server.py --index_path biobert_mli_faiss_PCAR128_SQ8 --index_type faiss --model_name_or_path s-biobert_base_cased_mli/ --dataset_path cord_19_dataset_formatted/ --address $YOUR_IP --port $YOUR_PORT --silent 82 | ``` 83 | 84 | To test the server: 85 | ``` 86 | curl --header "Content-Type: application/json" \ 87 | --request POST \ 88 | --data '["These RNA transcripts may be spliced to give rise to mRNAs encoding the envelope (Env) glycoproteins (Fig. 1a)"]' \ 89 | http://$YOUR_IP:$YOUR_PORT/query 90 | ``` 91 | 92 | ### Output Format 93 | 94 | The output from the index is a JSON object containing the top K hits from the index, an example of the API is given below: 95 | 96 | ``` 97 | [ 98 | { 99 | "query": "These RNA transcripts may be spliced to give rise to mRNAs encoding the envelope (Env) glycoproteins (Fig. 1a)", 100 | "hits": [ 101 | { 102 | "title": "Title", 103 | "authors": [ 104 | "..." 105 | ], 106 | "abstract": [ 107 | "..." 108 | ], 109 | "paragraph": "Paragraph that included the hit", 110 | "sentence": "The semantically similar sentence", 111 | "distance": 42, 112 | } 113 | ] 114 | } 115 | ] 116 | ``` 117 | 118 | ## Creating the Index from scratch 119 | 120 | The process requires a GPU enabled node such as a GCP n8 node with a nvidia-tesla-v100 to generate the embeddings, with at-least 20GB RAM. 121 | 122 | ### Preparing the dataset 123 | 124 | Currently we tokenize at the sentence level using SciSpacy, however future work may look into using paragraph level tokenization. 125 | 126 | ``` 127 | mkdir datasets/ 128 | python download_data.py 129 | python extract_sentences.py --num_workers 16 130 | ``` 131 | 132 | ### Generating embeddings 133 | 134 | #### Using fine-tuned BioBERT/BlueBERT 135 | 136 | Using sentence-transformers we can fine-tune either model. BlueBERT offers only uncased models whereas BioBERT offer a cased model. We've converted them into PyTorch format and included them in releases, to download: 137 | 138 | ``` 139 | wget https://github.com/SeanNaren/CORD-19-ANN/releases/download/V1.0/s-biobert_base_cased_mli.tar.gz 140 | wget https://github.com/SeanNaren/CORD-19-ANN/releases/download/V1.0/s-bluebert_base_uncased_mli.tar.gz 141 | tar -xzvf s-biobert_base_cased_mli.tar.gz 142 | tar -xzvf s-bluebert_base_uncased_mli.tar.gz 143 | ``` 144 | 145 | ##### Using Pre-trained BioBERT/BlueBERT 146 | 147 | ``` 148 | python generate_embeddings.py --model_name_or_path s-biobert_base_cased_mli/ --embedding_path biobert_embeddings.npy --device cuda --batch_size 256 # If you want to use biobert 149 | python generate_embeddings.py --model_name_or_path s-bluebert_base_uncased_mli/ --embedding_path bluebert_embeddings.npy --device cuda --batch_size 256 # If you want to use bluebert 150 | ``` 151 | 152 | #### Using pre-trained S-BERT models 153 | 154 | You can also use the standard pre-trained model from the S-BERT repo like below, however we suggest using the fine-tuned models offered in this repo. 155 | 156 | ``` 157 | python generate_embeddings.py --model_name_or_path bert-base-nli-mean-tokens --embedding_path pretrained_embeddings.npy --device cuda --batch_size 256 158 | ``` 159 | 160 | ##### Training the model from scratch 161 | 162 | This takes a few hours on a V100 GPU. 163 | 164 | If you'd like to include the MedNLI dataset during training, you'll need to download the dataset from [here](https://physionet.org/content/mednli/1.0.0/). Getting access requires credentialed access which requires some efforts and a waiting period of up to two weeks. 165 | 166 | Once trained the model is saved to the `output/` folder by default. Inside there you'll find checkpoints such as `output/training_nli/biobert-2020-03-30_10-51-49/` after training has finished. Use this as the model path when generating your embeddings. 167 | 168 | ``` 169 | wget https://github.com/SeanNaren/CORD-19-ANN/releases/download/V1.0/biobert_cased_v1.1.tar.gz 170 | wget https://github.com/SeanNaren/CORD-19-ANN/releases/download/V1.0/bluebert_base_uncased.tar.gz 171 | tar -xzvf biobert_cased_v1.1.tar.gz 172 | tar -xzvf bluebert_base_uncased.tar.gz 173 | 174 | mkdkir datasets/ 175 | python sentence-transformers/examples/datasets/get_data.py --output_path datasets/ 176 | python sentence-transformers/examples/training_nli_transformers.py --model_name_or_path biobert_cased_v1.1/ 177 | python sentence-transformers/examples/training_nli_transformers.py --model_name_or_path bluebert_base_uncased/ --do_lower_case 178 | 179 | # Training with medNLI 180 | python sentence-transformers/examples/training_nli_transformers.py --model_name_or_path biobert_cased_v1.1/ --mli_dataset_path path/to/mednli/ 181 | python sentence-transformers/examples/training_nli_transformers.py --model_name_or_path bluebert_base_uncased/ --mli_dataset_path path/to/mednli/ --do_lower_case 182 | ``` 183 | 184 | To exclude the MedNLI but still evaluate on the data (still requires the MedNLI dataset), use the `--exclude_mli`. 185 | 186 | ### Create the Index 187 | 188 | We have the ability to use faiss or nmslib given the parameter below. We've exposed the FAISS config string for modifying the index. More details about selecting the index can be seen [here](https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index). 189 | 190 | ``` 191 | python create_index.py --output_path index --embedding_path pretrained_embeddings.npy --index_type faiss # Swap to scibert_embeddings.npy if using fine-tuned SciBERT embeddings 192 | ``` 193 | 194 | ### Clustering 195 | 196 | We also took the example clustering script out of sentence-transformers and added it to this repository for using the pre-trained models. An example below: 197 | 198 | ``` 199 | python cluster_sentences.py --input_path sentences.txt --model_name_or_path biobert_cased_v1.1/ --device cpu 200 | ``` 201 | 202 | There is also a more interactive version available using the Google Colab demo: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/137jbpY3yQJGSzlLHZGUYBk5F78bwuqKJ) 203 | 204 | ## Acknowledgements 205 | 206 | Thanks to the authors of the various libraries that made this possible! 207 | 208 | - [sentence-transformers](https://github.com/UKPLab/sentence-transformers) 209 | - [cord-19](https://pages.semanticscholar.org/coronavirus-research) 210 | - [scibert](https://github.com/allenai/scibert) 211 | - [nmslib](https://github.com/nmslib/nmslib) 212 | - [FAISS](https://github.com/facebookresearch/faiss) --------------------------------------------------------------------------------