├── orion ├── core │ ├── README.md │ ├── orms │ │ ├── __init__.py │ │ ├── tests │ │ │ └── test_mag_orm.py │ │ └── es_mapping.py │ ├── operators │ │ ├── faiss_index_task.py │ │ ├── open_access_journals_task.py │ │ ├── affiliation_type_task.py │ │ ├── mag_geocode_task.py │ │ ├── collect_wb_indicators_task.py │ │ ├── topic_filtering_task.py │ │ ├── infer_gender_task.py │ │ ├── country_details_task.py │ │ ├── postgresql2es_task.py │ │ ├── mag_parse_task.py │ │ ├── text2vec_task.py │ │ ├── dim_reduction_task.py │ │ ├── mag_collect_task.py │ │ └── draw_collaboration_graph_task.py │ └── dags │ │ ├── tutorial.py │ │ └── orion.py ├── __init__.py └── packages │ ├── utils │ ├── tests │ │ ├── test_s3_utils.py │ │ ├── test_nlp_utils.py │ │ ├── test_batches.py │ │ └── test_utils.py │ ├── nlp_utils.py │ ├── batches.py │ ├── s3_utils.py │ └── utils.py │ ├── projection │ ├── faiss_index.py │ ├── dim_reduction.py │ └── tests │ │ └── test_projections.py │ ├── mag │ ├── create_tables.py │ ├── tests │ │ ├── test_query_mag.py │ │ └── test_parsing_mag_data.py │ ├── parsing_mag_data.py │ └── query_mag_api.py │ ├── geo │ ├── enrich_countries.py │ ├── geocode.py │ └── tests │ │ ├── test_geocode.py │ │ └── test_enrich_countries.py │ ├── nlp │ ├── tests │ │ └── test_text2vec.py │ └── text2vec.py │ ├── gender │ ├── query_gender_api.py │ └── tests │ │ └── test_query_gender_api.py │ ├── README.md │ └── metrics │ ├── tests │ └── test_metrics.py │ └── metrics.py ├── schema ├── schema.png ├── country_collaboration.yaml ├── mag_field_of_study_metadata.yaml ├── mag_author_affiliation.yaml ├── mag_paper_fields_of_study.yaml ├── rca_country.yaml ├── mag_authors.yaml ├── rca_affiliation.yaml ├── mag_affiliation.yaml ├── mag_field_of_study_hierarchy.yaml ├── mag_fields_of_study.yaml ├── author_gender.yaml ├── doc_vectors.yaml ├── mag_paper_authors.yaml ├── mag_paper_journal.yaml ├── README.md ├── gender_diversity_country.yaml ├── mag_paper_conferences.yaml ├── research_diversity_country.yaml ├── geocoded_places.yaml └── mag_papers.yaml ├── boto.cfg.example ├── entrypoint ├── .travis.yml ├── requirements.txt ├── .env.example ├── docker-compose.yml ├── LICENSE ├── setup.py ├── Dockerfile ├── README.md ├── .gitignore └── CONTRIBUTING.md /orion/core/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /orion/core/orms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /schema/schema.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orion-search/orion/HEAD/schema/schema.png -------------------------------------------------------------------------------- /boto.cfg.example: -------------------------------------------------------------------------------- 1 | # Example ~/.boto file 2 | [Credentials] 3 | aws_access_key_id = foo 4 | aws_secret_access_key = bar -------------------------------------------------------------------------------- /schema/country_collaboration.yaml: -------------------------------------------------------------------------------- 1 | schema: 2 | id: 3 | type: integer 4 | country_a: 5 | type: string 6 | country_b: 7 | type: string 8 | weight: 9 | type: integer 10 | -------------------------------------------------------------------------------- /schema/mag_field_of_study_metadata.yaml: -------------------------------------------------------------------------------- 1 | schema: 2 | id: 3 | type: integer 4 | level: 5 | type: integer 6 | frequency: 7 | type: integer 8 | description: Raw frequency of a field of study. 9 | -------------------------------------------------------------------------------- /schema/mag_author_affiliation.yaml: -------------------------------------------------------------------------------- 1 | API: Academic Knowledge API 2 | endpoint: https://api.labs.cognitive.microsoft.com/academic/v1.0/evaluate 3 | schema: 4 | affiliation_id: 5 | type: integer 6 | author_id: 7 | type: integer 8 | -------------------------------------------------------------------------------- /schema/mag_paper_fields_of_study.yaml: -------------------------------------------------------------------------------- 1 | API: Academic Knowledge API 2 | endpoint: https://api.labs.cognitive.microsoft.com/academic/v1.0/evaluate 3 | schema: 4 | paper_id: 5 | type: integer 6 | field_of_study_id: 7 | type: integer 8 | -------------------------------------------------------------------------------- /schema/rca_country.yaml: -------------------------------------------------------------------------------- 1 | schema: 2 | id: 3 | type: integer 4 | rca_sum: 5 | type: float 6 | entity: 7 | type: integer 8 | year: 9 | type: string 10 | field_of_study_id: 11 | type: integer 12 | -------------------------------------------------------------------------------- /schema/mag_authors.yaml: -------------------------------------------------------------------------------- 1 | API: Academic Knowledge API 2 | endpoint: https://api.labs.cognitive.microsoft.com/academic/v1.0/evaluate 3 | schema: 4 | id: 5 | type: integer 6 | name: 7 | type: string 8 | description: Author full name. 9 | -------------------------------------------------------------------------------- /schema/rca_affiliation.yaml: -------------------------------------------------------------------------------- 1 | schema: 2 | id: 3 | type: integer 4 | rca_sum: 5 | type: float 6 | entity: 7 | type: integer 8 | year: 9 | type: string 10 | field_of_study_id: 11 | type: integer 12 | -------------------------------------------------------------------------------- /schema/mag_affiliation.yaml: -------------------------------------------------------------------------------- 1 | API: Academic Knowledge API 2 | endpoint: https://api.labs.cognitive.microsoft.com/academic/v1.0/evaluate 3 | schema: 4 | id: 5 | type: integer 6 | affiliation: 7 | type: string 8 | description: Affiliation name. 9 | -------------------------------------------------------------------------------- /entrypoint: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Load DAGs examples (default: Yes) 4 | if [[ -z "$AIRFLOW__CORE__LOAD_EXAMPLES" && "${LOAD_EX:=n}" == n ]]; then 5 | AIRFLOW__CORE__LOAD_EXAMPLES=False 6 | fi 7 | 8 | airflow initdb 9 | airflow scheduler & 10 | exec airflow webserver -------------------------------------------------------------------------------- /schema/mag_field_of_study_hierarchy.yaml: -------------------------------------------------------------------------------- 1 | schema: 2 | id: 3 | type: integer 4 | parent_id: 5 | type: integer[] 6 | description: Parent IDs of a FoS. 7 | child_id: 8 | type: integer[] 9 | description: Child IDs of a FoS. 10 | -------------------------------------------------------------------------------- /schema/mag_fields_of_study.yaml: -------------------------------------------------------------------------------- 1 | API: Academic Knowledge API 2 | endpoint: https://api.labs.cognitive.microsoft.com/academic/v1.0/evaluate 3 | schema: 4 | id: 5 | type: integer 6 | name: 7 | type: string 8 | description: Name of a field of study (aka paper keyword). 9 | -------------------------------------------------------------------------------- /schema/author_gender.yaml: -------------------------------------------------------------------------------- 1 | schema: 2 | id: 3 | type: integer 4 | full_name: 5 | type: float 6 | first_name: 7 | type: string 8 | gender: 9 | type: string 10 | samples: 11 | type: integer 12 | probability: 13 | type: float 14 | -------------------------------------------------------------------------------- /orion/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | import yaml 4 | from pathlib import Path 5 | 6 | 7 | # Define project base directory 8 | project_dir = Path(__file__).resolve().parents[1] 9 | 10 | # Model config 11 | with open(project_dir / "model_config.yaml", "rt") as f: 12 | config = yaml.safe_load(f.read()) 13 | -------------------------------------------------------------------------------- /schema/doc_vectors.yaml: -------------------------------------------------------------------------------- 1 | schema: 2 | id: 3 | type: integer 4 | doi: 5 | type: string 6 | vector_2d: 7 | type: float[] 8 | description: 2D vector representation of an abstract. 9 | vector_3d: 10 | type: float[] 11 | description: 3D vector representation of an abstract. 12 | -------------------------------------------------------------------------------- /schema/mag_paper_authors.yaml: -------------------------------------------------------------------------------- 1 | API: Academic Knowledge API 2 | endpoint: https://api.labs.cognitive.microsoft.com/academic/v1.0/evaluate 3 | schema: 4 | author_id: 5 | type: integer 6 | paper_id: 7 | type: integer 8 | order: 9 | type: integer 10 | description: Order of the authors in a paper. 11 | -------------------------------------------------------------------------------- /schema/mag_paper_journal.yaml: -------------------------------------------------------------------------------- 1 | API: Academic Knowledge API 2 | endpoint: https://api.labs.cognitive.microsoft.com/academic/v1.0/evaluate 3 | schema: 4 | id: 5 | type: integer 6 | description: Journal ID. 7 | paper_id: 8 | type: integer 9 | journal_name: 10 | type: string 11 | description: Journal name. 12 | -------------------------------------------------------------------------------- /schema/README.md: -------------------------------------------------------------------------------- 1 | # Schema # 2 | There is a separate `yaml` file with the schema for each PostgreSQL table. 3 | 4 | ## Entity relation diagram ## 5 | 6 | Colour codes: 7 | * **Blue**: MAG tables. 8 | * **Red**: GenderAPI table. 9 | * **Orange**: Google Places API table. 10 | * **Black**: Generated features. 11 | 12 | ![ER Diagram](schema.png?raw=true) 13 | -------------------------------------------------------------------------------- /schema/gender_diversity_country.yaml: -------------------------------------------------------------------------------- 1 | schema: 2 | id: 3 | type: integer 4 | female_share: 5 | type: float 6 | description: Proportion of female co-authors. 7 | year: 8 | type: string 9 | entity: 10 | type: string 11 | description: Country name. 12 | field_of_study_id: 13 | type: integer 14 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: xenial 2 | language: python 3 | python: 4 | - '3.6' 5 | - '3.7' 6 | # command to install dependencies 7 | install: 8 | - pip install -r requirements.txt 9 | - pip install -e . 10 | services: 11 | - postgresql 12 | env: 13 | - postgres=postgres+psycopg2://postgres@localhost/postgres 14 | # command to run tests 15 | script: 16 | - pytest 17 | -------------------------------------------------------------------------------- /schema/mag_paper_conferences.yaml: -------------------------------------------------------------------------------- 1 | API: Academic Knowledge API 2 | endpoint: https://api.labs.cognitive.microsoft.com/academic/v1.0/evaluate 3 | schema: 4 | id: 5 | type: integer 6 | description: Counference ID. 7 | paper_id: 8 | type: integer 9 | conference_name: 10 | type: string 11 | description: Conference name. 12 | -------------------------------------------------------------------------------- /schema/research_diversity_country.yaml: -------------------------------------------------------------------------------- 1 | schema: 2 | id: 3 | type: integer 4 | shannon_diversity: 5 | type: float 6 | description: Shannon diversity index. 7 | simpson_e_diversity: 8 | type: float 9 | description: Simpson E diversity index. 10 | simpson_diversity: 11 | type: float 12 | description: Simpson diversity index. 13 | year: 14 | type: string 15 | entity: 16 | type: string 17 | description: Country name. 18 | field_of_study_id: 19 | type: integer 20 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==2.8.0 2 | torch==1.3.1 3 | numpy==1.18.2 4 | elasticsearch==7.6.0 5 | elasticsearch-dsl==7.0.0 6 | requests-aws4auth==0.9 7 | toolz==0.10.0 8 | WTForms==2.2.1 9 | apache-airflow==1.10.14 10 | boto3==1.10.27 11 | setuptools==41.4.0 12 | SQLAlchemy==1.3.9 13 | retrying==1.3.3 14 | alphabet_detector==0.0.7 15 | requests==2.22.0 16 | pytest==5.2.2 17 | psycopg2-binary==2.8.3 18 | umap-learn==0.3.10 19 | scikit-learn==0.22.1 20 | faiss-cpu==1.6.1 21 | sentencepiece==0.1.85 22 | pandas-datareader==0.8.1 23 | werkzeug==0.16.0 24 | python-dotenv==0.10.3 25 | sentence-transformers==0.2.6.1 26 | pyarrow==0.17.0 27 | -------------------------------------------------------------------------------- /orion/packages/utils/tests/test_s3_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from unittest.mock import patch 4 | from orion.packages.utils.s3_utils import store_on_s3 5 | from orion.packages.utils.s3_utils import create_s3_bucket 6 | 7 | 8 | @patch("orion.packages.utils.s3_utils.boto3") 9 | def test_store_on_s3(boto3): 10 | """Data is archived, uploaded, and the floor is swept""" 11 | store_on_s3("test_data", "bucket", "prefix") 12 | 13 | boto3.resource.assert_called_with("s3") 14 | boto3.resource().Object.assert_called_with("bucket", "prefix.pickle") 15 | 16 | 17 | @patch("orion.packages.utils.s3_utils.boto3") 18 | def test_create_s3_bucket(boto3): 19 | create_s3_bucket("test_bucket") 20 | 21 | boto3.resource.assert_called_with("s3") 22 | -------------------------------------------------------------------------------- /orion/packages/projection/faiss_index.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import faiss 3 | 4 | 5 | def faiss_index(vectors, ids=None): 6 | """Create a brute-force FAISS index. 7 | 8 | Args: 9 | vectors (:obj:`numpy.array` of `float`): Usually document vectors 10 | ids (:obj:`list` of `int`, None): FAISS creates a numerical index which 11 | can be substituted by a list of ids. Here, it can be paper IDs. 12 | 13 | Returns: 14 | index (`faiss.swigfaiss.IndexIDMap`) 15 | 16 | """ 17 | index = faiss.IndexFlatL2(vectors.shape[1]) 18 | if ids: 19 | index = faiss.IndexIDMap(index) 20 | index.add_with_ids(vectors, np.array([i for i in ids])) 21 | else: 22 | index.add(vectors) 23 | 24 | return index 25 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # postgresdb 2 | = 3 | postgres= 4 | 5 | # mag 6 | mag_api_key= 7 | 8 | # google 9 | google_api_key= 10 | 11 | # gender api 12 | gender_api= 13 | 14 | # elasticsearch 15 | es_host= 16 | es_port= 17 | es_index= 18 | region= 19 | 20 | # docker-compose.yml ENV variables 21 | DB_HOST= 22 | DB_PORT= 23 | DB_USER= 24 | DB_PASS= 25 | MAIN_DB= 26 | # This should be left as airflow 27 | DB_NAME=airflow 28 | 29 | AWS_ACCESS_KEY_ID= 30 | AWS_SECRET_ACCESS_KEY= 31 | AWS_DEFAULT_REGION= 32 | -------------------------------------------------------------------------------- /orion/core/orms/tests/test_mag_orm.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import unittest 3 | from sqlalchemy.orm import sessionmaker 4 | from sqlalchemy import create_engine 5 | from orion.core.orms.mag_orm import Base 6 | from dotenv import load_dotenv, find_dotenv 7 | import os 8 | 9 | load_dotenv(find_dotenv()) 10 | 11 | 12 | class TestMag(unittest.TestCase): 13 | """Check that the MAG ORM works as expected""" 14 | 15 | db_config = os.getenv("postgres") 16 | engine = create_engine(db_config) 17 | Session = sessionmaker(engine) 18 | 19 | def setUp(self): 20 | """Create the temporary table""" 21 | Base.metadata.create_all(self.engine) 22 | 23 | def tearDown(self): 24 | """Drop the temporary table""" 25 | Base.metadata.drop_all(self.engine) 26 | 27 | def test_build(self): 28 | pass 29 | 30 | 31 | if __name__ == "__main__": 32 | unittest.main() 33 | -------------------------------------------------------------------------------- /orion/core/orms/es_mapping.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from elasticsearch_dsl import Document, Date, Integer, Keyword, Text, Nested, Long 3 | 4 | 5 | class PaperES(Document): 6 | # ES mappings 7 | original_title = Text(analyzer="standard") 8 | abstract = Text(analyzer="standard") 9 | year = Keyword() 10 | publication_date = Date() 11 | citations = Integer() 12 | fields_of_study = Nested( 13 | properties={"name": Text(fields={"raw": Keyword()}), "id": Long()}, 14 | include_in_parent=True, 15 | ) 16 | authors = Nested( 17 | properties={ 18 | "name": Text(fields={"raw": Keyword()}), 19 | "affiliation": Text(fields={"raw": Keyword()}), 20 | } 21 | ) 22 | 23 | class Index: 24 | # Index name 25 | name = "mag_papers" 26 | 27 | settings = {"number_of_shards": 2, "number_of_replicas": 0} 28 | 29 | def save(self, **kwargs): 30 | return super(PaperES, self).save(**kwargs) 31 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | # This is for running the Orion ETL locally 2 | version: "3" 3 | services: 4 | postgres: 5 | image: postgres:12.2-alpine 6 | container_name: postgres 7 | environment: 8 | POSTGRES_PASSWORD: ${DB_PASS} 9 | POSTGRES_USER: ${DB_USER} 10 | POSTGRES_DB: ${DB_NAME} 11 | ports: 12 | - "5432:5432" 13 | airflow: 14 | container_name: airflow 15 | depends_on: 16 | - postgres 17 | build: 18 | context: . 19 | dockerfile: Dockerfile 20 | args: 21 | DB_HOST: ${DB_HOST} 22 | DB_PORT: ${DB_PORT} 23 | DB_NAME: ${DB_NAME} 24 | DB_USER: ${DB_USER} 25 | DB_PASS: ${DB_PASS} 26 | environment: 27 | - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} 28 | - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} 29 | - AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION} 30 | # AIRFLOW__CORE__DAGS_FOLDER: "/airflow/orion/core/dags" 31 | env_file: 32 | - .env 33 | ports: 34 | - "8080:8080" 35 | restart: always 36 | -------------------------------------------------------------------------------- /orion/packages/mag/create_tables.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import psycopg2 3 | from sqlalchemy import create_engine, exc 4 | from orion.core.orms.mag_orm import Base 5 | from dotenv import load_dotenv, find_dotenv 6 | import os 7 | 8 | load_dotenv(find_dotenv()) 9 | 10 | 11 | def create_db_and_tables(db): 12 | # Try to create the database if it doesn't already exist. 13 | try: 14 | db_config = os.getenv("postgres") 15 | engine = create_engine(db_config) 16 | conn = engine.connect() 17 | conn.execute("commit") 18 | conn.execute(f"create database {db}") 19 | conn.close() 20 | except exc.DBAPIError as e: 21 | if isinstance(e.orig, psycopg2.errors.DuplicateDatabase): 22 | logging.info(e) 23 | else: 24 | logging.error(e) 25 | raise 26 | 27 | db_config = os.getenv(db) 28 | engine = create_engine(db_config) 29 | Base.metadata.create_all(engine) 30 | 31 | 32 | if __name__ == "__main__": 33 | create_db_and_tables("test_db") 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 orion-search 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /orion/packages/utils/nlp_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def clean_name(name): 5 | """Find the full fornames and assign a NaN value to the rest. 6 | Args: 7 | name (:obj:`str`): Forename of a person. 8 | 9 | Returns: 10 | (:obj:`str`) or (:obj:`np.nan`), depending on the string. 11 | 12 | """ 13 | name = re.sub('[)(#"!:]|-{2,}', "", name) 14 | first_name = " ".join(name.split(" ")[:-1]) 15 | last_name = name.split(" ")[-1] 16 | 17 | # Remove initials 18 | first_name = re.sub("(.?)\.", "", first_name).strip() 19 | first_name = re.sub("[A-Z]*-[A-Z]\\b", "", first_name) 20 | first_name = " ".join( 21 | [string for string in first_name.split(" ") if len(string) > 1] 22 | ) 23 | if len(first_name) > 1: 24 | return " ".join([first_name, last_name]).strip() 25 | else: 26 | return None 27 | 28 | 29 | def identity_tokenizer(tokens): 30 | """Passes tokens without processing. Used in a CountVectorizer. 31 | 32 | Args: 33 | tokens (:obj:`list`) 34 | 35 | Returns: 36 | tokens (:obj:`list`) 37 | 38 | """ 39 | return tokens 40 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from setuptools import find_namespace_packages 3 | 4 | exclude = ["tests*"] 5 | common_kwargs = dict( 6 | version="0.1.0", 7 | license="MIT", 8 | # install_requires=required, 9 | long_description=open("README.md").read(), 10 | url="https://github.com/kstathou/orion", 11 | author="Kostas Stathoulopoulos", 12 | author_email="k.stathou@gmail.com", 13 | maintainer="Kostas Stathoulopoulos", 14 | maintainer_email="k.stathou@gmail.com", 15 | classifiers=[ 16 | "Development Status :: 1 - Planning", 17 | "Intended Audience :: Developers", 18 | "Intended Audience :: Science/Research", 19 | "License :: OSI Approved :: MIT License", 20 | "Natural Language :: English", 21 | "Operating System :: OS Independent", 22 | "Programming Language :: Python :: 3.7", 23 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 24 | ], 25 | python_requires=">3.6", 26 | include_package_data=False, 27 | ) 28 | 29 | setup( 30 | name="orion", 31 | packages=find_namespace_packages(where="orion.*", exclude=exclude), 32 | **common_kwargs 33 | ) 34 | -------------------------------------------------------------------------------- /orion/packages/geo/enrich_countries.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | 4 | def parse_country_details(response, pos=0): 5 | """Parses the country details from the restcountries API. 6 | More info here: https://github.com/apilayer/restcountries 7 | 8 | Args: 9 | response (:obj:`list` of `dict`): API response. 10 | 11 | Returns: 12 | d (dict): Parsed API response. 13 | 14 | """ 15 | remappings = { 16 | "alpha2Code": "alpha2Code", 17 | "alpha3Code": "alpha3Code", 18 | "name": "name", 19 | "capital": "capital", 20 | "region": "region", 21 | "subregion": "subregion", 22 | "population": "population", 23 | } 24 | 25 | d = {k: response[pos][v] for k, v in remappings.items()} 26 | 27 | return d 28 | 29 | 30 | def get_country_details(country_name): 31 | """Fetches country details from restcountries API. 32 | 33 | Args: 34 | country_name (str): Country name. 35 | 36 | Returns: 37 | (dict) API response. 38 | 39 | """ 40 | r = requests.get(f"https://restcountries.eu/rest/v2/name/{country_name}") 41 | r.raise_for_status() 42 | 43 | return r.json() 44 | -------------------------------------------------------------------------------- /orion/packages/projection/dim_reduction.py: -------------------------------------------------------------------------------- 1 | import umap 2 | 3 | 4 | def umap_embeddings( 5 | data, n_neighbors=15, min_dist=0.1, n_components=2, metric="cosine" 6 | ): 7 | """Finds a low dimensional representation of the input embeddings. 8 | More info: https://umap-learn.readthedocs.io/en/latest/api.html#umap 9 | 10 | Args: 11 | data (:obj:`numpy.array` of :obj:`float`): Input vectors. 12 | n_neighbors (int): The size of local neighborhood (in terms of number 13 | of neighboring sample points) used for manifold approximation. 14 | min_dist (float): The effective minimum distance between embedded points. 15 | n_components (int): The dimension of the space to embed into. 16 | metric (str): The metric to use to compute distances in high dimensional space. 17 | 18 | Returns: 19 | (numpy.ndarray) 20 | 21 | """ 22 | reducer = umap.UMAP( 23 | n_neighbors=n_neighbors, 24 | min_dist=min_dist, 25 | n_components=n_components, 26 | metric=metric, 27 | random_state=42, 28 | ) 29 | 30 | fitted_reducer = reducer.fit(data) 31 | return fitted_reducer, fitted_reducer.transform(data) 32 | -------------------------------------------------------------------------------- /orion/packages/nlp/tests/test_text2vec.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest import mock 3 | 4 | import torch 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from orion.packages.nlp.text2vec import Text2Vector 9 | 10 | 11 | def test_text_encoder(): 12 | text = "Hello world. foo bar!" 13 | tv = Text2Vector() 14 | result = tv.encode_text(text) 15 | 16 | expected_shape = torch.Size([1, 8]) 17 | expected_values = torch.tensor([[101, 7592, 2088, 1012, 29379, 3347, 999, 102]]) 18 | 19 | assert result.shape == expected_shape 20 | assert torch.all(result.eq(expected_values)) 21 | # Ensure we're using the right model 22 | assert result.numpy()[0][0] == expected_values.numpy()[0][0] 23 | assert result.numpy()[0][1] == expected_values.numpy()[0][1] 24 | 25 | 26 | def test_feature_extraction(): 27 | input_ids = torch.tensor([[2, 10975, 126, 9, 4310, 111, 748, 187, 3]]) 28 | 29 | tv = Text2Vector() 30 | result = tv.feature_extraction(input_ids) 31 | expected_shape = torch.Size([1, 9, 768]) 32 | 33 | assert result.shape == expected_shape 34 | 35 | 36 | def test_average_vectors(): 37 | vectors = torch.tensor([[[2, 3, 4], [7, 8, 9]]]) 38 | 39 | tv = Text2Vector() 40 | result = tv.average_vectors(vectors) 41 | expected_result = np.array([4.5, 5.5, 6.5]) 42 | 43 | assert all(result == expected_result) 44 | -------------------------------------------------------------------------------- /schema/geocoded_places.yaml: -------------------------------------------------------------------------------- 1 | API: Google Places API 2 | endpoint: https://maps.googleapis.com/maps/api/place/details/json? 3 | schema: 4 | id: 5 | type: string 6 | description: Unique ID. Google Places API calls this Place ID 7 | affiliation_id: 8 | type: integer 9 | lat: 10 | type: float 11 | description: Latitude. 12 | lng: 13 | type: float 14 | description: Longitude. 15 | address: 16 | type: string 17 | description: Formatted (full) address of a place. 18 | name: 19 | type: string 20 | description: Place name. 21 | types: 22 | type: string[] 23 | description: Type of the place. You can find all the possible types here: https://developers.google.com/places/web-service/details#PlaceDetailsResults 24 | website: 25 | type: string 26 | description: Website URL of the place. 27 | postal_town: 28 | type: string 29 | description: Town of the place. For example, London. Do not use this. ~80% of it is missing. 30 | administrative_area_level_2: 31 | type: string 32 | description: City of the place. For example, Greater London. 33 | administrative_area_level_1: 34 | type: string 35 | description: Region of the place. For example, England. 36 | country: 37 | type: string 38 | description: Country of the place. 39 | -------------------------------------------------------------------------------- /orion/packages/utils/tests/test_nlp_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from orion.packages.utils.nlp_utils import clean_name 4 | from orion.packages.utils.nlp_utils import identity_tokenizer 5 | 6 | 7 | def test_clean_name_from_double_initials(): 8 | name = "A. B. FooBar" 9 | result = clean_name(name) 10 | 11 | expected_result = None 12 | 13 | assert result == expected_result 14 | 15 | 16 | def test_clean_name_from_single_initial(): 17 | name = "A. FooBar" 18 | result = clean_name(name) 19 | 20 | expected_result = None 21 | 22 | assert result == expected_result 23 | 24 | 25 | def test_clean_name_from_single_initial_variation(): 26 | name = "Foo A. FooBar" 27 | result = clean_name(name) 28 | 29 | expected_result = "Foo FooBar" 30 | 31 | assert result == expected_result 32 | 33 | 34 | def test_clean_name_symbols(): 35 | name = "허준 ( Joon Hur ) 이용구 ( Yong Goo Lee )" 36 | result = clean_name(name) 37 | 38 | expected_result = "허준 Joon Hur 이용구 Yong Goo Lee" 39 | 40 | assert result == expected_result 41 | 42 | 43 | def test_clean_name(): 44 | name = "Foo FooBar" 45 | result = clean_name(name) 46 | 47 | expected_result = "Foo FooBar" 48 | 49 | assert result == expected_result 50 | 51 | 52 | def test_identity_tokenizer(): 53 | data = [1, 2, 3] 54 | expected_result = [1, 2, 3] 55 | result = identity_tokenizer(data) 56 | 57 | assert result == expected_result 58 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Start with a base image 2 | FROM python:3-onbuild as base 3 | 4 | ARG DB_HOST=${DB_HOST} 5 | ARG DB_PORT=5432 6 | ARG MAIN_DB=${MAIN_DB} 7 | ARG DB_USER=${DB_USER} 8 | ARG DB_PASS=${DB_PASS} 9 | 10 | ENV orion_db=postgres+psycopg2://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${MAIN_DB} 11 | 12 | # Used for unit tests 13 | ENV orion_test=postgres+psycopg2://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/postgres 14 | 15 | # Stores Airflow task run metadata 16 | ENV airflow_db=postgres+psycopg2://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/airflow 17 | 18 | # Airflow setup 19 | ENV AIRFLOW_HOME /airflow/orion/core 20 | ENV AIRFLOW__CORE__SQL_ALCHEMY_CONN ${airflow_db} 21 | ENV AIRFLOW__CORE__EXECUTOR LocalExecutor 22 | ENV AIRFLOW__CORE__DAGS_FOLDER /airflow/orion/core/dags 23 | ENV AIRFLOW__CORE__LOAD_EXAMPLES=False 24 | ENV AIRFLOW_PORT=8080 25 | 26 | ENV AWS_ACCESS_KEY=${AWS_ACCESS_KEY} 27 | ENV AWS_ACCESS_SECRET=${AWS_ACCESS_SECRET} 28 | ENV AWS_REGION=${AWS_REGION} 29 | 30 | FROM base as builder 31 | 32 | RUN mkdir /install 33 | WORKDIR /install 34 | 35 | COPY requirements.txt /requirements.txt 36 | 37 | RUN pip install --upgrade pip \ 38 | && pip install --install-option=“--prefix=/install” -r /requirements.txt 39 | 40 | FROM base 41 | 42 | WORKDIR /airflow 43 | 44 | COPY --from=builder /install /usr/local 45 | COPY . ./ 46 | COPY entrypoint /entrypoint 47 | # COPY boto.cfg /etc/boto.cfg 48 | 49 | 50 | RUN pip install -e . 51 | 52 | EXPOSE 8080 53 | 54 | ENTRYPOINT ["/entrypoint"] 55 | -------------------------------------------------------------------------------- /orion/packages/utils/batches.py: -------------------------------------------------------------------------------- 1 | """Utilties for working with batches.""" 2 | import boto3 3 | 4 | # import time 5 | import pickle 6 | 7 | 8 | def split_batches(data, batch_size): 9 | """Breaks batches down into chunks consumable by the database. 10 | 11 | Args: 12 | data (:obj:`iterable`): Iterable containing data items 13 | batch_size (int): number of items per batch. 14 | 15 | Returns: 16 | (:obj:`list` of :obj:`pickle`): Yields a batch at a time. 17 | 18 | """ 19 | batch = [] 20 | for row in data: 21 | batch.append(row) 22 | if len(batch) == batch_size: 23 | yield batch 24 | batch.clear() 25 | if len(batch) > 0: 26 | yield batch 27 | 28 | 29 | def put_s3_batch(data, bucket, prefix): 30 | """Writes out a batch of data to s3 as pickle, so it can be picked up by the 31 | batchable task. 32 | 33 | Args: 34 | data (:obj:`list` of :obj:`str`): A batch of records. 35 | bucket (str): Name of the s3 bucket. 36 | prefix (str): Identifier for the batched object. 37 | 38 | Returns: 39 | (str): name of the file in the s3 bucket (key). 40 | 41 | """ 42 | # Pickle data 43 | data = pickle.dumps(data) 44 | 45 | # s3 setup 46 | s3 = boto3.resource("s3") 47 | 48 | # timestamp = str(time.time()).replace('.', '') 49 | filename = f"{prefix}.pickle" 50 | obj = s3.Object(bucket, filename) 51 | obj.put(Body=data) 52 | 53 | return filename 54 | -------------------------------------------------------------------------------- /schema/mag_papers.yaml: -------------------------------------------------------------------------------- 1 | API: Academic Knowledge API 2 | endpoint: https://api.labs.cognitive.microsoft.com/academic/v1.0/evaluate 3 | schema: 4 | id: 5 | type: integer 6 | prob: 7 | type: float 8 | description: Probability of this being the right response for a given query. Not being used. 9 | title: 10 | type: string 11 | description: Paper title. 12 | publication_type: 13 | type: string 14 | description: Publication type. It can be one of the following: (0:Unknown, 1:Journal article, 2:Patent, 3:Conference paper, 4:Book chapter, 5:Book, 6:Book reference entry, 7:Dataset, 8:Repository) 15 | year: 16 | type: string 17 | description: Publication year of the paper. 18 | date: 19 | type: string # YYYY-MM-DD 20 | description: Publication date. 21 | citations: 22 | type: integer 23 | description: Number of paper citations. 24 | references: 25 | type: integer[] 26 | description: List of references in the paper. 27 | doi: 28 | type: string 29 | description: Paper DOI (this is a unique identifier). 30 | publisher: 31 | type: string 32 | description: Publisher name. 33 | bibtext_doc_type: 34 | type: string 35 | description: BibTex document type. It can be one of the following: ('a':Journal article, 'b':Book, 'c':Book chapter, 'p':Conference paper) 36 | inverted_abstract: 37 | type: string 38 | description: Inverted abstract. 39 | -------------------------------------------------------------------------------- /orion/packages/gender/query_gender_api.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import logging 3 | from retrying import retry 4 | from requests.exceptions import HTTPError 5 | 6 | ENDPOINT = "https://gender-api.com/v2/gender" 7 | 8 | 9 | @retry(stop_max_attempt_number=2) 10 | def query_gender_api(full_names, auth_token): 11 | """Infers the gender by querying a full name to the GenderAPI. 12 | 13 | Args: 14 | full_names (:obj:`list` of str): Full names. 15 | auth_token (str): Authorization token. 16 | 17 | Returns: 18 | (json) Person's gender. 19 | 20 | """ 21 | headers = { 22 | "Authorization": f"Bearer {auth_token}", 23 | "Content-Type": "application/json", 24 | } 25 | 26 | data = [{"full_name": fn} for fn in full_names] 27 | 28 | r = requests.post(ENDPOINT, json=data, headers=headers) 29 | try: 30 | r.raise_for_status() 31 | return r.json() 32 | except HTTPError as h: 33 | logging.info(full_names, h) 34 | return None 35 | 36 | 37 | def parse_response(response): 38 | """Parses the GenderAPI response. 39 | 40 | Args: 41 | id_ (int): Author MAG ID. 42 | name (str): Full or first name used to query the GenderAPI. 43 | response (dict): GenderAPI response. 44 | 45 | Returns: 46 | (dict) Parsed response. 47 | 48 | """ 49 | d = {} 50 | d["full_name"] = response["input"]["full_name"] 51 | d["samples"] = response["details"]["samples"] 52 | d["first_name"] = response["first_name"] 53 | d["probability"] = response["probability"] 54 | d["gender"] = response["gender"] 55 | 56 | return d 57 | -------------------------------------------------------------------------------- /orion/packages/gender/tests/test_query_gender_api.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest import mock 3 | 4 | from orion.packages.gender.query_gender_api import query_gender_api 5 | from orion.packages.gender.query_gender_api import parse_response 6 | 7 | 8 | @mock.patch("orion.packages.gender.query_gender_api.requests.post", autospec=True) 9 | def test_query_gender_api_sends_correct_request(mocked_requests): 10 | auth_token = 123 11 | full_name = ["foo bar"] 12 | query_gender_api(full_name, auth_token) 13 | 14 | expected_call_args = mock.call( 15 | "https://gender-api.com/v2/gender", 16 | headers={ 17 | "Authorization": f"Bearer {auth_token}", 18 | "Content-Type": "application/json", 19 | }, 20 | json=[{"full_name": "foo bar"}], 21 | ) 22 | assert mocked_requests.call_args == expected_call_args 23 | 24 | 25 | def test_parse_response_from_gender_api(): 26 | response = { 27 | "input": {"full_name": "Foo Bar"}, 28 | "details": { 29 | "credits_used": 1, 30 | "duration": "71ms", 31 | "samples": 106011, 32 | "country": None, 33 | "first_name_sanitized": "foo", 34 | }, 35 | "result_found": True, 36 | "first_name": "Foo", 37 | "probability": 0.98, 38 | "gender": "female", 39 | } 40 | 41 | result = parse_response(response) 42 | 43 | expected_result = { 44 | "full_name": "Foo Bar", 45 | "samples": 106011, 46 | "first_name": "Foo", 47 | "probability": 0.98, 48 | "gender": "female", 49 | } 50 | 51 | assert result == expected_result 52 | -------------------------------------------------------------------------------- /orion/packages/utils/tests/test_batches.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest import mock 3 | 4 | from orion.packages.utils.batches import split_batches 5 | 6 | 7 | @pytest.fixture 8 | def generate_test_data(): 9 | def _generate_test_data(n): 10 | return [{"data": "foo", "other": "bar"} for i in range(n)] 11 | 12 | return _generate_test_data 13 | 14 | 15 | @pytest.fixture 16 | def generate_test_set_data(): 17 | def _generate_test_set_data(n): 18 | return {f"data{i}" for i in range(n)} 19 | 20 | return _generate_test_set_data 21 | 22 | 23 | def test_split_batches_when_data_is_smaller_than_batch_size(generate_test_data): 24 | yielded_batches = [] 25 | for batch in split_batches(generate_test_data(200), batch_size=1000): 26 | yielded_batches.append(batch) 27 | 28 | assert len(yielded_batches) == 1 29 | 30 | 31 | def test_split_batches_yields_multiple_batches_with_exact_fit(generate_test_data): 32 | yielded_batches = [] 33 | for batch in split_batches(generate_test_data(2000), batch_size=1000): 34 | yielded_batches.append(batch) 35 | 36 | assert len(yielded_batches) == 2 37 | 38 | 39 | def test_split_batches_yields_multiple_batches_with_remainder(generate_test_data): 40 | yielded_batches = [] 41 | for batch in split_batches(generate_test_data(2400), batch_size=1000): 42 | yielded_batches.append(batch) 43 | 44 | assert len(yielded_batches) == 3 45 | 46 | 47 | def test_split_batches_with_set(generate_test_set_data): 48 | yielded_batches = [] 49 | for batch in split_batches(generate_test_set_data(2400), batch_size=1000): 50 | yielded_batches.append(batch) 51 | 52 | assert len(yielded_batches) == 3 53 | -------------------------------------------------------------------------------- /orion/core/operators/faiss_index_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | FaissIndexOperator: Creates a FAISS index. It fetches vectors and paper IDs from PostgreSQL. 3 | It serialises and stores the index as a pickle in S3. 4 | 5 | """ 6 | import logging 7 | import numpy as np 8 | from airflow.models import BaseOperator 9 | from sqlalchemy.orm import sessionmaker 10 | from sqlalchemy import create_engine 11 | from orion.core.orms.mag_orm import HighDimDocVector 12 | from airflow.utils.decorators import apply_defaults 13 | from orion.packages.utils.s3_utils import store_on_s3 14 | import faiss 15 | from orion.packages.projection.faiss_index import faiss_index 16 | 17 | 18 | class FaissIndexOperator(BaseOperator): 19 | @apply_defaults 20 | def __init__(self, db_config, bucket, *args, **kwargs): 21 | super().__init__(**kwargs) 22 | self.db_config = db_config 23 | self.bucket = bucket 24 | 25 | def execute(self, context): 26 | # Connect to postgresql 27 | engine = create_engine(self.db_config) 28 | Session = sessionmaker(bind=engine) 29 | s = Session() 30 | 31 | vectors = s.query(HighDimDocVector.vector, HighDimDocVector.id) 32 | 33 | # Load vectors 34 | vectors, ids = zip(*vectors) 35 | logging.info("Loaded document vectors") 36 | 37 | # Store vectors in an array 38 | vectors = np.array([vector for vector in vectors]).astype("float32") 39 | 40 | # Build the FAISS index with custom IDs 41 | index = faiss_index(vectors, ids) 42 | logging.info(f"Created index with {index.ntotal} elements.") 43 | 44 | # Serialise index and store it on S3 45 | store_on_s3(faiss.serialize_index(index), self.bucket, "faiss_index") 46 | -------------------------------------------------------------------------------- /orion/core/operators/open_access_journals_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | OpenAccessJournalOperator: Splits the journals to non-open access (= 0) and open access (= 1) 3 | using a seed list of tokens. The seed list can be found in `model_config.yaml`. 4 | 5 | """ 6 | import logging 7 | from sqlalchemy import create_engine 8 | from sqlalchemy.orm import sessionmaker 9 | from sqlalchemy.sql import exists 10 | from airflow.models import BaseOperator 11 | from airflow.utils.decorators import apply_defaults 12 | from orion.core.orms.mag_orm import Journal, OpenAccess 13 | import orion 14 | 15 | seed_list = orion.config["open_access"] 16 | 17 | 18 | class OpenAccessJournalOperator(BaseOperator): 19 | """Flag a journal as open access or not.""" 20 | 21 | @apply_defaults 22 | def __init__(self, db_config, seed_list=seed_list, *args, **kwargs): 23 | super().__init__(**kwargs) 24 | self.db_config = db_config 25 | self.seed_list = seed_list 26 | 27 | def _is_open_access(self, name): 28 | if name in set(seed_list): 29 | return 1 30 | else: 31 | return 0 32 | 33 | def execute(self, context): 34 | # Connect to postgresql db 35 | engine = create_engine(self.db_config) 36 | Session = sessionmaker(engine) 37 | s = Session() 38 | 39 | s.query(OpenAccess).delete() 40 | s.commit() 41 | 42 | # Get journal names and IDs 43 | journal_access = [ 44 | {"id": id, "open_access": self._is_open_access(journal_name)} 45 | for (id, journal_name) in s.query(Journal.id, Journal.journal_name) 46 | .distinct() 47 | .all() 48 | ] 49 | 50 | logging.info(f"{len(journal_access)}") 51 | 52 | # Store journal types 53 | s.bulk_insert_mappings(OpenAccess, journal_access) 54 | s.commit() 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | :exclamation: We are sunsetting Orion. You can watch a [video of the platform on YouTube](https://youtu.be/OkfkdaVINSI). 2 | 3 | # Orion 4 | 5 | [![Build Status](https://travis-ci.org/orion-search/orion.svg?branch=dev)](https://travis-ci.org/kstathou/orion) 6 | 7 | Orion is a research measurement and knowledge discovery tool that enables you to monitor progress in science, visually explore the scientific landscape and search for relevant publications. 8 | 9 | This repo contains Orion's data collection, enrichment and analysis pipeline for scientific documents from Microsoft Academic Graph. You can find the rest of our work in the following repositories: 10 | 11 | - [Search engine](https://github.com/orion-search/search-engine) 12 | - [Web-interface](https://github.com/orion-search/orion-search.org) 13 | - [Talks, demos, papers and tutorials on Orion](https://github.com/orion-search/tutorials). Most of the content was made for presentations at venues such as the SciNLP, NetSci, IC2S2 and WOOC. 14 | - [Micro-service deployment [WIP]](https://github.com/orion-search/universe) 15 | 16 | To learn more about Orion, check out the **[documentation website](https://docs.orion-search.org/)**. 17 | 18 | Orion is open-source. If you want to use our work or parts of it, be a good citizen of the Internet and drop us an acknowledgement. We would also love to know what you are developing so get in touch! 19 | 20 | ## Installation ## 21 | 22 | 1. Clone Orion's ETL 23 | 24 | ``` bash 25 | git clone https://github.com/orion-search/orion 26 | ``` 27 | 28 | 2. Modify Orion using the `model_config.yaml` and the `.env` files as shown in [this tutorial](https://docs.orion-search.org/docs/running_etl). 29 | 3. Run Orion's ETL in docker 30 | 31 | ``` bash 32 | docker-compose up 33 | ``` 34 | 35 | 4. Access and run Orion's DAG at 36 | 37 | ``` bash 38 | http://localhost:8080/admin/ 39 | ``` 40 | 41 | ## TODO ## 42 | 43 | - Update data schema. 44 | - Change Airflow operators to kubernetes. 45 | -------------------------------------------------------------------------------- /orion/core/operators/affiliation_type_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | AffiliationTypeOperator: Splits the affiliations to industry (= 0) and non-industry (= 1) 3 | using a seed list of tokens. The seed list can be found in `model_config.yaml`. 4 | 5 | """ 6 | import logging 7 | from sqlalchemy import create_engine, and_ 8 | from sqlalchemy.orm import sessionmaker 9 | from sqlalchemy.sql import exists 10 | from airflow.models import BaseOperator 11 | from airflow.utils.decorators import apply_defaults 12 | from orion.core.orms.mag_orm import Affiliation, AffiliationType 13 | import orion 14 | 15 | seed_list = orion.config["affiliations"]["non_profit"] 16 | 17 | 18 | class AffiliationTypeOperator(BaseOperator): 19 | """Find the type (industry, non-industry) of an affiliation.""" 20 | 21 | @apply_defaults 22 | def __init__(self, db_config, seed_list=seed_list, *args, **kwargs): 23 | super().__init__(**kwargs) 24 | self.db_config = db_config 25 | self.seed_list = seed_list 26 | 27 | def _find_academic_affiliations(self, name): 28 | if any(val in name for val in self.seed_list): 29 | return 1 30 | else: 31 | return 0 32 | 33 | def execute(self, context): 34 | # Connect to postgresql db 35 | engine = create_engine(self.db_config) 36 | Session = sessionmaker(engine) 37 | s = Session() 38 | 39 | # Get affiliation names and IDs 40 | aff_types = [ 41 | { 42 | "id": aff.id, 43 | "non_industry": self._find_academic_affiliations(aff.affiliation), 44 | } 45 | for aff in s.query(Affiliation) 46 | .filter(and_(~exists().where(Affiliation.id == AffiliationType.id))) 47 | .all() 48 | ] 49 | logging.info(f"Mapped {len(aff_types)} affiliations.") 50 | 51 | # Store affiliation types 52 | s.bulk_insert_mappings(AffiliationType, aff_types) 53 | s.commit() 54 | -------------------------------------------------------------------------------- /orion/core/operators/mag_geocode_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | GeocodingOperator: Fetches affiliation names from PostgreSQL, geocodes them and 3 | collects additional details using Google Places API. It parses the response and stores 4 | it in PostgreSQL. 5 | 6 | """ 7 | import logging 8 | from sqlalchemy import create_engine 9 | from sqlalchemy.sql import exists 10 | from sqlalchemy.orm import sessionmaker 11 | from airflow.models import BaseOperator 12 | from airflow.utils.decorators import apply_defaults 13 | from orion.packages.geo.geocode import place_by_id, place_by_name, parse_response 14 | from orion.core.orms.mag_orm import Affiliation, AffiliationLocation 15 | 16 | 17 | class GeocodingOperator(BaseOperator): 18 | """Find a place's details given its name.""" 19 | 20 | # template_fields = [''] 21 | @apply_defaults 22 | def __init__(self, db_config, subscription_key, *args, **kwargs): 23 | super().__init__(**kwargs) 24 | self.db_config = db_config 25 | self.subscription_key = subscription_key 26 | 27 | def execute(self, context): 28 | # Connect to postgresql db 29 | engine = create_engine(self.db_config) 30 | Session = sessionmaker(engine) 31 | s = Session() 32 | 33 | # Fetch affiliations that have not been geocoded yet. 34 | queries = s.query(Affiliation.id, Affiliation.affiliation).filter( 35 | ~exists().where(Affiliation.id == AffiliationLocation.affiliation_id) 36 | ) 37 | logging.info(f"Number of queries: {queries.count()}") 38 | 39 | for id, name in queries: 40 | r = place_by_name(name, self.subscription_key) 41 | if r is not None: 42 | response = place_by_id(r, self.subscription_key) 43 | place_details = parse_response(response) 44 | place_details.update({"affiliation_id": id}) 45 | s.add(AffiliationLocation(**place_details)) 46 | s.commit() 47 | else: 48 | continue 49 | -------------------------------------------------------------------------------- /orion/packages/projection/tests/test_projections.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from orion.packages.projection.dim_reduction import umap_embeddings 5 | from orion.packages.projection.faiss_index import faiss_index 6 | 7 | 8 | def test_umap_embeddings(): 9 | arr = np.random.uniform(0.1, 10, [10, 4]) 10 | 11 | _, result = umap_embeddings(arr) 12 | expected_shape = (10, 2) 13 | 14 | assert result.shape == expected_shape 15 | 16 | 17 | def test_faiss_index_construction_with_ids(): 18 | vectors = np.array( 19 | [ 20 | [0.1243, 1.124, 0.21456, 0.123], 21 | [1.24, 0.765, 0.987, 0.433], 22 | [0.123, 0.6543, 0.734, 0.235], 23 | ], 24 | dtype="float32", 25 | ) 26 | ids = [11, 22, 33] 27 | 28 | index = faiss_index(vectors, ids) 29 | 30 | D, I = index.search(np.array([vectors[0]]), k=3) 31 | expected_D = np.array([[0.0, 0.5029817, 2.066431]], dtype="float32") 32 | expected_I = np.array([[11, 33, 22]]) 33 | 34 | expected_index_total = 3 35 | index_total = index.ntotal 36 | 37 | assert index_total == expected_index_total 38 | np.testing.assert_almost_equal(D, expected_D) 39 | np.testing.assert_almost_equal(I, expected_I) 40 | 41 | 42 | def test_faiss_index_construction_without_ids(): 43 | vectors = np.array( 44 | [ 45 | [0.1243, 1.124, 0.21456, 0.123], 46 | [1.24, 0.765, 0.987, 0.433], 47 | [0.123, 0.6543, 0.734, 0.235], 48 | ], 49 | dtype="float32", 50 | ) 51 | 52 | index = faiss_index(vectors) 53 | 54 | D, I = index.search(np.array([vectors[0]]), k=3) 55 | expected_D = np.array([[0.0, 0.5029817, 2.066431]], dtype="float32") 56 | expected_I = np.array([[0, 2, 1]]) 57 | 58 | expected_index_total = 3 59 | index_total = index.ntotal 60 | 61 | assert index_total == expected_index_total 62 | np.testing.assert_almost_equal(D, expected_D) 63 | np.testing.assert_almost_equal(I, expected_I) 64 | -------------------------------------------------------------------------------- /orion/packages/utils/s3_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import boto3 3 | import logging 4 | 5 | 6 | def store_on_s3(data, bucket, prefix): 7 | """Writes out data to s3 as pickle, so it can be picked up by a task. 8 | 9 | Args: 10 | data (:obj:`list` of :obj:`str`): A batch of records. 11 | bucket (str): Name of the s3 bucket. 12 | prefix (str): Identifier for the batched object. 13 | 14 | Returns: 15 | (str): name of the file in the s3 bucket (key). 16 | 17 | """ 18 | # Pickle data 19 | data = pickle.dumps(data) 20 | 21 | # s3 setup 22 | s3 = boto3.resource("s3") 23 | # timestamp = str(time.time()).replace('.', '') 24 | filename = f"{prefix}.pickle" 25 | obj = s3.Object(bucket, filename) 26 | obj.put(Body=data) 27 | 28 | return filename 29 | 30 | 31 | def load_from_s3(bucket, prefix): 32 | """Loads a pickled file from s3. 33 | 34 | Args: 35 | bucket (str): Name of the s3 bucket. 36 | prefix (str): Name of the pickled file. 37 | 38 | """ 39 | s3 = boto3.resource("s3") 40 | obj = s3.Object(bucket, f"{prefix}.pickle") 41 | return pickle.loads(obj.get()["Body"].read()) 42 | 43 | 44 | def s3_bucket_obj(bucket): 45 | """Get all objects of an S3 bucket. 46 | 47 | Args: 48 | bucket (str): Name of the s3 bucket. 49 | 50 | Returns: 51 | (`boto3.resources.collection.s3.Bucket.objectsCollection`) 52 | 53 | """ 54 | s3 = boto3.resource("s3") 55 | return list(s3.Bucket(bucket).objects.all()) 56 | 57 | 58 | def create_s3_bucket(bucket, location="eu-west-2"): 59 | """Create an s3 bucket on a given location.""" 60 | s3 = boto3.resource("s3") 61 | # Check if the bucket already exists 62 | if not s3.Bucket(bucket).creation_date: 63 | s3.create_bucket( 64 | Bucket=bucket, CreateBucketConfiguration={"LocationConstraint": location} 65 | ) 66 | else: 67 | logging.info(f"Bucket {bucket} already exists. Skipped creation.") 68 | -------------------------------------------------------------------------------- /orion/packages/mag/tests/test_query_mag.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest import mock 3 | 4 | from orion.packages.mag.query_mag_api import build_expr 5 | from orion.packages.mag.query_mag_api import query_mag_api 6 | from orion.packages.mag.query_mag_api import dedupe_entities 7 | from orion.packages.mag.query_mag_api import build_composite_expr 8 | 9 | 10 | class TestBuildExpr: 11 | def test_build_expr_correctly_forms_query(self): 12 | assert list(build_expr([1, 2], "Id", 1000)) == ["expr=OR(Id=1,Id=2)"] 13 | assert list(build_expr(["cat", "dog"], "Ti", 1000)) == [ 14 | "expr=OR(Ti='cat',Ti='dog')" 15 | ] 16 | 17 | def test_build_expr_respects_query_limit_and_returns_remainder(self): 18 | assert list(build_expr([1, 2, 3], "Id", 21)) == [ 19 | "expr=OR(Id=1,Id=2)", 20 | "expr=OR(Id=3)", 21 | ] 22 | 23 | 24 | @mock.patch("orion.packages.mag.query_mag_api.requests.post", autospec=True) 25 | def test_query_mag_api_sends_correct_request(mocked_requests): 26 | sub_key = 123 27 | fields = ["Id", "Ti"] 28 | expr = "expr=OR(Id=1,Id=2)" 29 | query_mag_api(expr, fields, sub_key, query_count=10, offset=0) 30 | expected_call_args = mock.call( 31 | "https://api.labs.cognitive.microsoft.com/academic/v1.0/evaluate", 32 | data=b"expr=OR(Id=1,Id=2)&count=10&offset=0&attributes=Id,Ti", 33 | headers={ 34 | "Ocp-Apim-Subscription-Key": 123, 35 | "Content-Type": "application/x-www-form-urlencoded", 36 | }, 37 | ) 38 | assert mocked_requests.call_args == expected_call_args 39 | 40 | 41 | def test_dedupe_entities_picks_highest_for_each_title(): 42 | entities = [ 43 | {"Id": 1, "Ti": "test title", "logprob": 44}, 44 | {"Id": 2, "Ti": "test title", "logprob": 10}, 45 | {"Id": 3, "Ti": "another title", "logprob": 5}, 46 | {"Id": 4, "Ti": "another title", "logprob": 10}, 47 | ] 48 | 49 | assert dedupe_entities(entities) == {1, 4} 50 | 51 | 52 | def test_build_composite_queries_correctly(): 53 | assert ( 54 | build_composite_expr(["bar", "foo"], "F.FN", ("2019-01-01", "2019-02-22")) 55 | == "expr=OR(And(Composite(F.FN='bar'), D=['2019-01-01', '2019-02-22']), And(Composite(F.FN='foo'), D=['2019-01-01', '2019-02-22']))" 56 | ) 57 | -------------------------------------------------------------------------------- /orion/packages/README.md: -------------------------------------------------------------------------------- 1 | # Packages # 2 | A collection of modules used in Orion. 3 | 4 | ## Microsoft Academic Knowledge API ## 5 | 6 | ### Getting an API key ### 7 | * Sign up for an API Management account with [Microsoft Research](https://msr-apis.portal.azure-api.net/signup). 8 | * To activate your account, log into the email you used during the registration, open the _Please confirm your new Microsoft Research APIs account_ email and click on the activation link. 9 | * Click on the **Subscribe** button and choose **Project Academic Knowledge**. 10 | * Click again on the **Subscribe** button and then **Confirm** your choice. 11 | * You can now use the **Primary key** to query the API. 12 | 13 | ### Using your API key in this project ### 14 | The Microsoft Academic API key is stored in the `orion/orion/core/config/orion_config.config` file with the following format: 15 | 16 | ``` 17 | [mag] 18 | MAG_API_KEY=MY_MAG_API_KEY 19 | ``` 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | To learn how to use the API, check the [official documentation](https://docs.microsoft.com/en-us/azure/cognitive-services/academic-knowledge/home). 28 | 29 | ## Google Places API ## 30 | 31 | ### Getting an API key ### 32 | * Sign in with your Google account to [Google Cloud Platform (GCP)](https://console.cloud.google.com/). 33 | * Set up a project and enable billing. 34 | * Find the **Places API** in the **Marketplace** and enable it. 35 | * Click on the **CREDENTIALS** tab and generate an API key. 36 | 37 | ### Using your API key in this project ### 38 | The Google API key is stored in the `orion/orion/core/config/orion_config.config` file with the following format: 39 | 40 | ``` 41 | [google] 42 | GOOGLE_KEY=MY_GOOGLE_API_KEY 43 | ``` 44 | 45 | 46 | 47 | 48 | 49 | 50 | To learn how to use the API, check the [official documentation](https://developers.google.com/places/web-service/details). 51 | 52 | ## Gender API ## 53 | 54 | ### Getting an API key ### 55 | Sign up on [Gender API](https://gender-api.com/en/) in order to get an API key. 56 | 57 | ### Using your API key in this project ### 58 | The Gender API key is stored in the `orion/orion/core/config/orion_config.config` file with the following format: 59 | 60 | ``` 61 | [genderapi] 62 | GENDER_API_KEY=MY_GENDER_API_KEY 63 | ``` 64 | 65 | **Note**: All of the above are paid services. 66 | -------------------------------------------------------------------------------- /orion/packages/nlp/text2vec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import DistilBertModel, DistilBertTokenizer 3 | import numpy as np 4 | import sentencepiece as spm 5 | 6 | np.random.seed(42) 7 | 8 | 9 | class Text2Vector: 10 | """Transform text to vector using DistilBert from HuggingFace.""" 11 | 12 | def __init__( 13 | self, 14 | tokenizer_model=DistilBertTokenizer, 15 | transformer=DistilBertModel, 16 | pretrained_weights="distilbert-base-uncased", 17 | ): 18 | self.tokenizer_model = tokenizer_model 19 | self.transformer = transformer 20 | self.pretrained_weights = pretrained_weights 21 | self.tokenizer = self.tokenizer_model.from_pretrained(self.pretrained_weights) 22 | self.model = self.transformer.from_pretrained(self.pretrained_weights) 23 | 24 | def encode_text(self, text): 25 | """Tokenizes text and finds its indices. 26 | 27 | Args: 28 | text (str): Text input. 29 | tokenizer_model (transformers.tokenization_distilbert.DistilBertTokenizer): Pretrained tokenizer. 30 | The tokenizer must match the transformer architecture that will be used. 31 | pretrained_weights (str): Pretrained weights shortcut. 32 | 33 | Returns: 34 | (torch.Tensor) Indices of input sequence tokens in the vocabulary of the transformer. 35 | 36 | """ 37 | 38 | # max_length is equal to 512 because that's the longest input sequence the model takes. 39 | return torch.tensor( 40 | [self.tokenizer.encode(text, add_special_tokens=True, max_length=512)] 41 | ) 42 | 43 | def feature_extraction(self, input_ids): 44 | """Extracts word embeddings. 45 | 46 | Args: 47 | input_ids (torch.Tensor) Indices of input sequence tokens in the vocabulary of the transformer. 48 | model_class (transformers.modeling_distilbert.DistilBertModel): Pretrained transformer. 49 | pretrained_weights (str): Pretrained weights shortcut. 50 | 51 | Returns: 52 | (torch.Tensor) Tensor of shape (batch_size, sequence_length, hidden_size). 53 | 54 | """ 55 | with torch.no_grad(): 56 | # Keep only the sequence of hidden-states at the output of the last layer of the model. 57 | last_hidden_states = self.model(input_ids)[0] 58 | 59 | return last_hidden_states 60 | 61 | def average_vectors(self, vectors): 62 | """Averages a Tensor with hidden states. 63 | 64 | Args: 65 | vectors (torch.Tensor) Tensor of shape (batch_size, sequence_length, hidden_size). 66 | 67 | Returns: 68 | (numpy.ndarray) Average of the vectors of the shape (hidden_size,). 69 | 70 | """ 71 | return np.mean([l for l in vectors.numpy()[0]], axis=0) 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # exclude data/{raw,externa,interim} from source control by default 132 | # data/aux can be tracked by github as this will contain hand-written content 133 | # data/processed is not ignored as outputs should be tracked by DVC 134 | /data/raw/ 135 | /data/external/ 136 | /data/interim/ 137 | /data 138 | 139 | *~ 140 | #* 141 | *# 142 | *bin 143 | *npy 144 | *pickle 145 | 146 | # OSX 147 | .DS_Store 148 | 149 | # config 150 | /orion/core/config/ 151 | orion/core/airflow-scheduler.pid 152 | orion/core/airflow.cfg 153 | orion/core/logs/ 154 | orion/core/unittests.cfg 155 | orion/core/airflow-webserver.pid 156 | boto.cfg 157 | -------------------------------------------------------------------------------- /orion/core/operators/collect_wb_indicators_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | Collects indicators from the World Bank. Currently, we collect indicators from the following URLs: 3 | - http://datatopics.worldbank.org/world-development-indicators/themes/economy.html#featured-indicators_1 4 | - http://datatopics.worldbank.org/world-development-indicators/themes/states-and-markets.html#featured-indicators_1 5 | - http://datatopics.worldbank.org/world-development-indicators/themes/global-links.html#featured-indicators_1 6 | - http://datatopics.worldbank.org/world-development-indicators/themes/people.html#featured-indicators_1 7 | 8 | We use the pandas-datareader, a Python package that provides access to economic databases 9 | for this as it is straightforward to collect indicators by querying their unique code. 10 | 11 | Orion currently collects the following country-level indicators: 12 | * GDP (current US$) 13 | * Research and development expenditure (% of GDP) 14 | * Government expenditure on education, total (% of GDP) 15 | * Ratio of female to male labour force participation rate (%) (modelled ILO estimate) 16 | 17 | Users can filter indicators by start and end year as well as country. 18 | 19 | """ 20 | from pandas_datareader import wb 21 | from sqlalchemy import create_engine 22 | from sqlalchemy.orm import sessionmaker 23 | from airflow.models import BaseOperator 24 | from airflow.utils.decorators import apply_defaults 25 | from orion.core.orms.mag_orm import ( 26 | WorldBankFemaleLaborForce, 27 | WorldBankGovEducation, 28 | WorldBankResearchDevelopment, 29 | WorldBankGDP, 30 | ) 31 | 32 | 33 | class WBIndicatorOperator(BaseOperator): 34 | """Fetches indicators from the World Bank.""" 35 | 36 | @apply_defaults 37 | def __init__( 38 | self, 39 | db_config, 40 | table_name, 41 | indicator, 42 | start_year, 43 | end_year, 44 | country, 45 | *args, 46 | **kwargs 47 | ): 48 | super().__init__(**kwargs) 49 | self.db_config = db_config 50 | self.indicator = indicator 51 | self.start_year = start_year 52 | self.end_year = end_year 53 | self.country = country 54 | self.table_name = table_name 55 | self.tables = { 56 | "wb_gdp": WorldBankGDP, 57 | "wb_edu_expenditure": WorldBankGovEducation, 58 | "wb_rnd_expenditure": WorldBankResearchDevelopment, 59 | "wb_female_workforce": WorldBankFemaleLaborForce, 60 | } 61 | 62 | def execute(self, context): 63 | # Connect to postgresql db 64 | engine = create_engine(self.db_config) 65 | Session = sessionmaker(engine) 66 | s = Session() 67 | 68 | # Fetch WB Indicator 69 | ind = wb.download( 70 | indicator=self.indicator, 71 | country=self.country, 72 | start=self.start_year, 73 | end=self.end_year, 74 | ) 75 | 76 | # Store in DB 77 | for (area, year), row in ind.iterrows(): 78 | s.add( 79 | self.tables[self.table_name]( 80 | country=area, year=year, indicator=row[self.indicator] 81 | ) 82 | ) 83 | s.commit() 84 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute to Orion? # 2 | 3 | Everyone is welcome to contribute to Orion! You can support our work by contributing code, improving documentation and testing as well as sharing with us how you are using Orion. Moreover, spreading the word and connecting us with potential users would be immensly valuable. 4 | 5 | ## Ways to contribute ## 6 | We have listed four ways to contribute to Orion but get in touch if you have other ideas: 7 | - Work on outstanding [issues](https://github.com/orion-search/orion/issues). 8 | - Implement new components. For example, add a new metric or data source. 9 | - Submit issues related to bugs or desired new features. 10 | - Connect us with potential users. 11 | 12 | ### Working on an outstanding issue ### 13 | Would you like to work on an existing issue? That's great! Before writing any code, we strongly advise you to reply to the issue you plan to work on and provide the following details: 14 | - Describe your solution to the issue. 15 | - Add any references that will help us learn more about your proposed solution. 16 | 17 | We aim to reply as soon as possible and guide you as best as we can. After agreeing on the scope of your contribution, you can start working on it: 18 | - Fork the [repository](https://github.com/orion-search/orion) and clone it on your local disk. 19 | - Create a new branch for your changes. **Do not** work directly on the `master` or `dev` branches. 20 | - Develop and test your solution. Make sure it also passes existing tests. Document your contribution and use [`black`](https://github.com/psf/black) for code formatting. 21 | - When you are happy with your changes, push your code and open a Pull Request (PR) describing your solution. If it is still work in progress, create a draft PR and add a **[WIP]** in front of its title. 22 | - Add me (@kstathou) as a reviewer. If I suggest some changes, work on your local branch and push them to your fork. Let me know when I should review the new additions and after my approval, you can merge your contribution to Orion's `dev` branch! 23 | 24 | ### Suggesting new features ### 25 | We are actively developing Orion and your ideas can help us make it better! If you want to suggest a new feature, open a new issue and provide the following details: 26 | - Motivation: Why do you think this feature is important and should be developed in Orion? 27 | - Is it related to a problem (not bug) with Orion? For example, is there a better way to do X, or is a data source with better coverage than what we use? 28 | - Is it a feature you saw somewhere else and it would be a valuable addition to Orion? Let us know where you found it! 29 | - Describe the new feature. 30 | - Add any references that might help us learn more about it. 31 | 32 | ### Reporting bugs ### 33 | Orion is in beta so we except users to come across some bugs. If you found one, let us know by submitting a new issue with the title `Bug: descriptive_title_for_the_bug` and providing the following details: 34 | - Your Python version. 35 | - Let us know which task failed or didn't work as expected. Giving us a data sample to rerun the task would be very helpful. 36 | - Full error message. 37 | 38 | ### Connecting us with users ### 39 | We think Orion can help policymakers and researchers parse academic knowledge. Do you know an individual or organisation who might benefit by using Orion? Send me an email at kostas@mozillafoundation.org! 40 | 41 | This guide was inspired by the [transformers](https://github.com/huggingface/transformers). 42 | -------------------------------------------------------------------------------- /orion/core/operators/topic_filtering_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | Filter topics so that they can be used in downstream tasks. 3 | 4 | FilterTopicsByDistributionOperator: Filter topics by level and frequency. 5 | 6 | FilteredTopicsMetadataOperator: Creates a table with the filtered Fields of Study, their children, 7 | annual citation sum and paper count. 8 | 9 | """ 10 | import logging 11 | import pandas as pd 12 | from sqlalchemy import create_engine 13 | from sqlalchemy.orm import sessionmaker 14 | from airflow.models import BaseOperator 15 | from airflow.utils.decorators import apply_defaults 16 | import numpy as np 17 | from orion.core.orms.mag_orm import ( 18 | FosMetadata, 19 | FilteredFos, 20 | Paper, 21 | PaperFieldsOfStudy, 22 | FosHierarchy, 23 | ) 24 | from orion.packages.utils.s3_utils import store_on_s3, load_from_s3 25 | from orion.packages.utils.utils import flatten_lists, get_all_children 26 | 27 | 28 | class FilterTopicsByDistributionOperator(BaseOperator): 29 | """Filter topics by level and frequency.""" 30 | 31 | @apply_defaults 32 | def __init__( 33 | self, db_config, s3_bucket, prefix, levels, percentiles, *args, **kwargs 34 | ): 35 | super().__init__(**kwargs) 36 | self.db_config = db_config 37 | self.s3_bucket = s3_bucket 38 | self.prefix = prefix 39 | self.levels = levels 40 | self.percentiles = percentiles 41 | 42 | def execute(self, context): 43 | # Connect to postgresql db 44 | engine = create_engine(self.db_config) 45 | Session = sessionmaker(engine) 46 | s = Session() 47 | 48 | # Fetch tables 49 | metadata = pd.read_sql(s.query(FosMetadata).statement, s.bind) 50 | logging.info(f"Number of FoS: {metadata.id.shape[0]}") 51 | 52 | d = {} 53 | for lvl, perc in zip(self.levels, self.percentiles): 54 | # Filter by level 55 | frame = metadata[metadata.level == lvl] 56 | # Find the percentile 57 | num = int(np.percentile(frame.frequency, perc)) 58 | d[lvl] = list(frame[frame.frequency > num]["id"].values) 59 | 60 | # Store pickle on s3 61 | store_on_s3(d, self.s3_bucket, self.prefix) 62 | logging.info("Done :)") 63 | 64 | 65 | class FilteredTopicsMetadataOperator(BaseOperator): 66 | """Creates a table with the filtered Fields of Study, their children, 67 | annual citation sum and paper count.""" 68 | 69 | @apply_defaults 70 | def __init__(self, db_config, s3_bucket, prefix, *args, **kwargs): 71 | super().__init__(**kwargs) 72 | self.db_config = db_config 73 | self.s3_bucket = s3_bucket 74 | self.prefix = prefix 75 | 76 | def execute(self, context): 77 | # Load topics 78 | topics = flatten_lists(list(load_from_s3(self.s3_bucket, self.prefix).values())) 79 | logging.info(f"Number of topics: {len(topics)}") 80 | 81 | # Connect to postgresql db 82 | engine = create_engine(self.db_config) 83 | FilteredFos.__table__.drop(engine, checkfirst=True) 84 | FilteredFos.__table__.create(engine, checkfirst=True) 85 | Session = sessionmaker(engine) 86 | s = Session() 87 | 88 | # Load all the tables needed for the metrics 89 | papers = pd.read_sql(s.query(Paper).statement, s.bind) 90 | paper_fos = pd.read_sql(s.query(PaperFieldsOfStudy).statement, s.bind) 91 | hierarchy = pd.read_sql(s.query(FosHierarchy).statement, s.bind) 92 | 93 | # Merge papers with fields of study, citations and publication year. 94 | papers = ( 95 | papers[["id", "citations", "year"]] 96 | .merge(paper_fos, left_on="id", right_on="paper_id") 97 | .drop("id", axis=1) 98 | ) 99 | logging.info("Merged tables.") 100 | 101 | # Traverse the FoS hierarchy tree and get all children 102 | d = {topic: get_all_children(hierarchy, topic) for topic in topics} 103 | logging.info(f"Got children of {len(d)} topics.") 104 | 105 | for fos_ids in d.values(): 106 | logging.info(f"fos id: {fos_ids[0]}") 107 | g = ( 108 | papers[papers.field_of_study_id.isin(fos_ids)] 109 | .drop_duplicates("paper_id") 110 | .groupby("year") 111 | ) 112 | for year, paper_count, total_citations in zip( 113 | g.groups.keys(), g["paper_id"].count(), g["citations"].sum() 114 | ): 115 | s.add( 116 | FilteredFos( 117 | field_of_study_id=int(fos_ids[0]), 118 | all_children=[int(f) for f in fos_ids], 119 | year=year, 120 | paper_count=int(paper_count), 121 | total_citations=int(total_citations), 122 | ) 123 | ) 124 | s.commit() 125 | -------------------------------------------------------------------------------- /orion/packages/geo/geocode.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import requests 3 | import numpy as np 4 | 5 | GEOCODE = "https://maps.googleapis.com/maps/api/geocode/json?" 6 | FIND_PLACE = "https://maps.googleapis.com/maps/api/place/findplacefromtext/json?" 7 | PLACE_DETAILS = "https://maps.googleapis.com/maps/api/place/details/json?" 8 | 9 | 10 | def geocoding(address, key, GEOCODE=GEOCODE): 11 | """Geocodes a human-readable address. 12 | 13 | Args: 14 | address (str): Address of a location. For example, it can be country, city, 15 | postcode or street name. 16 | key (str): Key for the Google API. 17 | GEOCODE (str): Endpoint for the Google Geocode API. Note that the service 18 | must be enabled in order to use it. 19 | 20 | Returns: 21 | r (:obj:`dict` of :obj:`list`): API response with geocoded information. The returned 22 | results are ordered by relevance. 23 | 24 | """ 25 | params = {"address": "{}".format(address), "key": key} 26 | 27 | r = requests.get(GEOCODE, params=params) 28 | r.raise_for_status() 29 | 30 | return r.json() 31 | 32 | 33 | def reverse_geocoding(lat, lng, key, GEOCODE=GEOCODE): 34 | """Reverse geocodes geographic coordinates into a human-readable address. 35 | 36 | Args: 37 | lat (float): Latitude. 38 | lng (float): Longitude. 39 | key (str): Key for the Google API. 40 | GEOCODE (str): Endpoint for the Google Geocode API. Note that the service 41 | must be enabled in order to use it. 42 | 43 | Returns: 44 | (:obj:`dict` of :obj:`list`): API response with the most relevant results 45 | to the given longitude and latitude. 46 | 47 | """ 48 | params = {"latlng": "{},{}".format(lat, lng), "key": key} 49 | 50 | r = requests.get(GEOCODE, params=params) 51 | r.raise_for_status() 52 | return r.json() 53 | 54 | 55 | def place_by_name(place, key, FIND_PLACE=FIND_PLACE): 56 | """Finds a Google Place ID by searching with its name. 57 | 58 | Args: 59 | place (str): Name of the place. It can be a restaurant, bar, monument, 60 | whatever you would normally search in Google Maps. 61 | key (str): Key for the Google API. 62 | FIND_PLACE (str): Endpoint for the Google Places API. Note that the 63 | service must be enabled in order to use it. 64 | 65 | Returns: 66 | (str) Place ID. 67 | 68 | """ 69 | params = { 70 | "input": "{}".format(place), 71 | "fields": "place_id", 72 | "inputtype": "textquery", 73 | "key": key, 74 | } 75 | 76 | r = requests.get(FIND_PLACE, params=params) 77 | r.raise_for_status() 78 | 79 | try: 80 | return r.json()["candidates"][0]["place_id"] 81 | except IndexError as e: 82 | logging.info(f"Failed to find a match for {place}") 83 | return None 84 | 85 | 86 | def place_by_id(id, key, PLACE_DETAILS=PLACE_DETAILS): 87 | """Finds details about a place given its Google Place ID. 88 | 89 | Args: 90 | id (str): Place ID. 91 | key (str): Key for the Google API. 92 | FIND_PLACE_DETAILS (str): Endpoint for the Google Places API. Note that the 93 | service must be enabled in order to use it. 94 | 95 | Returns: 96 | (dict): Details of a place. See the `fields` parameters to find what's 97 | being returned in the response. 98 | 99 | """ 100 | params = { 101 | "place_id": id, 102 | "key": key, 103 | "fields": "address_components,formatted_address,geometry,name,place_id,type,website", 104 | } 105 | 106 | r = requests.get(PLACE_DETAILS, params=params) 107 | r.raise_for_status() 108 | 109 | return r.json() 110 | 111 | 112 | def parse_response(response): 113 | """Parses details from a Google Place Details API endpoint response. 114 | 115 | Args: 116 | response (dict): Response of a request. 117 | 118 | Returns: 119 | d (dict): Geocoded information for a given Place ID. 120 | 121 | """ 122 | result = response["result"] 123 | 124 | # Store attributes 125 | d = dict() 126 | d["lat"] = result["geometry"]["location"]["lat"] 127 | d["lng"] = result["geometry"]["location"]["lng"] 128 | d["address"] = result["formatted_address"] 129 | d["name"] = result["name"] 130 | d["id"] = result["place_id"] 131 | d["types"] = result["types"] 132 | try: 133 | d["website"] = result["website"] 134 | except KeyError as e: 135 | logging.info(f"{d['name']}: {e}") 136 | d["website"] = np.nan 137 | 138 | for r in result["address_components"]: 139 | if "postal_town" in r["types"]: 140 | d["postal_town"] = r["long_name"] 141 | elif "administrative_area_level_2" in r["types"]: 142 | d["administrative_area_level_2"] = r["long_name"] 143 | elif "administrative_area_level_1" in r["types"]: 144 | d["administrative_area_level_1"] = r["long_name"] 145 | elif "country" in r["types"]: 146 | d["country"] = r["long_name"] 147 | else: 148 | continue 149 | 150 | return d 151 | -------------------------------------------------------------------------------- /orion/packages/mag/tests/test_parsing_mag_data.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from orion.packages.mag.parsing_mag_data import parse_papers 4 | from orion.packages.mag.parsing_mag_data import parse_affiliations 5 | from orion.packages.mag.parsing_mag_data import parse_authors 6 | from orion.packages.mag.parsing_mag_data import parse_fos 7 | from orion.packages.mag.parsing_mag_data import parse_journal 8 | from orion.packages.mag.parsing_mag_data import parse_conference 9 | 10 | test_example = { 11 | "logprob": -17.825, 12 | "prob": 1.81426557, 13 | "Id": 2592122940, 14 | "Ti": "dna fountain enables a robust and efficient storage architecture", 15 | "Pt": "1", 16 | "DN": "this is an original title", 17 | "Y": 2017, 18 | "D": "2017-03-03", 19 | "CC": 109, 20 | "RId": [2293000460, 2296125569], 21 | "DOI": "10.1126/science.aaj2038", 22 | "PB": "American Association for the Advancement of Science", 23 | "BT": "a", 24 | "AA": [ 25 | { 26 | "DAuN": "Foo", 27 | "AuId": 2780121452, 28 | "AfN": "columbia university", 29 | "AfId": 78577930, 30 | "S": 1, 31 | }, 32 | {"DAuN": "Bar", "AuId": 2159352281, "AfId": None, "S": 2}, 33 | ], 34 | "F": [ 35 | {"DFN": "Petabyte", "FId": 13600138}, 36 | {"DFN": "Oligonucleotide", "FId": 129312508}, 37 | ], 38 | "J": {"JN": "science", "JId": 3880285}, 39 | "C": {"CN": "foo bar", "CId": 3880285}, 40 | "IA": { 41 | "IndexLength": 111, 42 | "InvertedIndex": { 43 | "In": [0], 44 | "comparative": [1], 45 | "high-throughput": [2], 46 | "sequencing": [3], 47 | "assays,": [4], 48 | "a": [5], 49 | }, 50 | }, 51 | "S": [ 52 | {"U": "https://www.biorxiv.org/content/early/2018/04/02/292706.full.pdf"}, 53 | { 54 | "Ty": 3, 55 | "U": "https://www.biorxiv.org/content/biorxiv/early/2018/05/08/292706.full.pdf", 56 | }, 57 | {"Ty": 1, "U": "https://www.biorxiv.org/content/10.1101/292706v4"}, 58 | {"U": "https://www.biorxiv.org/content/early/2018/05/08/292706"}, 59 | ], 60 | } 61 | 62 | 63 | def test_parse_papers(): 64 | expected_result = { 65 | "id": 2592122940, 66 | "title": "dna fountain enables a robust and efficient storage architecture", 67 | "doi": "10.1126/science.aaj2038", 68 | "prob": 1.81426557, 69 | "publication_type": "1", 70 | "year": 2017, 71 | "date": "2017-03-03", 72 | "original_title": "this is an original title", 73 | "citations": 109, 74 | "bibtex_doc_type": "a", 75 | "references": "[2293000460, 2296125569]", 76 | "publisher": "American Association for the Advancement of Science", 77 | "abstract": "In comparative high-throughput sequencing assays, a", 78 | "source": "https://www.biorxiv.org/content/early/2018/04/02/292706.full.pdf", 79 | } 80 | result = parse_papers(test_example) 81 | 82 | assert result == expected_result 83 | 84 | 85 | def test_parse_journal(): 86 | expected_result = {"id": 3880285, "journal_name": "science", "paper_id": 2592122940} 87 | result = parse_journal(test_example, 2592122940) 88 | 89 | assert result == expected_result 90 | 91 | 92 | def test_parse_conference(): 93 | expected_result = { 94 | "id": 3880285, 95 | "conference_name": "foo bar", 96 | "paper_id": 2592122940, 97 | } 98 | result = parse_conference(test_example, 2592122940) 99 | 100 | assert result == expected_result 101 | 102 | 103 | def test_parse_authors(): 104 | expected_result_authors = [ 105 | {"id": 2780121452, "name": "Foo"}, 106 | {"id": 2159352281, "name": "Bar"}, 107 | ] 108 | expected_result_paper_with_authors = [ 109 | {"paper_id": 2592122940, "author_id": 2780121452, "order": 1}, 110 | {"paper_id": 2592122940, "author_id": 2159352281, "order": 2}, 111 | ] 112 | result_authors, result_paper_with_authors = parse_authors(test_example, 2592122940) 113 | 114 | assert result_authors == expected_result_authors 115 | assert result_paper_with_authors == expected_result_paper_with_authors 116 | 117 | 118 | def test_parse_fields_of_study(): 119 | expected_result_paper_with_fos = [ 120 | {"field_of_study_id": 13600138, "paper_id": 2592122940}, 121 | {"field_of_study_id": 129312508, "paper_id": 2592122940}, 122 | ] 123 | expected_result_fields_of_study = [ 124 | {"id": 13600138, "name": "Petabyte"}, 125 | {"id": 129312508, "name": "Oligonucleotide"}, 126 | ] 127 | result_paper_with_fos, result_fields_of_study = parse_fos(test_example, 2592122940) 128 | 129 | assert expected_result_paper_with_fos == result_paper_with_fos 130 | assert expected_result_fields_of_study == result_fields_of_study 131 | 132 | 133 | def test_parse_affiliations(): 134 | expected_result_affiliations = [ 135 | {"id": 78577930, "affiliation": "columbia university"} 136 | ] 137 | expected_result_author_with_aff = [ 138 | {"affiliation_id": 78577930, "author_id": 2780121452, "paper_id": 2592122940}, 139 | {"affiliation_id": None, "author_id": 2159352281, "paper_id": 2592122940}, 140 | ] 141 | 142 | affiliations, paper_author_aff = parse_affiliations(test_example, 2592122940) 143 | 144 | assert affiliations == expected_result_affiliations 145 | assert paper_author_aff == expected_result_author_with_aff 146 | -------------------------------------------------------------------------------- /orion/core/operators/infer_gender_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | NamesBatchesOperator: Fetches full names from PostgreSQL, removes those with just an initial 3 | and stores them as batches on S3. 4 | 5 | GenderInferenceOperator: Infers the gender of a person by using their name. I am using GenderAPI for this, 6 | which is supposed to be one of the most reliable name to gender inference services. 7 | 8 | """ 9 | import logging 10 | import pandas as pd 11 | from sqlalchemy.sql import exists 12 | from sqlalchemy import create_engine 13 | from sqlalchemy.orm import sessionmaker 14 | from airflow.models import BaseOperator 15 | from airflow.utils.decorators import apply_defaults 16 | from orion.packages.utils.batches import split_batches, put_s3_batch 17 | from orion.packages.utils.s3_utils import load_from_s3 18 | from orion.packages.gender.query_gender_api import query_gender_api, parse_response 19 | from orion.core.orms.mag_orm import Author, AuthorGender 20 | from orion.packages.utils.nlp_utils import clean_name 21 | from botocore.exceptions import ClientError 22 | import toolz 23 | 24 | 25 | class NamesBatchesOperator(BaseOperator): 26 | """Creates batches for parallel processing.""" 27 | 28 | @apply_defaults 29 | def __init__(self, db_config, s3_bucket, prefix, batch_size, *args, **kwargs): 30 | super().__init__(**kwargs) 31 | self.db_config = db_config 32 | self.s3_bucket = s3_bucket 33 | self.prefix = prefix 34 | self.batch_size = batch_size 35 | 36 | def execute(self, context): 37 | # Connect to PostgreSQL DB 38 | engine = create_engine(self.db_config) 39 | Session = sessionmaker(bind=engine) 40 | s = Session() 41 | 42 | # Get author names if they haven't had their name inferred yet. 43 | authors = pd.read_sql( 44 | s.query(Author.id, Author.name) 45 | .filter(~exists().where(Author.id == AuthorGender.id)) 46 | .statement, 47 | s.bind, 48 | ) 49 | 50 | # Process their name and drop missing values 51 | authors["proc_name"] = authors.name.apply(clean_name) 52 | authors = authors.dropna() 53 | 54 | # Group author IDs by full name 55 | grouped_ids = authors.groupby("proc_name")["id"].apply(list) 56 | logging.info(f"Authors passed to GenderAPI: {grouped_ids.shape[0]}") 57 | 58 | # Store (full names, IDs[]) batches on S3 59 | for i, batch in enumerate( 60 | split_batches(grouped_ids.to_dict().items(), self.batch_size) 61 | ): 62 | put_s3_batch(batch, self.s3_bucket, "_".join([self.prefix, str(i)])) 63 | 64 | 65 | class GenderInferenceOperator(BaseOperator): 66 | """Infers gender by name using GenderAPI.""" 67 | 68 | @apply_defaults 69 | def __init__(self, db_config, s3_bucket, prefix, auth_token, *args, **kwargs): 70 | super().__init__(**kwargs) 71 | self.db_config = db_config 72 | self.s3_bucket = s3_bucket 73 | self.prefix = prefix 74 | self.auth_token = auth_token 75 | 76 | def execute(self, context): 77 | # Connect to PostgreSQL DB 78 | engine = create_engine(self.db_config) 79 | Session = sessionmaker(bind=engine) 80 | s = Session() 81 | 82 | # # Fetch all collected author names 83 | collected_full_names = set( 84 | [full_name[0] for full_name in s.query(AuthorGender.full_name)] 85 | ) 86 | try: 87 | # Load queries from S3 88 | queries = load_from_s3(self.s3_bucket, self.prefix) 89 | 90 | # Convert queries to dict 91 | queries = {tup[0]: tup[1] for tup in queries} 92 | 93 | # Filter authors that already exist in the DB 94 | # This is mainly to catch existing keys after task failures or re-runs 95 | queries = { 96 | k: v for k, v in queries.items() if k not in collected_full_names 97 | } 98 | logging.info(f"Total number of queries: {len(queries.keys())}") 99 | 100 | # i = 1 101 | # Bulk query GenderAPI 102 | for i, chunk in enumerate( 103 | toolz.partition_all(100, queries.keys()), start=1 104 | ): 105 | logging.info(f"Chunk: {i}, count: {len(chunk)}") 106 | results = query_gender_api(chunk, self.auth_token) 107 | 108 | if results: 109 | # Parse response 110 | parsed_responses = [] 111 | for result in [result for result in results if result]: 112 | if result["result_found"]: 113 | parsed_response = parse_response(result) 114 | for id_ in queries[parsed_response["full_name"]]: 115 | # Add author id in the response object 116 | parsed_response_copy = parsed_response.copy() 117 | parsed_response_copy.update({"id": id_}) 118 | parsed_responses.append(parsed_response_copy) 119 | 120 | # Insert bulk 121 | s.bulk_insert_mappings(AuthorGender, parsed_responses) 122 | s.commit() 123 | logging.info(f"Committed {len(parsed_responses)} results") 124 | # i += 1 125 | else: 126 | logging.info(f"Chunk {i} failed.") 127 | # i += 1 128 | continue 129 | 130 | logging.info("Done! :)") 131 | except ClientError as err: 132 | logging.info(err) 133 | pass 134 | -------------------------------------------------------------------------------- /orion/core/dags/tutorial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datetime import datetime, timedelta 3 | from airflow import DAG 4 | from airflow.operators.dummy_operator import DummyOperator 5 | from airflow.operators.python_operator import PythonOperator 6 | from orion.packages.utils.s3_utils import create_s3_bucket 7 | from orion.core.operators.mag_parse_task import MagParserOperator, FosFrequencyOperator 8 | from orion.core.operators.mag_collect_task import ( 9 | MagCollectionOperator, 10 | MagFosCollectionOperator, 11 | ) 12 | import orion 13 | from orion.core.operators.topic_filtering_task import ( 14 | FilterTopicsByDistributionOperator, 15 | FilteredTopicsMetadataOperator, 16 | ) 17 | from orion.core.operators.affiliation_type_task import AffiliationTypeOperator 18 | from orion.core.operators.collect_wb_indicators_task import WBIndicatorOperator 19 | from orion.packages.mag.create_tables import create_db_and_tables 20 | from orion.core.operators.open_access_journals_task import OpenAccessJournalOperator 21 | from dotenv import load_dotenv, find_dotenv 22 | import os 23 | 24 | load_dotenv(find_dotenv()) 25 | 26 | default_args = { 27 | "owner": "Orion", 28 | "start_date": datetime(2020, 2, 2), 29 | "depends_on_past": False, 30 | "retries": 0, 31 | } 32 | 33 | DAG_ID = "tutorial" 34 | db_name = orion.config["data"]["db_name"] 35 | DB_CONFIG = os.getenv(db_name) 36 | MAG_API_KEY = os.getenv("mag_api_key") 37 | MAG_OUTPUT_BUCKET = orion.config["s3_buckets"]["mag"] 38 | mag_config = orion.config["data"]["mag"] 39 | query_values = mag_config["query_values"] 40 | entity_name = mag_config["entity_name"] 41 | metadata = mag_config["metadata"] 42 | with_doi = mag_config["with_doi"] 43 | mag_start_date = mag_config["mag_start_date"] 44 | mag_end_date = mag_config["mag_end_date"] 45 | intervals_in_a_year = mag_config["intervals_in_a_year"] 46 | 47 | # topic_filtering 48 | topic_prefix = orion.config["prefix"]["topic"] 49 | topic_bucket = orion.config["s3_buckets"]["topic"] 50 | topic_config = orion.config["topic_filter"] 51 | levels = topic_config["levels"] 52 | percentiles = topic_config["percentiles"] 53 | 54 | # wb indicators 55 | wb_country = orion.config["data"]["wb"]["country"] 56 | wb_end_year = orion.config["data"]["wb"]["end_year"] 57 | wb_indicators = orion.config["data"]["wb"]["indicators"] 58 | wb_table_names = orion.config["data"]["wb"]["table_names"] 59 | year_thresh = orion.config["metrics"]["year"] 60 | 61 | with DAG( 62 | dag_id=DAG_ID, default_args=default_args, schedule_interval=timedelta(days=365) 63 | ) as dag: 64 | 65 | dummy_task = DummyOperator(task_id="start") 66 | 67 | dummy_task_3 = DummyOperator(task_id="world_bank_indicators") 68 | 69 | dummy_task_4 = DummyOperator(task_id="create_s3_buckets") 70 | 71 | dummy_task_5 = DummyOperator(task_id="s3_buckets") 72 | 73 | create_tables = PythonOperator( 74 | task_id="create_tables", 75 | python_callable=create_db_and_tables, 76 | op_kwargs={"db": db_name}, 77 | ) 78 | 79 | create_buckets = [ 80 | PythonOperator( 81 | task_id=bucket, 82 | python_callable=create_s3_bucket, 83 | op_kwargs={"bucket": bucket}, 84 | ) 85 | for bucket in [MAG_OUTPUT_BUCKET, topic_bucket] 86 | ] 87 | 88 | query_mag = MagCollectionOperator( 89 | task_id="query_mag", 90 | output_bucket=MAG_OUTPUT_BUCKET, 91 | subscription_key=MAG_API_KEY, 92 | query_values=query_values, 93 | entity_name=entity_name, 94 | metadata=metadata, 95 | with_doi=with_doi, 96 | mag_start_date=mag_start_date, 97 | mag_end_date=mag_end_date, 98 | intervals_in_a_year=intervals_in_a_year, 99 | ) 100 | 101 | parse_mag = MagParserOperator( 102 | task_id="parse_mag", s3_bucket=MAG_OUTPUT_BUCKET, db_config=DB_CONFIG 103 | ) 104 | 105 | collect_fos = MagFosCollectionOperator( 106 | task_id="collect_fos_metadata", 107 | db_config=DB_CONFIG, 108 | subscription_key=MAG_API_KEY, 109 | ) 110 | 111 | fos_frequency = FosFrequencyOperator(task_id="fos_frequency", db_config=DB_CONFIG) 112 | 113 | topic_filtering = FilterTopicsByDistributionOperator( 114 | task_id="filter_topics", 115 | db_config=DB_CONFIG, 116 | s3_bucket=topic_bucket, 117 | prefix=topic_prefix, 118 | levels=levels, 119 | percentiles=percentiles, 120 | ) 121 | 122 | filtered_topic_metadata = FilteredTopicsMetadataOperator( 123 | task_id="topic_metadata", 124 | db_config=DB_CONFIG, 125 | s3_bucket=topic_bucket, 126 | prefix=topic_prefix, 127 | ) 128 | 129 | aff_types = AffiliationTypeOperator(task_id="affiliation_type", db_config=DB_CONFIG) 130 | 131 | batch_task_wb = [] 132 | for wb_indicator, wb_table_name in zip(wb_indicators, wb_table_names): 133 | task_id = f"{wb_table_name}" 134 | batch_task_wb.append( 135 | WBIndicatorOperator( 136 | task_id=task_id, 137 | db_config=DB_CONFIG, 138 | indicator=wb_indicator, 139 | start_year=year_thresh, 140 | end_year=wb_end_year, 141 | country=wb_country, 142 | table_name=wb_table_name, 143 | ) 144 | ) 145 | 146 | open_access = OpenAccessJournalOperator(task_id="open_access", db_config=DB_CONFIG) 147 | 148 | dummy_task >> create_tables >> query_mag >> parse_mag 149 | dummy_task >> dummy_task_4 >> create_buckets >> dummy_task_5 >> query_mag 150 | parse_mag >> collect_fos >> fos_frequency >> topic_filtering >> filtered_topic_metadata 151 | parse_mag >> aff_types 152 | parse_mag >> open_access 153 | dummy_task >> create_tables >> dummy_task_3 >> batch_task_wb 154 | -------------------------------------------------------------------------------- /orion/packages/geo/tests/test_geocode.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest import mock 3 | 4 | from orion.packages.geo.geocode import parse_response 5 | from orion.packages.geo.geocode import geocoding 6 | from orion.packages.geo.geocode import reverse_geocoding 7 | from orion.packages.geo.geocode import place_by_id 8 | from orion.packages.geo.geocode import place_by_name 9 | 10 | GEOCODE = "https://maps.googleapis.com/maps/api/geocode/json?" 11 | FIND_PLACE = "https://maps.googleapis.com/maps/api/place/findplacefromtext/json?" 12 | PLACE_DETAILS = "https://maps.googleapis.com/maps/api/place/details/json?" 13 | 14 | 15 | @mock.patch("orion.packages.geo.geocode.requests.get", autospec=True) 16 | def test_google_geocoding_api_queries_correctly(mocked_requests): 17 | key = "123" 18 | address = "foo bar" 19 | geocoding(address, key, GEOCODE) 20 | 21 | expected_call_args = mock.call(GEOCODE, params={"address": "foo bar", "key": "123"}) 22 | assert mocked_requests.call_args == expected_call_args 23 | 24 | 25 | @mock.patch("orion.packages.geo.geocode.requests.get", autospec=True) 26 | def test_google_reverse_geocoding_api_queries_correctly(mocked_requests): 27 | key = "123" 28 | lat = 1.2 29 | lng = 2.3 30 | reverse_geocoding(lat, lng, key, GEOCODE) 31 | 32 | expected_call_args = mock.call(GEOCODE, params={"latlng": "1.2,2.3", "key": "123"}) 33 | assert mocked_requests.call_args == expected_call_args 34 | 35 | 36 | @mock.patch("orion.packages.geo.geocode.requests.get", autospec=True) 37 | def test_google_places_api_queries_correctly(mocked_requests): 38 | place = "foo bar" 39 | key = "123" 40 | place_by_name(place, key, FIND_PLACE) 41 | expected_call_args = mock.call( 42 | FIND_PLACE, 43 | params={ 44 | "input": "foo bar", 45 | "fields": "place_id", 46 | "inputtype": "textquery", 47 | "key": "123", 48 | }, 49 | ) 50 | assert mocked_requests.call_args == expected_call_args 51 | 52 | 53 | @mock.patch("orion.packages.geo.geocode.requests.get", autospec=True) 54 | def test_google_places_api_queries_correctly_with_place_ids(mocked_requests): 55 | id = "abc123" 56 | key = "123" 57 | place_by_id(id, key, PLACE_DETAILS) 58 | expected_call_args = mock.call( 59 | PLACE_DETAILS, 60 | params={ 61 | "place_id": "abc123", 62 | "key": "123", 63 | "fields": "address_components,formatted_address,geometry,name,place_id,type,website", 64 | }, 65 | ) 66 | assert mocked_requests.call_args == expected_call_args 67 | 68 | 69 | def test_parse_google_place_api_response(): 70 | api_response = { 71 | "html_attributions": [], 72 | "result": { 73 | "address_components": [ 74 | {"long_name": "441", "short_name": "441", "types": ["subpremise"]}, 75 | { 76 | "long_name": "Metal Box Factory", 77 | "short_name": "Metal Box Factory", 78 | "types": ["premise"], 79 | }, 80 | {"long_name": "30", "short_name": "30", "types": ["street_number"]}, 81 | { 82 | "long_name": "Great Guildford Street", 83 | "short_name": "Great Guildford St", 84 | "types": ["route"], 85 | }, 86 | { 87 | "long_name": "London", 88 | "short_name": "London", 89 | "types": ["postal_town"], 90 | }, 91 | { 92 | "long_name": "Greater London", 93 | "short_name": "Greater London", 94 | "types": ["administrative_area_level_2", "political"], 95 | }, 96 | { 97 | "long_name": "England", 98 | "short_name": "England", 99 | "types": ["administrative_area_level_1", "political"], 100 | }, 101 | { 102 | "long_name": "United Kingdom", 103 | "short_name": "GB", 104 | "types": ["country", "political"], 105 | }, 106 | { 107 | "long_name": "SE1 0HS", 108 | "short_name": "SE1 0HS", 109 | "types": ["postal_code"], 110 | }, 111 | ], 112 | "formatted_address": "441, Metal Box Factory, 30 Great Guildford St, London SE1 0HS, UK", 113 | "geometry": { 114 | "location": {"lat": 51.504589, "lng": -0.09708649999999999}, 115 | "viewport": { 116 | "northeast": {"lat": 51.5059537802915, "lng": -0.09565766970849796}, 117 | "southwest": {"lat": 51.5032558197085, "lng": -0.09835563029150202}, 118 | }, 119 | }, 120 | "name": "Mozilla", 121 | "place_id": "ChIJd7gxxc0EdkgRsxXmeQyR44A", 122 | "types": ["point_of_interest", "establishment"], 123 | "website": "https://www.mozilla.org/contact/spaces/london/", 124 | }, 125 | "status": "OK", 126 | } 127 | 128 | expected_response = { 129 | "lat": 51.504589, 130 | "lng": -0.09708649999999999, 131 | "address": "441, Metal Box Factory, 30 Great Guildford St, London SE1 0HS, UK", 132 | "name": "Mozilla", 133 | "id": "ChIJd7gxxc0EdkgRsxXmeQyR44A", 134 | "types": ["point_of_interest", "establishment"], 135 | "website": "https://www.mozilla.org/contact/spaces/london/", 136 | "postal_town": "London", 137 | "administrative_area_level_2": "Greater London", 138 | "administrative_area_level_1": "England", 139 | "country": "United Kingdom", 140 | } 141 | 142 | assert parse_response(api_response) == expected_response 143 | -------------------------------------------------------------------------------- /orion/core/operators/country_details_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | HomogeniseCountryNamesOperator: Homogenises country names from Google Places API 3 | and the World Bank. It uses a country mapping dictionary from `model_config.yaml`. 4 | 5 | CountryDetailsOperator: Fetches additional country details from the restcountries API. 6 | This includes things such as population, capital, region (continent), sub-region and 7 | country codes. 8 | 9 | """ 10 | import logging 11 | from sqlalchemy import create_engine, distinct 12 | from sqlalchemy.orm import sessionmaker 13 | from airflow.models import BaseOperator 14 | from airflow.utils.decorators import apply_defaults 15 | from orion.core.orms.mag_orm import ( 16 | WorldBankFemaleLaborForce, 17 | WorldBankGovEducation, 18 | WorldBankResearchDevelopment, 19 | WorldBankGDP, 20 | CountryAssociation, 21 | AffiliationLocation, 22 | CountryDetails, 23 | ) 24 | import orion 25 | import requests 26 | from orion.packages.geo.enrich_countries import ( 27 | parse_country_details, 28 | get_country_details, 29 | ) 30 | 31 | google2wb = orion.config["country_name_mapping"]["google2wb"] 32 | google2restcountries = orion.config["country_name_mapping"]["google2restcountries"] 33 | 34 | 35 | class HomogeniseCountryNamesOperator(BaseOperator): 36 | """Homogenises country names from Google Places API and the World Bank.""" 37 | 38 | @apply_defaults 39 | def __init__(self, db_config, country_name_mapping=google2wb, *args, **kwargs): 40 | super().__init__(**kwargs) 41 | self.db_config = db_config 42 | self.country_name_mapping = google2wb 43 | 44 | def execute(self, context): 45 | # Connect to postgresdb 46 | engine = create_engine(self.db_config) 47 | Session = sessionmaker(engine) 48 | s = Session() 49 | 50 | s.query(CountryAssociation).delete() 51 | s.commit() 52 | 53 | # Country names from Google Places API 54 | country_names = [ 55 | country_name[0] 56 | for country_name in s.query(distinct(AffiliationLocation.country)).filter( 57 | (AffiliationLocation.country != None), 58 | (AffiliationLocation.country != ""), 59 | ) 60 | ] 61 | logging.info(f"Countries from Google: {len(country_names)}") 62 | 63 | # Country names from the World Bank 64 | wb_countries = set() 65 | for c1, c2, c3, c4 in zip( 66 | s.query(WorldBankGDP.country), 67 | s.query(WorldBankGovEducation.country), 68 | s.query(WorldBankFemaleLaborForce.country), 69 | s.query(WorldBankResearchDevelopment.country), 70 | ): 71 | wb_countries.update(c1, c2, c3, c4) 72 | 73 | # Match country names 74 | for country_name in country_names: 75 | if country_name in self.country_name_mapping.keys(): 76 | s.add( 77 | CountryAssociation( 78 | google_country=country_name, 79 | wb_country=self.country_name_mapping[country_name], 80 | ) 81 | ) 82 | else: 83 | s.add( 84 | CountryAssociation( 85 | google_country=country_name, wb_country=country_name 86 | ) 87 | ) 88 | s.commit() 89 | 90 | 91 | class CountryDetailsOperator(BaseOperator): 92 | """Fetch country information from restcountries.""" 93 | 94 | @apply_defaults 95 | def __init__( 96 | self, db_config, country_name_mapping=google2restcountries, *args, **kwargs 97 | ): 98 | super().__init__(**kwargs) 99 | self.db_config = db_config 100 | self.country_name_mapping = country_name_mapping 101 | 102 | def execute(self, context): 103 | # Connect to postgresdb 104 | engine = create_engine(self.db_config) 105 | Session = sessionmaker(engine) 106 | s = Session() 107 | 108 | s.query(CountryDetails).delete() 109 | s.commit() 110 | 111 | # Query restcountries API with Google Places country names. 112 | d = {} 113 | for country_name in [ 114 | country_name[0] 115 | for country_name in s.query(CountryAssociation.google_country) 116 | ]: 117 | try: 118 | d[country_name] = get_country_details(country_name) 119 | except requests.exceptions.HTTPError as h: 120 | logging.info(f"{country_name} - {h}: Trying with country_mapping") 121 | try: 122 | d[country_name] = get_country_details( 123 | self.country_name_mapping[country_name] 124 | ) 125 | except requests.exceptions.HTTPError as h: 126 | logging.info(f"Failed: {country_name}") 127 | continue 128 | except KeyError as e: 129 | logging.info(f"{country_name} not in mapping.") 130 | continue 131 | # Parse country info 132 | country_info = [] 133 | for k, v in d.items(): 134 | # These countries are not the first match so we choose `pos=1` 135 | if k in ["India", "United States", "Sudan"]: 136 | parsed_response = parse_country_details(v, pos=1) 137 | else: 138 | parsed_response = parse_country_details(v) 139 | 140 | parsed_response.update({"google_name": k}) 141 | parsed_response.update( 142 | { 143 | "wb_name": s.query(CountryAssociation.wb_country) 144 | .filter(CountryAssociation.google_country == k) 145 | .first()[0] 146 | } 147 | ) 148 | country_info.append(parsed_response) 149 | logging.info(f"Parsed countries: {len(country_info)}") 150 | 151 | s.bulk_insert_mappings(CountryDetails, country_info) 152 | s.commit() 153 | -------------------------------------------------------------------------------- /orion/core/operators/postgresql2es_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | Migrates data from PostgreSQL to Elasticsearch. 3 | 4 | Postgreqsl2ElasticSearchOperator: Creates an index that contains 5 | the following data for every paper: 6 | - original_title 7 | - abstract 8 | - citations 9 | - publication date 10 | - publication year 11 | - field_of_study_id 12 | - field_of_study name 13 | - author name 14 | - author affiliation 15 | 16 | Users have the option to delete the index before uploading documents. The task also 17 | checks if the index exists before creating it. 18 | """ 19 | import logging 20 | from datetime import datetime 21 | import pandas as pd 22 | from sqlalchemy import create_engine 23 | from sqlalchemy.orm import sessionmaker 24 | from airflow.models import BaseOperator 25 | from airflow.utils.decorators import apply_defaults 26 | from orion.core.orms.mag_orm import ( 27 | Paper, 28 | PaperFieldsOfStudy, 29 | FieldOfStudy, 30 | AuthorAffiliation, 31 | Author, 32 | Affiliation, 33 | ) 34 | from elasticsearch_dsl import Index 35 | from orion.core.orms.es_mapping import PaperES 36 | from elasticsearch import helpers 37 | from orion.packages.utils.utils import aws_es_client 38 | 39 | 40 | class Postgreqsl2ElasticSearchOperator(BaseOperator): 41 | """Migrate data from PostgreSQL to Elastic Search.""" 42 | 43 | @apply_defaults 44 | def __init__( 45 | self, 46 | db_config, 47 | es_index, 48 | es_host, 49 | es_port, 50 | region, 51 | erase_es_index, 52 | *args, 53 | **kwargs, 54 | ): 55 | super().__init__(**kwargs) 56 | self.db_config = db_config 57 | self.es_index = es_index 58 | self.es_host = es_host 59 | self.es_port = es_port 60 | self.region = region 61 | self.erase_es_index = erase_es_index 62 | 63 | def execute(self, context): 64 | # Connect to postgresql db 65 | engine = create_engine(self.db_config) 66 | Session = sessionmaker(engine) 67 | s = Session() 68 | 69 | # Read MAG data 70 | mag = pd.read_sql(s.query(Paper).statement, s.bind) 71 | 72 | # Read Fields of study and merge them with papers 73 | fos = pd.read_sql(s.query(FieldOfStudy).statement, s.bind) 74 | pfos = pd.read_sql(s.query(PaperFieldsOfStudy).statement, s.bind) 75 | pfos = pfos.merge(fos, left_on="field_of_study_id", right_on="id") 76 | mag = mag.merge( 77 | pfos[["paper_id", "field_of_study_id", "name"]], 78 | left_on="id", 79 | right_on="paper_id", 80 | ) 81 | 82 | author_aff = pd.read_sql(s.query(AuthorAffiliation).statement, s.bind) 83 | author = pd.read_sql(s.query(Author).statement, s.bind) 84 | affiliation = pd.read_sql(s.query(Affiliation).statement, s.bind) 85 | author_aff = ( 86 | author_aff.merge(author, left_on="author_id", right_on="id") 87 | .merge(affiliation, how="left", left_on="affiliation_id", right_on="id")[ 88 | ["affiliation", "name", "paper_id"] 89 | ] 90 | .fillna("") 91 | ) 92 | author_aff = pd.DataFrame( 93 | author_aff.groupby("paper_id")["name"].apply(list) 94 | ).merge( 95 | pd.DataFrame(author_aff.groupby("paper_id")["affiliation"].apply(list)), 96 | left_index=True, 97 | right_index=True, 98 | ) 99 | 100 | # Groupby fos name and fos id and merge them in a table 101 | fos_names = pd.DataFrame( 102 | mag.groupby( 103 | ["paper_id", "year", "date", "original_title", "abstract", "citations"] 104 | )["name"].apply(list) 105 | ) 106 | fos_ids = pd.DataFrame( 107 | mag.groupby( 108 | ["paper_id", "year", "date", "original_title", "abstract", "citations"] 109 | )["field_of_study_id"].apply(list) 110 | ) 111 | 112 | table = ( 113 | fos_names.merge(fos_ids, left_index=True, right_index=True) 114 | .merge(author_aff, how="left", left_index=True, right_index=True) 115 | .rename( 116 | index=str, columns={"name_x": "field_of_study", "name_y": "author_name"} 117 | ) 118 | ) 119 | 120 | # Setup ES connection 121 | es = aws_es_client(self.es_host, self.es_port, self.region) 122 | 123 | # Delete index if needed (usually not) 124 | if self.erase_es_index: 125 | Index(self.es_index, using=es).delete() 126 | logging.info(f"Deleted ES index: {self.es_index}") 127 | 128 | # Create the index if it does not exist 129 | if not Index(self.es_index, using=es).exists(): 130 | PaperES.init(using=es) 131 | logging.info(f"Created ES index: {self.es_index}") 132 | 133 | def _docs_for_load(table): 134 | """Indexes documents in bulk.""" 135 | for ( 136 | (paper_id, year, date, title, abstract, citations), 137 | row, 138 | ) in table.iterrows(): 139 | yield PaperES( 140 | meta={"id": paper_id}, 141 | year=datetime.strptime(date, "%Y-%m-%d").date().year, 142 | publication_date=datetime.strptime(date, "%Y-%m-%d").date(), 143 | original_title=title, 144 | abstract=abstract, 145 | citations=citations, 146 | fields_of_study=[ 147 | {"name": name, "id": id_} 148 | for name, id_ in zip( 149 | row["field_of_study"], row["field_of_study_id"] 150 | ) 151 | ], 152 | author=[ 153 | {"name": name, "affiliation": aff} 154 | for name, aff in zip(row["author_name"], row["affiliation"]) 155 | ], 156 | ).to_dict(include_meta=True) 157 | 158 | # Increase timeout from 10 to 180 159 | helpers.bulk(es, _docs_for_load(table), request_timeout=180) 160 | -------------------------------------------------------------------------------- /orion/packages/geo/tests/test_enrich_countries.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest import mock 3 | 4 | from orion.packages.geo.enrich_countries import get_country_details 5 | from orion.packages.geo.enrich_countries import parse_country_details 6 | 7 | 8 | @mock.patch("orion.packages.geo.enrich_countries.requests.get", autospec=True) 9 | def test_get_country_details(mock_requests): 10 | get_country_details("Laos") 11 | 12 | expected_call_args = mock.call("https://restcountries.eu/rest/v2/name/Laos") 13 | assert mock_requests.call_args == expected_call_args 14 | 15 | 16 | def test_parse_rest_countries_api_response(): 17 | data = [ 18 | { 19 | "name": "United States Minor Outlying Islands", 20 | "topLevelDomain": [".us"], 21 | "alpha2Code": "UM", 22 | "alpha3Code": "UMI", 23 | "callingCodes": [""], 24 | "capital": "", 25 | "altSpellings": ["UM"], 26 | "region": "Americas", 27 | "subregion": "Northern America", 28 | "population": 300, 29 | "latlng": [], 30 | "demonym": "American", 31 | "area": None, 32 | "gini": None, 33 | "timezones": ["UTC-11:00", "UTC-10:00", "UTC+12:00"], 34 | "borders": [], 35 | "nativeName": "United States Minor Outlying Islands", 36 | "numericCode": "581", 37 | "currencies": [ 38 | {"code": "USD", "name": "United States Dollar", "symbol": "$"} 39 | ], 40 | "languages": [ 41 | { 42 | "iso639_1": "en", 43 | "iso639_2": "eng", 44 | "name": "English", 45 | "nativeName": "English", 46 | } 47 | ], 48 | "translations": { 49 | "de": "Kleinere Inselbesitzungen der Vereinigten Staaten", 50 | "es": "Islas Ultramarinas Menores de Estados Unidos", 51 | "fr": "Îles mineures éloignées des États-Unis", 52 | "ja": "合衆国領有小離島", 53 | "it": "Isole minori esterne degli Stati Uniti d'America", 54 | "br": "Ilhas Menores Distantes dos Estados Unidos", 55 | "pt": "Ilhas Menores Distantes dos Estados Unidos", 56 | "nl": "Kleine afgelegen eilanden van de Verenigde Staten", 57 | "hr": "Mali udaljeni otoci SAD-a", 58 | "fa": "جزایر کوچک حاشیه\u200cای ایالات متحده آمریکا", 59 | }, 60 | "flag": "https://restcountries.eu/data/umi.svg", 61 | "regionalBlocs": [], 62 | "cioc": "", 63 | }, 64 | { 65 | "name": "United States of America", 66 | "topLevelDomain": [".us"], 67 | "alpha2Code": "US", 68 | "alpha3Code": "USA", 69 | "callingCodes": ["1"], 70 | "capital": "Washington, D.C.", 71 | "altSpellings": ["US", "USA", "United States of America"], 72 | "region": "Americas", 73 | "subregion": "Northern America", 74 | "population": 323947000, 75 | "latlng": [38.0, -97.0], 76 | "demonym": "American", 77 | "area": 9629091.0, 78 | "gini": 48.0, 79 | "timezones": [ 80 | "UTC-12:00", 81 | "UTC-11:00", 82 | "UTC-10:00", 83 | "UTC-09:00", 84 | "UTC-08:00", 85 | "UTC-07:00", 86 | "UTC-06:00", 87 | "UTC-05:00", 88 | "UTC-04:00", 89 | "UTC+10:00", 90 | "UTC+12:00", 91 | ], 92 | "borders": ["CAN", "MEX"], 93 | "nativeName": "United States", 94 | "numericCode": "840", 95 | "currencies": [ 96 | {"code": "USD", "name": "United States dollar", "symbol": "$"} 97 | ], 98 | "languages": [ 99 | { 100 | "iso639_1": "en", 101 | "iso639_2": "eng", 102 | "name": "English", 103 | "nativeName": "English", 104 | } 105 | ], 106 | "translations": { 107 | "de": "Vereinigte Staaten von Amerika", 108 | "es": "Estados Unidos", 109 | "fr": "États-Unis", 110 | "ja": "アメリカ合衆国", 111 | "it": "Stati Uniti D'America", 112 | "br": "Estados Unidos", 113 | "pt": "Estados Unidos", 114 | "nl": "Verenigde Staten", 115 | "hr": "Sjedinjene Američke Države", 116 | "fa": "ایالات متحده آمریکا", 117 | }, 118 | "flag": "https://restcountries.eu/data/usa.svg", 119 | "regionalBlocs": [ 120 | { 121 | "acronym": "NAFTA", 122 | "name": "North American Free Trade Agreement", 123 | "otherAcronyms": [], 124 | "otherNames": [ 125 | "Tratado de Libre Comercio de América del Norte", 126 | "Accord de Libre-échange Nord-Américain", 127 | ], 128 | } 129 | ], 130 | "cioc": "USA", 131 | }, 132 | ] 133 | 134 | expected_result_pos_0 = { 135 | "alpha2Code": "UM", 136 | "alpha3Code": "UMI", 137 | "name": "United States Minor Outlying Islands", 138 | "capital": "", 139 | "region": "Americas", 140 | "subregion": "Northern America", 141 | "population": 300, 142 | } 143 | 144 | expected_result_pos_1 = { 145 | "alpha2Code": "US", 146 | "alpha3Code": "USA", 147 | "name": "United States of America", 148 | "capital": "Washington, D.C.", 149 | "region": "Americas", 150 | "subregion": "Northern America", 151 | "population": 323947000, 152 | } 153 | 154 | result_pos_0 = parse_country_details(data, pos=0) 155 | result_pos_1 = parse_country_details(data, pos=1) 156 | 157 | assert result_pos_0 == expected_result_pos_0 158 | assert result_pos_1 == expected_result_pos_1 159 | -------------------------------------------------------------------------------- /orion/packages/mag/parsing_mag_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from orion.packages.utils.utils import inverted2abstract 4 | 5 | 6 | def parse_papers(response): 7 | """Parse paper information from a MAG API response. 8 | 9 | Args: 10 | response (json): Response from MAG API in JSON format. Contains paper information. 11 | 12 | Returns: 13 | d (dict): Paper metadata. 14 | 15 | """ 16 | d = {} 17 | d["id"] = response["Id"] 18 | d["prob"] = response["prob"] 19 | d["title"] = response["Ti"] 20 | d["publication_type"] = response["Pt"] 21 | d["year"] = response["Y"] 22 | d["date"] = response["D"] 23 | d["citations"] = response["CC"] 24 | d["original_title"] = response["DN"] 25 | try: 26 | d["doi"] = response["DOI"] 27 | except KeyError as e: 28 | d["doi"] = np.nan 29 | try: 30 | # Get the first URL - MAG sorts them by relevance 31 | d["source"] = response["S"][0]["U"] 32 | except KeyError as e: 33 | d["source"] = np.nan 34 | 35 | try: 36 | d["bibtex_doc_type"] = response["BT"] 37 | except KeyError as e: 38 | # logging.info(f"{response['Id']}: {e}") 39 | d["bibtex_doc_type"] = np.nan 40 | try: 41 | d["abstract"] = inverted2abstract(response["IA"]) 42 | except KeyError as e: 43 | # logging.info(f"{response['Id']}: {e}") 44 | d["abstract"] = np.nan 45 | try: 46 | d["references"] = json.dumps(response["RId"]) 47 | except KeyError as e: 48 | # logging.info(f"{response['Id']}: {e}") 49 | d["references"] = np.nan 50 | try: 51 | d["publisher"] = response["PB"] 52 | except KeyError as e: 53 | # logging.info(f"{response['Id']}: {e}") 54 | d["publisher"] = np.nan 55 | 56 | return d 57 | 58 | 59 | def parse_journal(response, paper_id): 60 | """Parse journal information from a MAG API response. 61 | 62 | Args: 63 | response (json): Response from MAG API in JSON format. Contains all paper information. 64 | paper_id (int): Paper ID. 65 | 66 | Returns: 67 | d (dict): Journal details. 68 | 69 | """ 70 | return { 71 | "id": response["J"]["JId"], 72 | "journal_name": response["J"]["JN"], 73 | "paper_id": paper_id, 74 | } 75 | 76 | 77 | def parse_conference(response, paper_id): 78 | """Parse conference information from a MAG API response. 79 | 80 | Args: 81 | response (json): Response from MAG API in JSON format. Contains all paper information. 82 | paper_id (int): Paper ID. 83 | 84 | Returns: 85 | d (dict): Conference details. 86 | 87 | """ 88 | return { 89 | "id": response["C"]["CId"], 90 | "conference_name": response["C"]["CN"], 91 | "paper_id": paper_id, 92 | } 93 | 94 | 95 | def parse_authors(response, paper_id): 96 | """Parse author information from a MAG API response. 97 | 98 | Args: 99 | response (json): Response from MAG API in JSON format. Contains all paper information. 100 | paper_id (int): Paper ID. 101 | 102 | Returns: 103 | authors (:obj:`list` of :obj:`dict`): List of dictionaries with author information. 104 | There's one dictionary per author. 105 | paper_with_authors (:obj:`list` of :obj:`dict`): Matching paper and author IDs. 106 | 107 | """ 108 | authors = [] 109 | paper_with_authors = [] 110 | for author in response["AA"]: 111 | # mag_paper_authors 112 | paper_with_authors.append( 113 | {"paper_id": paper_id, "author_id": author["AuId"], "order": author["S"]} 114 | ) 115 | # mag_authors 116 | authors.append({"id": author["AuId"], "name": author["DAuN"]}) 117 | 118 | return authors, paper_with_authors 119 | 120 | 121 | def parse_fos(response, paper_id): 122 | """Parse the fields of study of a paper from a MAG API response. 123 | 124 | Args: 125 | response (json): Response from MAG API in JSON format. Contains all paper information. 126 | paper_id (int): Paper ID. 127 | 128 | Returns: 129 | fields_of_study (:obj:`list` of :obj:`dict`): List of dictionaries with fields of study information. 130 | There's one dictionary per field of study. 131 | paper_with_fos (:obj:`list` of :obj:`dict`): Matching fields of study and paper IDs. 132 | 133 | """ 134 | # two outputs: fos_id with fos_name, fos_id with paper_id 135 | paper_with_fos = [] 136 | fields_of_study = [] 137 | for fos in response["F"]: 138 | # mag_fields_of_study 139 | fields_of_study.append({"id": fos["FId"], "name": fos["DFN"]}) 140 | # mag_paper_fields_of_study 141 | paper_with_fos.append({"field_of_study_id": fos["FId"], "paper_id": paper_id}) 142 | 143 | return paper_with_fos, fields_of_study 144 | 145 | 146 | def parse_affiliations(response, paper_id): 147 | """Parse the author affiliations from a MAG API response. 148 | 149 | Args: 150 | response (json): Response from MAG API in JSON format. Contains all paper information. 151 | paper_id (int): Paper ID. 152 | 153 | Returns: 154 | affiliations (:obj:`list` of :obj:`dict`): List of dictionaries with affiliation information. 155 | There's one dictionary per field of study. 156 | author_with_aff (:obj:`list` of :obj:`dict`): Matching affiliation and author IDs. 157 | 158 | """ 159 | affiliations = [] 160 | paper_author_aff = [] 161 | for aff in response["AA"]: 162 | if aff["AfId"]: 163 | # mag_author_affiliation 164 | paper_author_aff.append( 165 | { 166 | "affiliation_id": aff["AfId"], 167 | "author_id": aff["AuId"], 168 | "paper_id": paper_id, 169 | } 170 | ) 171 | # mag_affiliation 172 | affiliations.append({"id": aff["AfId"], "affiliation": aff["AfN"]}) 173 | else: 174 | paper_author_aff.append( 175 | { 176 | "affiliation_id": None, 177 | "author_id": aff["AuId"], 178 | "paper_id": paper_id, 179 | } 180 | ) 181 | return affiliations, paper_author_aff 182 | -------------------------------------------------------------------------------- /orion/packages/metrics/tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pandas as pd 3 | import numpy as np 4 | 5 | from orion.packages.metrics.metrics import calculate_rca_by_sum 6 | from orion.packages.metrics.metrics import calculate_rca_by_count 7 | from orion.packages.metrics.metrics import simpson 8 | from orion.packages.metrics.metrics import dominance 9 | from orion.packages.metrics.metrics import shannon 10 | from orion.packages.metrics.metrics import _validate_counts_vector 11 | from orion.packages.metrics.metrics import observed_otus 12 | from orion.packages.metrics.metrics import simpson_e 13 | from orion.packages.metrics.metrics import enspie 14 | 15 | 16 | def test_calculate_rca_by_sum_calculates_correct_results(): 17 | data = pd.DataFrame( 18 | { 19 | "field_of_study_id": [123, 13, 13, 123, 123, 123], 20 | "country": ["UK", "UK", "US", "US", "US", "AU"], 21 | "paper_id": [1, 2, 3, 4, 5, 6], 22 | "citations": [10, 2, 7, 78, 32, 1], 23 | "year": ["2018", "2018", "2018", "2018", "2018", "2018"], 24 | } 25 | ) 26 | 27 | expected_result = pd.DataFrame( 28 | { 29 | "country": ["UK", "US"], 30 | "year": ["2018", "2018"], 31 | "citations": [2.407407, 0.864198], 32 | } 33 | ).set_index(["country", "year"]) 34 | 35 | result = pd.DataFrame( 36 | calculate_rca_by_sum( 37 | data, 38 | entity_column="country", 39 | commodity=13, 40 | value="citations", 41 | paper_thresh=1, 42 | year_thresh="2013", 43 | ) 44 | ) 45 | 46 | pd.testing.assert_frame_equal( 47 | expected_result, result, check_exact=False, check_less_precise=3 48 | ) 49 | 50 | 51 | def test_calculate_rca_by_count_calculates_correct_results(): 52 | data = pd.DataFrame( 53 | { 54 | "field_of_study_id": [123, 13, 13, 123, 123, 123], 55 | "country": ["UK", "UK", "US", "US", "US", "AU"], 56 | "paper_id": [1, 2, 3, 4, 5, 6], 57 | "citations": [10, 2, 7, 78, 32, 1], 58 | "year": ["2018", "2018", "2018", "2018", "2018", "2018"], 59 | } 60 | ) 61 | 62 | expected_result = pd.DataFrame( 63 | {"country": ["UK", "US"], "year": ["2018", "2018"], "paper_id": [1.5, 1]} 64 | ).set_index(["country", "year"]) 65 | 66 | result = pd.DataFrame( 67 | calculate_rca_by_count( 68 | data, 69 | entity_column="country", 70 | commodity=13, 71 | paper_thresh=1, 72 | year_thresh="2013", 73 | ) 74 | ) 75 | 76 | pd.testing.assert_frame_equal( 77 | expected_result, result, check_exact=False, check_less_precise=3 78 | ) 79 | 80 | 81 | def test_validate_counts_vector_list(): 82 | obs = _validate_counts_vector([0, 2, 1, 3]) 83 | 84 | np.testing.assert_array_equal(obs, np.array([0, 2, 1, 3])) 85 | assert obs.dtype == int 86 | 87 | 88 | def test_validate_counts_vector_numpy_array(): 89 | # numpy array (no copy made) 90 | data = np.array([0, 2, 1, 3]) 91 | obs = _validate_counts_vector(data) 92 | 93 | np.testing.assert_array_equal(obs, data) 94 | assert obs.dtype == int 95 | 96 | 97 | def test_validate_counts_vector_single_element(): 98 | obs = _validate_counts_vector([42]) 99 | 100 | np.testing.assert_array_equal(obs, np.array([42])) 101 | assert obs.dtype == int 102 | assert obs.shape == (1,) 103 | 104 | 105 | def test_validate_counts_vector_suppress_casting_to_int(): 106 | obs = _validate_counts_vector([42.2, 42.1, 0], suppress_cast=True) 107 | 108 | np.testing.assert_array_equal(obs, np.array([42.2, 42.1, 0])) 109 | assert obs.dtype == float 110 | 111 | 112 | def test_validate_counts_vector_all_zeros(): 113 | obs = _validate_counts_vector([0, 0, 0]) 114 | 115 | np.testing.assert_array_equal(obs, np.array([0, 0, 0])) 116 | assert obs.dtype == int 117 | 118 | 119 | def test_validate_counts_vector_all_zeros_single_value(): 120 | obs = _validate_counts_vector([0]) 121 | 122 | np.testing.assert_array_equal(obs, np.array([0])) 123 | assert obs.dtype == int 124 | 125 | 126 | def test_validate_counts_vector_invalid_input_wrong_dtype(): 127 | with pytest.raises(Exception): 128 | _validate_counts_vector([0, 2, 1.2, 3]) 129 | 130 | 131 | def test_validate_counts_vector_invalid_input_wrong_number_of_dimensions(): 132 | with pytest.raises(Exception): 133 | _validate_counts_vector([[0, 2, 1, 3], [4, 5, 6, 7]]) 134 | 135 | 136 | def test_validate_counts_vector_invalid_input_wrong_number_of_dimensions_scalar(): 137 | with pytest.raises(Exception): 138 | _validate_counts_vector(1) 139 | 140 | 141 | def test_validate_counts_vector_invalid_input_negative_values(): 142 | with pytest.raises(Exception): 143 | _validate_counts_vector([0, 0, 2, -1, 3]) 144 | 145 | 146 | def test_dominance(): 147 | assert dominance(np.array([5])) == 1 148 | assert pytest.approx(dominance(np.array([1, 0, 2, 5, 2])), 0.34) 149 | 150 | 151 | def test_shannon(): 152 | assert shannon(np.array([5])) == 0 153 | assert shannon(np.array([5, 5])) == 1 154 | assert shannon(np.array([1, 1, 1, 1, 0])) == 2 155 | 156 | 157 | def test_simpson(): 158 | assert pytest.approx(simpson(np.array([1, 0, 2, 5, 2])), 0.66) 159 | assert pytest.approx(simpson(np.array([5])), 0) 160 | 161 | 162 | def test_observed_otus(): 163 | obs = observed_otus(np.array([4, 3, 4, 0, 1, 0, 2])) 164 | assert obs == 5 165 | 166 | obs = observed_otus(np.array([0, 0, 0])) 167 | assert obs == 0 168 | 169 | obs = observed_otus(np.array([0, 1, 1, 4, 2, 5, 2, 4, 1, 2])) 170 | assert obs == 9 171 | 172 | 173 | def test_enspie(): 174 | # Totally even community should have ENS_pie = number of OTUs. 175 | assert pytest.approx(enspie(np.array([1, 1, 1, 1, 1, 1])), 6) 176 | assert pytest.approx(enspie(np.array([13, 13, 13, 13])), 4) 177 | 178 | # Hand calculated. 179 | arr = np.array([1, 41, 0, 0, 12, 13]) 180 | exp = 1 / ((arr / arr.sum()) ** 2).sum() 181 | np.testing.assert_almost_equal(enspie(arr), exp) 182 | 183 | # Using dominance. 184 | exp = 1 / dominance(arr) 185 | np.testing.assert_almost_equal(enspie(arr), exp) 186 | 187 | arr = np.array([1, 0, 2, 5, 2]) 188 | exp = 1 / dominance(arr) 189 | np.testing.assert_array_almost_equal(enspie(arr), exp) 190 | -------------------------------------------------------------------------------- /orion/packages/utils/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import chain, combinations 2 | from collections import OrderedDict, Counter 3 | import numpy as np 4 | from elasticsearch import Elasticsearch, RequestsHttpConnection 5 | from requests_aws4auth import AWS4Auth 6 | import boto3 7 | from datetime import datetime 8 | 9 | 10 | def unique_dicts(d): 11 | """Removes duplicate dictionaries from a list. 12 | 13 | Args: 14 | d (:obj:`list` of :obj:`dict`): List of dictionaries with the same keys. 15 | 16 | Returns 17 | (:obj:`list` of :obj:`dict`) 18 | 19 | """ 20 | return [dict(y) for y in set(tuple(x.items()) for x in d)] 21 | 22 | 23 | def unique_dicts_by_value(d, key): 24 | """Removes duplicate dictionaries from a list by filtering one of the key values. 25 | 26 | Args: 27 | d (:obj:`list` of :obj:`dict`): List of dictionaries with the same keys. 28 | 29 | Returns 30 | (:obj:`list` of :obj:`dict`) 31 | 32 | """ 33 | return list({v[key]: v for v in d}.values()) 34 | 35 | 36 | def flatten_lists(lst): 37 | """Unpacks nested lists into one list of elements. 38 | 39 | Args: 40 | lst (:obj:`list` of :obj:`list`) 41 | 42 | Returns 43 | (list) 44 | 45 | """ 46 | return list(chain(*lst)) 47 | 48 | 49 | def dict2psql_format(d): 50 | """Transform a dictionary with pandas Series to a list of dictionaries 51 | in order to add it in PostgreSQL. 52 | 53 | Args: 54 | d (dict): Dictionary with pandas Series, usually containing RCA measurement. 55 | 56 | Returns: 57 | (:obj:`list` of :obj:`dict`) 58 | 59 | """ 60 | return flatten_lists( 61 | [ 62 | [ 63 | { 64 | "entity": idx[0], 65 | "year": idx[1], 66 | "rca_sum": elem, 67 | "field_of_study_id": int(fos), 68 | } 69 | for idx, elem in series.iteritems() 70 | ] 71 | for fos, series in d.items() 72 | ] 73 | ) 74 | 75 | 76 | def inverted2abstract(obj): 77 | """Transforms an inverted abstract to abstract. 78 | 79 | Args: 80 | obj (json): Inverted Abstract. 81 | 82 | Returns: 83 | (str): Formatted abstract. 84 | 85 | """ 86 | if isinstance(obj, dict): 87 | inverted_index = obj["InvertedIndex"] 88 | d = {} 89 | for k, v in inverted_index.items(): 90 | if len(v) == 1: 91 | d[v[0]] = k 92 | else: 93 | for idx in v: 94 | d[idx] = k 95 | 96 | return " ".join([v for _, v in OrderedDict(sorted(d.items())).items()]).replace( 97 | "\x00", "" 98 | ) 99 | else: 100 | return np.nan 101 | 102 | 103 | def cooccurrence_graph(elements): 104 | """Creates a cooccurrence table from a nested list. 105 | 106 | Args: 107 | elements (:obj:`list` of :obj:`list`): Nested list. 108 | 109 | Returns: 110 | (`collections.Counter`) of the form Counter({('country_a, country_b), weight}) 111 | 112 | """ 113 | # Get a list of all of the combinations you have 114 | expanded = [tuple(combinations(d, 2)) for d in elements] 115 | expanded = chain(*expanded) 116 | 117 | # Sort the combinations so that A,B and B,A are treated the same 118 | expanded = [tuple(sorted(d)) for d in expanded] 119 | 120 | # count the combinations 121 | return Counter(expanded) 122 | 123 | 124 | def get_all_children(df, topics, lvl=1): 125 | """Traverses the Fields of Study tree to collect all the children FoS. For example, 126 | given a level 1 FoS, it will fetch all the level 2 children (A), the children of A, 127 | the children of the children of A [...] till it reaches the lowest level. 128 | 129 | Args: 130 | df (`pd.DataFrame`): Table with FoS IDs and their children. 131 | topics (:obj:`list` of int | int): Initially, it receives a single FoS. Then a list. 132 | lvl (int): Level of the initial FoS. 133 | 134 | Returns: 135 | t (:obj:`list` of int) 136 | 137 | """ 138 | # For the first pass of the recursion, put the topic in a list 139 | if not isinstance(topics, list): 140 | topics = [topics] 141 | 142 | t = [] 143 | t.extend(topics) 144 | t.extend( 145 | flatten_lists( 146 | [ 147 | df[df.id == id_]["child_id"].values[0] 148 | for id_ in topics 149 | if df[df.id == id_]["child_id"].values[0] is not None 150 | and df[df.id == id_]["child_id"].values[0] 151 | ] 152 | ) 153 | ) 154 | 155 | if lvl == 5: 156 | # t.remove(t[0]) 157 | return t 158 | else: 159 | return get_all_children(df, t, lvl + 1) 160 | 161 | 162 | def average_vectors(vectors): 163 | """Averages vectors. 164 | 165 | Args: 166 | vectors (:obj:`list` of `numpy.array`) 167 | 168 | Returns: 169 | (numpy.ndarray) Average of the vectors of the shape. 170 | 171 | """ 172 | return np.mean([v for v in vectors], axis=0) 173 | 174 | 175 | def aws_es_client(host, port, region): 176 | """Create a client with IAM based authentication on AWS. 177 | Boto3 will fetch the AWS credentials. 178 | 179 | Args: 180 | host (str): AWS ES domain. 181 | port (int): AWS ES port (default: 443). 182 | region (str): AWS ES region. 183 | 184 | Returns: 185 | es (elasticsearch.client.Elasticsearch): Authenticated AWS client. 186 | 187 | """ 188 | credentials = boto3.Session().get_credentials() 189 | awsauth = AWS4Auth(credentials.access_key, credentials.secret_key, region, "es") 190 | 191 | es = Elasticsearch( 192 | hosts=[{"host": host, "port": port}], 193 | http_auth=awsauth, 194 | use_ssl=True, 195 | verify_certs=True, 196 | connection_class=RequestsHttpConnection, 197 | ) 198 | 199 | return es 200 | 201 | 202 | def str2datetime(input_date): 203 | """Transform a string to datetime object. 204 | 205 | Args: 206 | input_date (str): String date of the format Y-m-d. It can 207 | also be 'today' which will return today's date. 208 | 209 | Returns: 210 | (`datetime.datetime`) 211 | 212 | """ 213 | if input_date == "today": 214 | return datetime.today() 215 | else: 216 | return datetime.strptime(input_date, "%Y-%m-%d") 217 | 218 | 219 | def date_range(start, end, intv): 220 | """Splits a date range into intervals. 221 | 222 | Args: 223 | start (str): Start date of the format (Y-m-d). 224 | intv (int): Number of intervals. 225 | end (str): End date of the format (Y-m-d). 226 | 227 | Returns: 228 | (:obj:`generator` of `str`) Dates with the (Y-m-d) format. 229 | 230 | """ 231 | diff = (end - start) / intv 232 | for i in range(intv): 233 | yield (start + diff * i).strftime("%Y-%m-%d") 234 | yield end.strftime("%Y-%m-%d") 235 | -------------------------------------------------------------------------------- /orion/core/operators/mag_parse_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | MagParserOperator: Fetches MAG responses (pickle files on S3), parses them and stores them in a PostgreSQL database. 3 | FosFrequencyOperator: Fetches all the Fields of Study from one table, calculates their frequency and stores them in another. 4 | 5 | """ 6 | import logging 7 | from sqlalchemy.orm.exc import NoResultFound 8 | from sqlalchemy import create_engine, func 9 | from sqlalchemy.orm import sessionmaker 10 | from airflow.models import BaseOperator 11 | from airflow.utils.decorators import apply_defaults 12 | from orion.packages.utils.s3_utils import load_from_s3, s3_bucket_obj 13 | from orion.packages.utils.utils import ( 14 | unique_dicts, 15 | unique_dicts_by_value, 16 | flatten_lists, 17 | ) 18 | from orion.packages.mag.parsing_mag_data import ( 19 | parse_affiliations, 20 | parse_authors, 21 | parse_fos, 22 | parse_journal, 23 | parse_papers, 24 | parse_conference, 25 | ) 26 | from orion.core.orms.mag_orm import ( 27 | Paper, 28 | Journal, 29 | Author, 30 | AuthorAffiliation, 31 | Affiliation, 32 | PaperAuthor, 33 | PaperFieldsOfStudy, 34 | FieldOfStudy, 35 | FosMetadata, 36 | Conference, 37 | ) 38 | 39 | 40 | class MagParserOperator(BaseOperator): 41 | """Parses files from S3 that contain MAG paper information.""" 42 | 43 | # template_fields = [''] 44 | def __init__(self, s3_bucket, db_config, *args, **kwargs): 45 | super().__init__(**kwargs) 46 | self.s3_bucket = s3_bucket 47 | self.db_config = db_config 48 | 49 | def execute(self, context): 50 | # Connect to postgresql 51 | engine = create_engine(self.db_config) 52 | Session = sessionmaker(bind=engine) 53 | s = Session() 54 | 55 | # Collect IDs from tables to ensure we're not inserting duplicates 56 | paper_ids = {id_[0] for id_ in s.query(Paper.id)} 57 | author_ids = {id_[0] for id_ in s.query(Author.id)} 58 | fos_ids = {id_[0] for id_ in s.query(FieldOfStudy.id)} 59 | aff_ids = {id_[0] for id_ in s.query(Affiliation.id)} 60 | 61 | # Read data from S3 62 | data = [] 63 | for obj in s3_bucket_obj(self.s3_bucket): 64 | data.extend(load_from_s3(self.s3_bucket, obj.key.split(".")[0])) 65 | logging.info(f"Number of collected papers: {len(data)}") 66 | 67 | # Remove duplicates and keep only papers that are not already in the mag_papers table. 68 | data = [ 69 | d for d in unique_dicts_by_value(data, "Id") if d["Id"] not in paper_ids 70 | ] 71 | logging.info(f"Number of unique papers not existing in DB: {len(data)}") 72 | 73 | papers = [parse_papers(response) for response in data] 74 | logging.info(f"Completed parsing papers: {len(papers)}") 75 | 76 | journals = [ 77 | parse_journal(response, response["Id"]) 78 | for response in data 79 | if "J" in response.keys() 80 | ] 81 | logging.info(f"Completed parsing journals: {len(journals)}") 82 | 83 | conferences = [ 84 | parse_conference(response, response["Id"]) 85 | for response in data 86 | if "C" in response.keys() 87 | ] 88 | logging.info(f"Completed parsing conferences: {len(conferences)}") 89 | 90 | # Parse author information 91 | items = [parse_authors(response, response["Id"]) for response in data] 92 | authors = [ 93 | d 94 | for d in unique_dicts_by_value( 95 | flatten_lists([item[0] for item in items]), "id" 96 | ) 97 | if d["id"] not in author_ids 98 | ] 99 | 100 | paper_with_authors = unique_dicts(flatten_lists([item[1] for item in items])) 101 | logging.info(f"Completed parsing authors: {len(authors)}") 102 | logging.info( 103 | f"Completed parsing papers_with_authors: {len(paper_with_authors)}" 104 | ) 105 | 106 | # Parse Fields of Study 107 | items = [ 108 | parse_fos(response, response["Id"]) 109 | for response in data 110 | if "F" in response.keys() 111 | ] 112 | paper_with_fos = unique_dicts(flatten_lists([item[0] for item in items])) 113 | fields_of_study = [ 114 | d 115 | for d in unique_dicts(flatten_lists([item[1] for item in items])) 116 | if d["id"] not in fos_ids 117 | ] 118 | logging.info(f"Completed parsing fields_of_study: {len(fields_of_study)}") 119 | logging.info(f"Completed parsing paper_with_fos: {len(paper_with_fos)}") 120 | 121 | # Parse affiliations 122 | items = [parse_affiliations(response, response["Id"]) for response in data] 123 | affiliations = [ 124 | d 125 | for d in unique_dicts(flatten_lists([item[0] for item in items])) 126 | if d["id"] not in aff_ids 127 | ] 128 | paper_author_aff = unique_dicts(flatten_lists([item[1] for item in items])) 129 | logging.info(f"Completed parsing affiliations: {len(affiliations)}") 130 | logging.info(f"Completed parsing author_with_aff: {len(paper_author_aff)}") 131 | 132 | logging.info(f"Parsing completed!") 133 | 134 | # Insert dicts into postgresql 135 | s.bulk_insert_mappings(Paper, papers) 136 | s.bulk_insert_mappings(Journal, journals) 137 | s.bulk_insert_mappings(Conference, conferences) 138 | s.bulk_insert_mappings(Author, authors) 139 | s.bulk_insert_mappings(PaperAuthor, paper_with_authors) 140 | s.bulk_insert_mappings(FieldOfStudy, fields_of_study) 141 | s.bulk_insert_mappings(PaperFieldsOfStudy, paper_with_fos) 142 | s.bulk_insert_mappings(Affiliation, affiliations) 143 | s.bulk_insert_mappings(AuthorAffiliation, paper_author_aff) 144 | s.commit() 145 | logging.info("Committed to DB!") 146 | 147 | 148 | class FosFrequencyOperator(BaseOperator): 149 | """Find the frequency of the Field of Studies.""" 150 | 151 | @apply_defaults 152 | def __init__(self, db_config, *args, **kwargs): 153 | super().__init__(**kwargs) 154 | self.db_config = db_config 155 | 156 | def execute(self, context): 157 | # Connect to PostgreSQL DB 158 | engine = create_engine(self.db_config) 159 | Session = sessionmaker(bind=engine) 160 | s = Session() 161 | 162 | # Get a count of field of study 163 | fos_freq = ( 164 | s.query( 165 | PaperFieldsOfStudy.field_of_study_id, 166 | func.count(PaperFieldsOfStudy.field_of_study_id), 167 | ) 168 | .group_by(PaperFieldsOfStudy.field_of_study_id) 169 | .all() 170 | ) 171 | 172 | # Transform it to a dictionary - This step can actually be skipped 173 | fos_freq = {tup[0]: tup[1] for tup in fos_freq} 174 | 175 | for k, v in fos_freq.items(): 176 | logging.info(f"FIELD_OF_STUDY: {k}") 177 | # Update the frequency column. Skip if the field_of_study id is not found 178 | try: 179 | fos = s.query(FosMetadata).filter(FosMetadata.id == k).one() 180 | fos.frequency = v 181 | s.commit() 182 | except NoResultFound: 183 | continue 184 | -------------------------------------------------------------------------------- /orion/core/operators/text2vec_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transforms a variable length text to a fixed-length vector. 3 | 4 | Text2VectorOperator: Uses a pretrained model (DistilBERT) from the transformers library 5 | to create word vectors which are then averaged to produce a document vector. It fetches 6 | abstracts from PostgreSQL which decodes to text. The output vectors are 7 | stored on PostgreSQL. 8 | 9 | Text2TfidfOperator: Transforms text to vectors using TF-IDF and SVD. TF-IDF from scikit-learn 10 | preprocesses the data and SVD reduces the dimensionality of the document vectors. It fetches 11 | abstracts from PostgreSQL which decodes to text. The output vectors are 12 | stored on PostgreSQL. 13 | 14 | Text2SentenceBertOperator: Uses a pretrained model (DistilBERT) to create sentence-level embeddings 15 | which are max-pooled to produce a document vector. It fetches abstracts from PostgreSQL and 16 | decodes to text. The output vectors are stored on PostgreSQL. 17 | 18 | Note: Running Text2VectorOperator and Text2SentenceBertOperator on a GPU will massively speed up 19 | the computation. 20 | 21 | """ 22 | import logging 23 | from sqlalchemy import create_engine, and_ 24 | from sqlalchemy.sql import exists 25 | from sqlalchemy.orm import sessionmaker 26 | from airflow.models import BaseOperator 27 | from airflow.utils.decorators import apply_defaults 28 | from orion.packages.nlp.text2vec import Text2Vector 29 | from orion.core.orms.mag_orm import Paper, HighDimDocVector 30 | from orion.packages.utils.s3_utils import store_on_s3 31 | from sklearn.feature_extraction.text import TfidfVectorizer 32 | from sentence_transformers import SentenceTransformer 33 | from sklearn.decomposition import TruncatedSVD 34 | import orion 35 | import toolz 36 | 37 | svd_components = orion.config["svd"]["n_components"] 38 | seed = orion.config["seed"] 39 | max_features = orion.config["tfidf"]["max_features"] 40 | 41 | 42 | class Text2VectorOperator(BaseOperator): 43 | """Transforms text to document embeddings.""" 44 | 45 | # template_fields = [''] 46 | @apply_defaults 47 | def __init__(self, db_config, *args, **kwargs): 48 | super().__init__(**kwargs) 49 | self.db_config = db_config 50 | 51 | def execute(self, context): 52 | # Connect to postgresql 53 | engine = create_engine(self.db_config) 54 | Session = sessionmaker(bind=engine) 55 | s = Session() 56 | 57 | # Get the abstracts of bioRxiv papers. 58 | papers = s.query(Paper.abstract, Paper.id).filter( 59 | and_( 60 | ~exists().where(Paper.id == HighDimDocVector.id), 61 | Paper.abstract != "NaN", 62 | ) 63 | ) 64 | logging.info(f"Number of documents to be vectorised: {papers.count()}") 65 | 66 | # Convert text to vectors 67 | tv = Text2Vector() 68 | for i, (abstract, id_) in enumerate(papers, start=1): 69 | vec = tv.average_vectors(tv.feature_extraction(tv.encode_text(abstract))) 70 | 71 | # Commit to db 72 | s.add(HighDimDocVector(**{"vector": vec.astype(float), "id": id_})) 73 | s.commit() 74 | logging.info("Committed to DB!") 75 | 76 | 77 | class Text2TfidfOperator(BaseOperator): 78 | """Transforms text to document embeddings.""" 79 | 80 | # template_fields = [''] 81 | @apply_defaults 82 | def __init__(self, db_config, bucket, *args, **kwargs): 83 | super().__init__(**kwargs) 84 | self.db_config = db_config 85 | self.bucket = bucket 86 | # self.prefix = prefix 87 | 88 | def execute(self, context): 89 | # Connect to postgresql 90 | engine = create_engine(self.db_config) 91 | Session = sessionmaker(bind=engine) 92 | s = Session() 93 | s.query(HighDimDocVector).delete() 94 | s.commit() 95 | # Get the paper abstracts. 96 | papers = s.query(Paper.abstract, Paper.id).filter( 97 | and_( 98 | ~exists().where(Paper.id == HighDimDocVector.id), 99 | Paper.abstract != "NaN", 100 | ) 101 | ) 102 | logging.info(f"Number of documents to be vectorised: {papers.count()}") 103 | 104 | # Unroll abstracts and IDs 105 | abstracts, ids = zip(*papers) 106 | 107 | # Get tfidf vectors 108 | vectorizer = TfidfVectorizer( 109 | stop_words="english", analyzer="word", max_features=max_features 110 | ) 111 | X = vectorizer.fit_transform(abstracts) 112 | logging.info("Embedding documents - Done!") 113 | 114 | # Reduce dimensionality with SVD to speed up UMAP computation 115 | svd = TruncatedSVD(n_components=svd_components, random_state=seed) 116 | features = svd.fit_transform(X) 117 | logging.info(f"SVD dimensionality reduction shape: {features.shape}") 118 | 119 | vectors = [{"vector": vec, "id": id_} for vec, id_ in zip(features, ids)] 120 | logging.info(f"Embeddings: {len(vectors)}") 121 | 122 | # Store models to S3 123 | store_on_s3(vectorizer, self.bucket, "tfidf_model") 124 | store_on_s3(svd, self.bucket, "svd_model") 125 | logging.info("Stored models to S3!") 126 | 127 | # Store vectors to DB 128 | s.bulk_insert_mappings(HighDimDocVector, vectors) 129 | s.commit() 130 | logging.info("Stored vectors to DB!") 131 | 132 | 133 | class Text2SentenceBertOperator(BaseOperator): 134 | """Transforms text to document embeddings.""" 135 | 136 | # template_fields = [''] 137 | @apply_defaults 138 | def __init__(self, db_config, batch_size, bert_model, *args, **kwargs): 139 | super().__init__(**kwargs) 140 | self.db_config = db_config 141 | self.batch_size = batch_size 142 | self.bert_model = bert_model 143 | 144 | def execute(self, context): 145 | # Instantiate SentenceTransformer 146 | model = SentenceTransformer(self.bert_model) 147 | 148 | # Connect to postgresql 149 | engine = create_engine(self.db_config) 150 | Session = sessionmaker(bind=engine) 151 | s = Session() 152 | 153 | # Get the abstracts of bioRxiv papers. 154 | papers = s.query(Paper.abstract, Paper.id).filter( 155 | and_( 156 | ~exists().where(Paper.id == HighDimDocVector.id), 157 | Paper.abstract != "NaN", 158 | ) 159 | ) 160 | logging.info(f"Number of documents to be vectorised: {papers.count()}") 161 | 162 | if papers.count() > 0: 163 | 164 | # Unroll abstracts and paper IDs 165 | abstracts, ids = zip(*papers) 166 | for i, (id_chunk, abstracts_chunk) in enumerate( 167 | zip( 168 | list(toolz.partition_all(self.batch_size, ids)), 169 | list(toolz.partition_all(self.batch_size, abstracts)), 170 | ), 171 | start=1, 172 | ): 173 | 174 | # Convert text to vectors 175 | embeddings = model.encode(abstracts_chunk) 176 | 177 | # Group IDs with embeddings 178 | batch = [ 179 | {"id": id_, "vector": vector.astype(float)} 180 | for id_, vector in zip(id_chunk, embeddings) 181 | ] 182 | 183 | # Commit to db 184 | s.bulk_insert_mappings(HighDimDocVector, batch) 185 | s.commit() 186 | logging.info(f"Committed batch {i}") 187 | else: 188 | logging.info("No documents need vectorisation") 189 | -------------------------------------------------------------------------------- /orion/core/operators/dim_reduction_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | DimReductionOperator: Transforms high dimensional arrays to 2D and 3D using UMAP. 3 | Fetches vectors and paper IDs from PostgreSQL and stores the low dimensional 4 | representation in PostgreSQL. 5 | 6 | """ 7 | import logging 8 | from sqlalchemy.orm import sessionmaker 9 | from sqlalchemy import create_engine, and_ 10 | from sqlalchemy.sql import exists 11 | from airflow.models import BaseOperator 12 | from airflow.utils.decorators import apply_defaults 13 | from orion.core.orms.mag_orm import DocVector, HighDimDocVector, Paper 14 | from orion.packages.projection.dim_reduction import umap_embeddings 15 | from orion.packages.utils.s3_utils import s3_bucket_obj, store_on_s3, load_from_s3 16 | 17 | 18 | def umap_exists(bucket, umap_model="umap_model.pickle"): 19 | s3_objects = set([obj.key for obj in s3_bucket_obj(bucket)]) 20 | if umap_model in s3_objects: 21 | return "call_umap" 22 | else: 23 | return "fit_umap" 24 | 25 | 26 | class DimReductionFittedUmapOperator(BaseOperator): 27 | """Transforms a high dimensional array to 2D or 3D with a fitted UMAP.""" 28 | 29 | @apply_defaults 30 | def __init__( 31 | self, 32 | db_config, 33 | s3_bucket, 34 | short_doc_len=300, 35 | remove_short_docs=True, 36 | exclude_docs=[], 37 | *args, 38 | **kwargs, 39 | ): 40 | super().__init__(**kwargs) 41 | self.db_config = db_config 42 | self.s3_bucket = s3_bucket 43 | self.exclude_docs = exclude_docs 44 | self.remove_short_docs = remove_short_docs 45 | self.short_doc_len = short_doc_len 46 | 47 | def execute(self, context): 48 | # Load fitted UMAP 49 | reducer = load_from_s3(self.s3_bucket, "umap_model") 50 | 51 | # Connect to postgresql 52 | engine = create_engine(self.db_config) 53 | Session = sessionmaker(bind=engine) 54 | s = Session() 55 | 56 | # Fetch paper citations 57 | paper_citations = { 58 | id_: citation_count 59 | for id_, citation_count in s.query(Paper.id, Paper.citations) 60 | } 61 | 62 | # Load high dimensional vectors and paper IDs 63 | vectors_and_ids = s.query(HighDimDocVector.vector, HighDimDocVector.id).filter( 64 | and_( 65 | ~exists().where(DocVector.id == HighDimDocVector.id), 66 | HighDimDocVector.id.notin_(self.exclude_docs), 67 | ) 68 | ) 69 | logging.info(f"New vectors: {vectors_and_ids.count()}") 70 | 71 | if self.remove_short_docs: 72 | # Find documents with very short abstracts 73 | short_docs_ids = [ 74 | id_ 75 | for id_, abstract in s.query(Paper.id, Paper.abstract) 76 | if len(abstract) < self.short_doc_len 77 | ] 78 | # Filter documents with very short abstracts 79 | vectors, ids = [], [] 80 | for vector, id_ in vectors_and_ids: 81 | if id_ not in short_docs_ids: 82 | vectors.append(vector) 83 | ids.append(id_) 84 | else: 85 | # Load vectors 86 | vectors, ids = zip(*vectors_and_ids) 87 | 88 | # Reduce the dimensionality of new vectors 89 | embeddings_3d = reducer.transform(vectors) 90 | logging.info(f"UMAP embeddings: {embeddings_3d.shape}") 91 | 92 | # Construct DB insertions 93 | doc_vectors = [ 94 | { 95 | "id": id_, 96 | "vector_3d": embed_3d.tolist(), 97 | "citations": paper_citations[id_], 98 | } 99 | for embed_3d, id_ in zip(embeddings_3d, ids) 100 | ] 101 | logging.info(f"Constructed DocVector input") 102 | 103 | # Store document vectors in PostgreSQL 104 | s.bulk_insert_mappings(DocVector, doc_vectors) 105 | s.commit() 106 | logging.info("Commited to DB!") 107 | 108 | 109 | class DimReductionOperator(BaseOperator): 110 | """Transforms a high dimensional array to 2D or 3D.""" 111 | 112 | @apply_defaults 113 | def __init__( 114 | self, 115 | db_config, 116 | s3_bucket, 117 | n_neighbors, 118 | min_dist, 119 | n_components, 120 | metric, 121 | short_doc_len=300, 122 | remove_short_docs=True, 123 | exclude_docs=[], 124 | *args, 125 | **kwargs, 126 | ): 127 | super().__init__(**kwargs) 128 | self.db_config = db_config 129 | self.s3_bucket = s3_bucket 130 | self.n_neighbors = n_neighbors 131 | self.min_dist = min_dist 132 | self.n_components = n_components 133 | self.metric = metric 134 | self.exclude_docs = exclude_docs 135 | self.remove_short_docs = remove_short_docs 136 | self.short_doc_len = short_doc_len 137 | 138 | def execute(self, context): 139 | # Connect to postgresql 140 | engine = create_engine(self.db_config) 141 | Session = sessionmaker(bind=engine) 142 | s = Session() 143 | 144 | # Fetch paper citations 145 | paper_citations = { 146 | id_: citation_count 147 | for id_, citation_count in s.query(Paper.id, Paper.citations) 148 | } 149 | 150 | # Delete existing UMAP projection 151 | s.query(DocVector).delete() 152 | s.commit() 153 | 154 | # Load high dimensional vectors and paper IDs 155 | vectors_and_ids = s.query(HighDimDocVector.vector, HighDimDocVector.id).filter( 156 | HighDimDocVector.id.notin_(self.exclude_docs) 157 | ) 158 | logging.info(f"Excluding: {self.exclude_docs}") 159 | 160 | if self.remove_short_docs: 161 | # Find documents with very short abstracts 162 | short_docs_ids = [ 163 | id_ 164 | for id_, abstract in s.query(Paper.id, Paper.abstract) 165 | if len(abstract) < self.short_doc_len 166 | ] 167 | # Filter documents with very short abstracts 168 | vectors, ids = [], [] 169 | for vector, id_ in vectors_and_ids: 170 | if id_ not in short_docs_ids: 171 | vectors.append(vector) 172 | ids.append(id_) 173 | else: 174 | # Load vectors 175 | vectors, ids = zip(*vectors_and_ids) 176 | 177 | logging.info( 178 | f"UMAP hyperparameters: n_neighbors:{self.n_neighbors}, min_dist:{self.min_dist}, metric:{self.metric}" 179 | ) 180 | 181 | # Reduce dimensionality to 3D with umap 182 | reducer, embeddings_3d = umap_embeddings( 183 | vectors, self.n_neighbors, self.min_dist, self.n_components + 1, self.metric 184 | ) 185 | 186 | logging.info(f"UMAP embeddings: {embeddings_3d.shape}") 187 | 188 | # Construct DB insertions 189 | doc_vectors = [ 190 | { 191 | "id": id_, 192 | "vector_3d": embed_3d.tolist(), 193 | "citations": paper_citations[id_], 194 | } 195 | for embed_3d, id_ in zip(embeddings_3d, ids) 196 | ] 197 | logging.info(f"Constructed DocVector input") 198 | 199 | # Store UMAP on S3 200 | store_on_s3(reducer, self.s3_bucket, "umap_model") 201 | 202 | # Store document vectors in PostgreSQL 203 | s.bulk_insert_mappings(DocVector, doc_vectors) 204 | s.commit() 205 | logging.info("Commited to DB!") 206 | -------------------------------------------------------------------------------- /orion/core/operators/mag_collect_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | MagCollectionOperator: Queries Microsoft Academic Knowledge API and stores the responses on S3. 3 | The pickle file is a list of JSONs where every JSON object is the API response corresponsing to a paper. 4 | Every pickle contains a maximum of 1,000 objects (that's the maximum number of 5 | papers we can retrieve from MAG with a query). 6 | 7 | Example API response: 8 | {'logprob': -24.006, 9 | 'prob': 3.75255e-11, 10 | 'Id': 2904236373, 11 | 'Ti': 'conspiracy ideation and fake science news', 12 | 'Pt': '0', 13 | 'Y': 2018, 14 | 'D': '2018-01-18', 15 | 'CC': 0, 16 | 'RId': [2067319876, 2130121899], 17 | 'PB': 'OSF', 18 | 'BT': 'a', 19 | 'AA': [{'DAuN': 'Asheley R. Landrum', 20 | 'AuId': 2226866834, 21 | 'AfId': None, 22 | 'S': 1}, 23 | {'DAuN': 'Alex Olshansky', 'AuId': 2883323127, 'AfId': None, 'S': 2}], 24 | 'F': [{'DFN': 'Science communication', 25 | 'FN': 'science communication', 26 | 'FId': 472806}, 27 | {'DFN': 'Public relations', 'FN': 'public relations', 'FId': 39549134}, 28 | {'DFN': 'Public awareness of science', 29 | 'FN': 'public awareness of science', 30 | 'FId': 176049440}, 31 | {'DFN': 'Political science', 'FN': 'political science', 'FId': 17744445}, 32 | {'DFN': 'Misinformation', 'FN': 'misinformation', 'FId': 2776990098}, 33 | {'DFN': 'Ideation', 'FN': 'ideation', 'FId': 170477896}, 34 | {'DFN': 'Fake news', 'FN': 'fake news', 'FId': 2779756789}]} 35 | 36 | MagFosCollectionOperator: Fetches Fields of Study IDs from PostgreSQL and collects their level in the hierarchy, 37 | child and parent nodes from Microsoft Academic Graph. 38 | 39 | """ 40 | import logging 41 | from datetime import datetime 42 | from sqlalchemy import create_engine 43 | from sqlalchemy.sql import exists 44 | from sqlalchemy.orm import sessionmaker 45 | from airflow.models import BaseOperator 46 | from airflow.utils.decorators import apply_defaults 47 | from orion.core.orms.mag_orm import FieldOfStudy, FosHierarchy, FosMetadata 48 | from orion.packages.mag.query_mag_api import ( 49 | query_mag_api, 50 | query_fields_of_study, 51 | build_composite_expr, 52 | ) 53 | from orion.packages.utils.s3_utils import store_on_s3 54 | from orion.packages.utils.utils import date_range, str2datetime 55 | import toolz 56 | 57 | 58 | class MagCollectionOperator(BaseOperator): 59 | """Queries MAG API.""" 60 | 61 | # template_fields = [''] 62 | 63 | @apply_defaults 64 | def __init__( 65 | self, 66 | subscription_key, 67 | output_bucket, 68 | query_values, 69 | entity_name, 70 | metadata, 71 | with_doi, 72 | mag_start_date, 73 | mag_end_date, 74 | intervals_in_a_year, 75 | *args, 76 | **kwargs, 77 | ): 78 | super().__init__(**kwargs) 79 | self.metadata = metadata 80 | self.query_values = query_values 81 | self.entity_name = entity_name 82 | self.subscription_key = subscription_key 83 | self.output_bucket = output_bucket 84 | self.with_doi = with_doi 85 | self.mag_start_date = mag_start_date 86 | self.mag_end_date = mag_end_date 87 | self.intervals_in_a_year = intervals_in_a_year 88 | 89 | def execute(self, context): 90 | # Convert strings to datetime objects 91 | self.mag_start_date = str2datetime(self.mag_start_date) 92 | self.mag_end_date = str2datetime(self.mag_end_date) 93 | 94 | # Number of time intervals for the data collection 95 | total_intervals = ( 96 | abs(self.mag_start_date.year - self.mag_end_date.year) + 1 97 | ) * self.intervals_in_a_year 98 | 99 | i = 0 100 | query_count = 1000 101 | for date in toolz.sliding_window( 102 | 2, list(date_range(self.mag_start_date, self.mag_end_date, total_intervals)) 103 | ): 104 | logging.info(f"Date interval: {date}") 105 | expression = build_composite_expr(self.query_values, self.entity_name, date) 106 | logging.info(f"{expression}") 107 | 108 | has_content = True 109 | # i = 1 110 | offset = 0 111 | # Request the API as long as we receive non-empty responses 112 | while has_content: 113 | logging.info(f"Query {i} - Offset {offset}...") 114 | 115 | data = query_mag_api( 116 | expression, 117 | self.metadata, 118 | self.subscription_key, 119 | query_count=query_count, 120 | offset=offset, 121 | ) 122 | 123 | if self.with_doi: 124 | # Keep only papers with a DOI 125 | results = [ 126 | ents for ents in data["entities"] if "DOI" in ents.keys() 127 | ] 128 | else: 129 | results = [ents for ents in data["entities"]] 130 | 131 | filename = "-".join([self.output_bucket, str(i),]) 132 | logging.info(f"File on s3: {filename}") 133 | 134 | store_on_s3(results, self.output_bucket, filename) 135 | logging.info(f"Number of stored results from query {i}: {len(results)}") 136 | 137 | i += 1 138 | offset += query_count 139 | 140 | if len(results) == 0: 141 | has_content = False 142 | 143 | 144 | class MagFosCollectionOperator(BaseOperator): 145 | """Queries MAG API with Fields of Study to collect their level 146 | in hierarchy, child and parent nodes.""" 147 | 148 | @apply_defaults 149 | def __init__(self, db_config, subscription_key, *args, **kwargs): 150 | super().__init__(**kwargs) 151 | self.db_config = db_config 152 | self.subscription_key = subscription_key 153 | 154 | def execute(self, context): 155 | # Connect to PostgreSQL DB 156 | engine = create_engine(self.db_config) 157 | Session = sessionmaker(bind=engine) 158 | s = Session() 159 | 160 | # Fetch FoS IDs 161 | all_fos_ids = set([id_[0] for id_ in s.query(FieldOfStudy.id)]) 162 | # Keep the FoS IDs that haven't been collected yet 163 | fields_of_study_ids = [ 164 | id_[0] 165 | for id_ in s.query(FieldOfStudy.id).filter( 166 | ~exists().where(FieldOfStudy.id == FosMetadata.id) 167 | ) 168 | ] 169 | logging.info(f"Fields of study left: {len(fields_of_study_ids)}") 170 | 171 | # Collect FoS metadata 172 | fos = query_fields_of_study(self.subscription_key, ids=fields_of_study_ids) 173 | 174 | # Parse api response 175 | for response in fos: 176 | s.add( 177 | FosMetadata(id=response["id"], level=response["level"], frequency=None) 178 | ) 179 | 180 | # Keep only the child and parent IDs that exist in our DB 181 | if "child_ids" in response.keys(): 182 | unique_child_ids = list(set(response["child_ids"]) & all_fos_ids) 183 | else: 184 | unique_child_ids = None 185 | 186 | if "parent_ids" in response.keys(): 187 | unique_parent_ids = list(set(response["parent_ids"]) & all_fos_ids) 188 | else: 189 | unique_parent_ids = None 190 | 191 | s.add( 192 | FosHierarchy( 193 | id=response["id"], 194 | child_id=unique_child_ids, 195 | parent_id=unique_parent_ids, 196 | ) 197 | ) 198 | 199 | # Commit all additions 200 | s.commit() 201 | -------------------------------------------------------------------------------- /orion/packages/mag/query_mag_api.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on the work we have done at Nesta. You can find the original here: https://github.com/nestauk/nesta/blob/dev/nesta/packages/mag/query_mag_api.py 3 | """ 4 | import logging 5 | import requests 6 | from collections import defaultdict 7 | from retrying import retry 8 | 9 | ENDPOINT = "https://api.labs.cognitive.microsoft.com/academic/v1.0/evaluate" 10 | 11 | 12 | def build_expr(query_items, entity_name, max_length=16000): 13 | """Builds and yields OR expressions for MAG from a list of items. Strings and 14 | integer items are formatted quoted and unquoted respectively, as per the MAG query 15 | specification. 16 | 17 | The maximum accepted query length for the api appears to be around 16,000 characters. 18 | 19 | Args: 20 | query_items (:obj:`list`): All items to be queried. 21 | entity_name (:obj:`str`): The mag entity to be queried ie 'Ti' or 'Id'. 22 | max_length (:obj:`int`): Length of the expression which should not be exceeded. Yields 23 | occur at or below this expression length. 24 | 25 | Returns: 26 | (:obj:`str`): Expression in the format expr=OR(entity_name=item1, entity_name=item2...). 27 | 28 | """ 29 | expr = [] 30 | length = 0 31 | query_prefix_format = "expr=OR({})" 32 | 33 | for item in query_items: 34 | if type(item) == str: 35 | formatted_item = f"{entity_name}='{item}'" 36 | elif type(item) == int: 37 | formatted_item = f"{entity_name}={item}" 38 | length = ( 39 | sum(len(e) + 1 for e in expr) 40 | + len(formatted_item) 41 | + len(query_prefix_format) 42 | ) 43 | if length >= max_length: 44 | yield query_prefix_format.format(",".join(expr)) 45 | expr.clear() 46 | expr.append(formatted_item) 47 | 48 | # pick up any remainder below max_length 49 | if len(expr) > 0: 50 | yield query_prefix_format.format(",".join(expr)) 51 | 52 | 53 | @retry(stop_max_attempt_number=10) 54 | def query_mag_api(expr, fields, subscription_key, query_count=1000, offset=0): 55 | """Posts a query to the Microsoft Academic Graph Evaluate API. 56 | 57 | Args: 58 | expr (:obj:`str`): Expression as built by build_expr. 59 | fields: (:obj:`list` of `str`): Codes of fields to return, as per mag documentation. 60 | query_count: (:obj:`int`): Number of items to return. 61 | offset (:obj:`int`): Offset in the results if paging through them. 62 | 63 | Returns: 64 | (:obj:`dict`): JSON response from the api containing 'expr' (the original expression) 65 | and 'entities' (the results) keys. 66 | If there are no results 'entities' is an empty list. 67 | 68 | """ 69 | headers = { 70 | "Ocp-Apim-Subscription-Key": subscription_key, 71 | "Content-Type": "application/x-www-form-urlencoded", 72 | } 73 | query = f"{expr}&count={query_count}&offset={offset}&attributes={','.join(fields)}" 74 | 75 | r = requests.post(ENDPOINT, data=query.encode("utf-8"), headers=headers) 76 | r.raise_for_status() 77 | 78 | return r.json() 79 | 80 | 81 | def dedupe_entities(entities): 82 | """Finds the highest probability match for each title in returned entities from MAG. 83 | 84 | Args: 85 | entities (:obj:`list` of `dict`): Entities from the MAG api. 86 | 87 | Returns: 88 | (set): IDs of entities with the highest probability score, one for each title. 89 | 90 | """ 91 | titles = defaultdict(dict) 92 | for row in entities: 93 | titles[row["Ti"]].update({row["Id"]: row["logprob"]}) 94 | 95 | deduped_mag_ids = set() 96 | for title in titles.values(): 97 | # find highest probability match for each title 98 | deduped_mag_ids.add(sorted(title, key=title.get, reverse=True)[0]) 99 | 100 | return deduped_mag_ids 101 | 102 | 103 | def query_fields_of_study( 104 | subscription_key, 105 | ids=None, 106 | levels=None, 107 | fields=["Id", "DFN", "FL", "FP.FId", "FC.FId"], 108 | # id, display_name, level, parent_ids, children_ids 109 | query_count=1000, 110 | results_limit=None, 111 | ): 112 | """Queries the MAG for fields of study. Expect >650k results for all levels. 113 | 114 | Args: 115 | subscription_key (str): MAG api subscription key 116 | ids: (:obj:`list` of `int`): field of study ids to query 117 | levels (:obj:`list` of `int`): levels to extract. 0 is highest, 5 is lowest 118 | fields (:obj:`list` of `str`): codes of fields to return, as per mag documentation 119 | query_count (int): number of items to return from each query 120 | results_limit (int): break and return as close to this number of results as the 121 | offset and query_count allow (for testing) 122 | 123 | Returns: 124 | (:obj:`list` of `dict`): processed results from the api query 125 | 126 | """ 127 | if ids is not None and levels is None: 128 | expr_args = (ids, "Id") 129 | elif levels is not None and ids is None: 130 | expr_args = (levels, "FL") 131 | else: 132 | raise TypeError("Field of study ids OR levels should be supplied") 133 | 134 | field_mapping = { 135 | "Id": "id", 136 | "DFN": "name", 137 | "FL": "level", 138 | "FP": "parent_ids", 139 | "FC": "child_ids", 140 | } 141 | fields_to_drop = ["logprob", "prob"] 142 | fields_to_compact = ["parent_ids", "child_ids"] 143 | 144 | for expr in build_expr(*expr_args): 145 | count = 1000 146 | offset = 0 147 | while True: 148 | fos_data = query_mag_api( 149 | expr, 150 | fields, 151 | subscription_key=subscription_key, 152 | query_count=count, 153 | offset=offset, 154 | ) 155 | if fos_data["entities"] == []: 156 | logging.info("Empty entities returned, no more data") 157 | break 158 | 159 | # clean up and formatting 160 | for row in fos_data["entities"]: 161 | for f in fields_to_drop: 162 | del row[f] 163 | 164 | for code, description in field_mapping.items(): 165 | try: 166 | row[description] = row.pop(code) 167 | except KeyError: 168 | pass 169 | 170 | for field in fields_to_compact: 171 | try: 172 | row[field] = [ids["FId"] for ids in row[field]] 173 | except KeyError: 174 | # no parents and/or children 175 | pass 176 | 177 | logging.info(f"new fos: {row}") 178 | yield row 179 | 180 | offset += len(fos_data["entities"]) 181 | logging.info(offset) 182 | 183 | if results_limit is not None and offset >= results_limit: 184 | break 185 | 186 | 187 | def build_composite_expr(query_values, entity_name, date): 188 | """Builds a composite expression with ANDs in OR to be used as MAG query. 189 | 190 | Args: 191 | query_values (:obj:`list` of str): Phrases to query MAG with. 192 | entity_name (str): MAG attribute that will be used in query. 193 | date (:obj:`tuple` of `str`): Time period of the data collection. 194 | 195 | Returns: 196 | (str) MAG expression. 197 | 198 | """ 199 | query_prefix_format = "expr=OR({})" 200 | and_queries = [ 201 | "".join( 202 | [ 203 | f"And(Composite({entity_name}='{query_value}'), D=['{date[0]}', '{date[1]}'])" 204 | ] 205 | ) 206 | for query_value in query_values 207 | ] 208 | return query_prefix_format.format(", ".join(and_queries)) 209 | -------------------------------------------------------------------------------- /orion/packages/utils/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pandas as pd 3 | import numpy as np 4 | from collections import Counter 5 | from datetime import datetime 6 | 7 | from orion.packages.utils.utils import flatten_lists 8 | from orion.packages.utils.utils import unique_dicts 9 | from orion.packages.utils.utils import unique_dicts_by_value 10 | from orion.packages.utils.utils import dict2psql_format 11 | from orion.packages.utils.utils import inverted2abstract 12 | from orion.packages.utils.utils import cooccurrence_graph 13 | from orion.packages.utils.utils import get_all_children 14 | from orion.packages.utils.utils import date_range 15 | from orion.packages.utils.utils import str2datetime 16 | 17 | example_list_dict = [ 18 | {"DFN": "Biology", "FId": 86803240}, 19 | {"DFN": "Biofilm", "FId": 58123911}, 20 | {"DFN": "Bacterial growth", "FId": 17741926}, 21 | {"DFN": "Bacteria", "FId": 523546767}, 22 | {"DFN": "Agar plate", "FId": 62643968}, 23 | {"DFN": "Agar", "FId": 2778660310}, 24 | {"DFN": "Agar", "FId": 2778660310}, 25 | {"DFN": "Agar foo bar", "FId": 2778660310}, 26 | ] 27 | 28 | 29 | def test_flatten_lists(): 30 | nested_list = [["a"], ["b"], ["c"]] 31 | result = flatten_lists(nested_list) 32 | assert result == ["a", "b", "c"] 33 | 34 | 35 | def test_unique_dicts(): 36 | result = unique_dicts(example_list_dict) 37 | expected_result = [ 38 | {"DFN": "Biology", "FId": 86803240}, 39 | {"DFN": "Biofilm", "FId": 58123911}, 40 | {"DFN": "Bacterial growth", "FId": 17741926}, 41 | {"DFN": "Bacteria", "FId": 523546767}, 42 | {"DFN": "Agar plate", "FId": 62643968}, 43 | {"DFN": "Agar", "FId": 2778660310}, 44 | {"DFN": "Agar foo bar", "FId": 2778660310}, 45 | ] 46 | assert sorted([d["FId"] for d in result]) == sorted( 47 | [d["FId"] for d in expected_result] 48 | ) 49 | 50 | 51 | def test_unique_dicts_by_value(): 52 | result = unique_dicts_by_value(example_list_dict, "FId") 53 | expected_result = [ 54 | {"DFN": "Biology", "FId": 86803240}, 55 | {"DFN": "Biofilm", "FId": 58123911}, 56 | {"DFN": "Bacterial growth", "FId": 17741926}, 57 | {"DFN": "Bacteria", "FId": 523546767}, 58 | {"DFN": "Agar plate", "FId": 62643968}, 59 | {"DFN": "Agar foo bar", "FId": 2778660310}, 60 | ] 61 | 62 | assert result == expected_result 63 | 64 | 65 | def test_dict2postgresql_format(): 66 | data = {} 67 | data[123] = pd.Series(data=[1, 2, 3], index=[["a", "b", "c"], ["v", "b", "a"]]) 68 | 69 | expected_result = [ 70 | {"entity": "a", "year": "v", "rca_sum": 1, "field_of_study_id": 123}, 71 | {"entity": "b", "year": "b", "rca_sum": 2, "field_of_study_id": 123}, 72 | {"entity": "c", "year": "a", "rca_sum": 3, "field_of_study_id": 123}, 73 | ] 74 | 75 | result = dict2psql_format(data) 76 | 77 | assert result == expected_result 78 | 79 | 80 | def test_inverted_abstract(): 81 | data = { 82 | "IndexLength": 111, 83 | "InvertedIndex": { 84 | "In": [0], 85 | "comparative": [1], 86 | "high-throughput": [2], 87 | "sequencing": [3], 88 | "assays,": [4], 89 | "a": [5, 44, 51, 78], 90 | "fundamental": [6], 91 | "task": [7], 92 | "is": [8, 105], 93 | "the": [9, 39, 74, 84, 88], 94 | "analysis": [10, 55, 81], 95 | "of": [11, 25, 41, 56, 73, 91], 96 | "count": [12, 57], 97 | "data,": [13, 22], 98 | "such": [14, 98], 99 | "as": [15, 99, 107], 100 | "read": [16], 101 | "counts": [17], 102 | "per": [18], 103 | "gene": [19, 100], 104 | "in": [20], 105 | "RNA-Seq": [21], 106 | "for": [23, 53, 63], 107 | "evidence": [24], 108 | "systematic": [26], 109 | "changes": [27, 67], 110 | "across": [28], 111 | "experimental": [29], 112 | "conditions.": [30], 113 | "Small": [31], 114 | "replicate": [32], 115 | "numbers,": [33], 116 | "discreteness,": [34], 117 | "large": [35], 118 | "dynamic": [36], 119 | "range": [37], 120 | "and": [38, 65, 71, 94, 102], 121 | "presence": [40, 90], 122 | "outliers": [42], 123 | "require": [43], 124 | "suitable": [45], 125 | "statistical": [46], 126 | "approach.": [47], 127 | "We": [48], 128 | "present": [49], 129 | "DESeq2,": [50], 130 | "method": [52], 131 | "differential": [54, 92], 132 | "data.": [58], 133 | "DESeq2": [59, 104], 134 | "uses": [60], 135 | "shrinkage": [61], 136 | "estimation": [62], 137 | "dispersions": [64], 138 | "fold": [66], 139 | "to": [68], 140 | "improve": [69], 141 | "stability": [70], 142 | "interpretability": [72], 143 | "estimates.": [75], 144 | "This": [76], 145 | "enables": [77], 146 | "more": [79], 147 | "quantitative": [80], 148 | "focused": [82], 149 | "on": [83], 150 | "strength": [85], 151 | "rather": [86], 152 | "than": [87], 153 | "mere": [89], 154 | "expression": [93], 155 | "facilitates": [95], 156 | "downstream": [96], 157 | "tasks": [97], 158 | "ranking": [101], 159 | "visualization.": [103], 160 | "available": [106], 161 | "an": [108], 162 | "R/Bioconductor": [109], 163 | "package.": [110], 164 | }, 165 | } 166 | 167 | expected_result = "In comparative high-throughput sequencing assays, a fundamental task is the analysis of count data, such as read counts per gene in RNA-Seq data, for evidence of systematic changes across experimental conditions. Small replicate numbers, discreteness, large dynamic range and the presence of outliers require a suitable statistical approach. We present DESeq2, a method for differential analysis of count data. DESeq2 uses shrinkage estimation for dispersions and fold changes to improve stability and interpretability of the estimates. This enables a more quantitative analysis focused on the strength rather than the mere presence of differential expression and facilitates downstream tasks such as gene ranking and visualization. DESeq2 is available as an R/Bioconductor package." 168 | 169 | result = inverted2abstract(data) 170 | 171 | assert result == expected_result 172 | 173 | 174 | def test_inverted_abstract_empty_field(): 175 | data = None 176 | result = inverted2abstract(data) 177 | 178 | assert np.isnan(result) 179 | 180 | 181 | def test_cooccurrence_graph(): 182 | data = [["a", "b"], ["a", "b", "c"]] 183 | 184 | expected_result = Counter({("a", "b"): 2, ("a", "c"): 1, ("b", "c"): 1}) 185 | result = cooccurrence_graph(data) 186 | 187 | assert result == expected_result 188 | 189 | 190 | def test_get_all_children(): 191 | data = pd.DataFrame( 192 | { 193 | "id": [165864922, 114009990, 178809742, 2909274368, 196033, 190796033], 194 | "child_id": [ 195 | [190796033, 114009990, 178809742], 196 | [2909274368, 196033], 197 | [], 198 | [], 199 | [114009990, 2909274368], 200 | [190796033], 201 | ], 202 | } 203 | ) 204 | 205 | expected_result = [2909274368, 190796033, 196033, 114009990, 178809742, 165864922] 206 | result = list(set(get_all_children(data, 165864922))) 207 | 208 | assert result == expected_result 209 | 210 | 211 | def test_date_range(): 212 | start = datetime.strptime("2000-01-01", "%Y-%m-%d") 213 | end = datetime.strptime("2000-12-31", "%Y-%m-%d") 214 | result = list(date_range(start, end, 6)) 215 | expected_result = [ 216 | "2000-01-01", 217 | "2000-03-01", 218 | "2000-05-01", 219 | "2000-07-01", 220 | "2000-08-31", 221 | "2000-10-31", 222 | "2000-12-31", 223 | ] 224 | 225 | assert result == expected_result 226 | 227 | 228 | def test_str2datetime(): 229 | result = str2datetime("2000-12-31") 230 | expected_result = datetime.strptime("2000-12-31", "%Y-%m-%d") 231 | 232 | assert result == expected_result 233 | -------------------------------------------------------------------------------- /orion/core/operators/draw_collaboration_graph_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | CountryCollaborationOperator: Draws a collaboration graph between countries based on 3 | the author affiliations. Papers are filtered by their publication year. 4 | 5 | CountrySimilarityOperator: Finds the similarity between countries based on their abstracts. 6 | It averages the abstract vectors of a country to create a country vector. 7 | Uses the text vectors that were calculated from the text2vector task. It filters papers 8 | by publication year and users can choose the number of similar countries to return. 9 | 10 | """ 11 | import logging 12 | import pandas as pd 13 | import numpy as np 14 | from sqlalchemy import create_engine 15 | from sqlalchemy.orm import sessionmaker 16 | from airflow.models import BaseOperator 17 | from airflow.utils.decorators import apply_defaults 18 | from orion.packages.utils.utils import cooccurrence_graph 19 | from orion.core.orms.mag_orm import ( 20 | AuthorAffiliation, 21 | AffiliationLocation, 22 | CountryCollaboration, 23 | Paper, 24 | FilteredFos, 25 | CountrySimilarity, 26 | PaperFieldsOfStudy, 27 | HighDimDocVector, 28 | ) 29 | from orion.packages.projection.faiss_index import faiss_index 30 | 31 | 32 | class CountryCollaborationOperator(BaseOperator): 33 | """Create a cooccurrence graph of country-level collaboration.""" 34 | 35 | @apply_defaults 36 | def __init__(self, db_config, year, *args, **kwargs): 37 | super().__init__(**kwargs) 38 | self.db_config = db_config 39 | self.year = year 40 | 41 | def execute(self, context): 42 | # Connect to postgresql db 43 | engine = create_engine(self.db_config) 44 | Session = sessionmaker(engine) 45 | s = Session() 46 | 47 | # Load all the tables needed for the collaboration graph. 48 | aff_location = pd.read_sql(s.query(AffiliationLocation).statement, s.bind) 49 | author_aff = pd.read_sql(s.query(AuthorAffiliation).statement, s.bind) 50 | papers = pd.read_sql(s.query(Paper).statement, s.bind) 51 | 52 | # Merge tables 53 | df = ( 54 | aff_location[["affiliation_id", "country"]] 55 | .merge(author_aff, left_on="affiliation_id", right_on="affiliation_id") 56 | .merge(papers[["id", "year"]], left_on="paper_id", right_on="id") 57 | ) 58 | 59 | # Group countries by paper, remove duplicates and missing entries. 60 | for year in [year for year in sorted(df.year.unique()) if year > self.year]: 61 | logging.info(f"Collaboration network for year: {year}") 62 | grouped_df = ( 63 | df[(df.country != "") & (df.year == year)][["country", "paper_id"]] 64 | .dropna() 65 | .groupby("paper_id")["country"] 66 | .apply(set) 67 | ) 68 | logging.info(f"Grouped DF shape: {grouped_df.shape}") 69 | graph = cooccurrence_graph(grouped_df) 70 | for k, v in graph.items(): 71 | s.add( 72 | CountryCollaboration( 73 | **{ 74 | "country_a": k[0], 75 | "country_b": k[1], 76 | "weight": v, 77 | "year": year, 78 | } 79 | ) 80 | ) 81 | s.commit() 82 | 83 | logging.info("Done :)") 84 | 85 | 86 | class CountrySimilarityOperator(BaseOperator): 87 | """Find the semantic similarity between abstracts.""" 88 | 89 | @apply_defaults 90 | def __init__(self, db_config, year, k=5, thresh=2, *args, **kwargs): 91 | super().__init__(**kwargs) 92 | self.db_config = db_config 93 | self.year = year 94 | self.k = k 95 | self.thresh = thresh 96 | 97 | def execute(self, context): 98 | # Connect to postgresql 99 | engine = create_engine(self.db_config) 100 | Session = sessionmaker(bind=engine) 101 | s = Session() 102 | 103 | # Drop and recreate the country similarity table to update the metric 104 | CountrySimilarity.__table__.drop(engine, checkfirst=True) 105 | CountrySimilarity.__table__.create(engine, checkfirst=True) 106 | 107 | # Load all the tables needed for the metrics 108 | papers = pd.read_sql(s.query(Paper).statement, s.bind) 109 | aff_location = pd.read_sql(s.query(AffiliationLocation).statement, s.bind) 110 | author_aff = pd.read_sql(s.query(AuthorAffiliation).statement, s.bind) 111 | paper_fos = pd.read_sql(s.query(PaperFieldsOfStudy).statement, s.bind) 112 | filtered_fos = pd.read_sql(s.query(FilteredFos).statement, s.bind) 113 | vectors = pd.read_sql( 114 | s.query(HighDimDocVector.vector, HighDimDocVector.id).statement, s.bind 115 | ) 116 | vectors["vector"] = vectors.vector.apply(np.array) 117 | 118 | # dict(topic id, all children) 119 | d = {} 120 | for _, row in filtered_fos.drop_duplicates("field_of_study_id").iterrows(): 121 | d[row["field_of_study_id"]] = row["all_children"] 122 | 123 | for parent, children in d.items(): 124 | logging.info(f"Parent ID: {parent} - Number of children: {len(children)}") 125 | # Merge tables for a particular "discipline" (level 1 FoS and its children) 126 | df = ( 127 | aff_location[aff_location.country != ""][["affiliation_id", "country"]] 128 | .merge(author_aff, left_on="affiliation_id", right_on="affiliation_id") 129 | .merge( 130 | papers[["id", "year", "citations"]], 131 | left_on="paper_id", 132 | right_on="id", 133 | ) 134 | .merge(vectors, left_on="paper_id", right_on="id") 135 | .merge( 136 | paper_fos[paper_fos["field_of_study_id"].isin(children)], 137 | left_on="paper_id", 138 | right_on="paper_id", 139 | )[ 140 | [ 141 | "affiliation_id", 142 | "field_of_study_id", 143 | "country", 144 | "paper_id", 145 | "citations", 146 | "year", 147 | "vector", 148 | ] 149 | ] 150 | ) 151 | # Filter country/year pairs based on paper frequency 152 | filter_ = ( 153 | df.drop_duplicates(["country", "paper_id", "year"]) 154 | .groupby(["year", "country"])["paper_id"] 155 | .count() 156 | ) 157 | logging.info(f"Remaining country/year pairs: {filter_.shape}") 158 | 159 | # Group and drop countries with less than N papers 160 | grouped = ( 161 | df.drop_duplicates(["country", "paper_id", "year"]) 162 | .groupby(["year", "country"])["vector"] 163 | .apply(lambda x: np.mean(x, axis=0)) 164 | .loc[filter_.where(filter_ > self.thresh).dropna().index] 165 | ) 166 | 167 | # Find similar countries on annual basis 168 | for year in set([tup[0] for tup in grouped.index if tup[0] > self.year]): 169 | v = np.array([v for v in grouped.loc[year]]).astype("float32") 170 | ids = range(len(grouped.loc[year].index)) 171 | logging.info(f"Vectors shape: {v.shape}") 172 | # Check that we have at least more than 5 countries in that year and topic 173 | if v.shape[0] > self.k: 174 | # Create FAISS index 175 | index = faiss_index(v, ids) 176 | # Find similar countries for each country 177 | for vector, country in zip(v, grouped.loc[year].index): 178 | D, I = index.search(np.array([vector]), self.k + 1) 179 | for i, (idx, similarity) in enumerate(zip(I[0][1:], D[0][1:])): 180 | 181 | s.add( 182 | CountrySimilarity( 183 | country_a=country, 184 | country_b=grouped.loc[year].index[idx], 185 | closeness=float(similarity), 186 | year=year, 187 | field_of_study_id=parent, 188 | ) 189 | ) 190 | s.commit() 191 | # logging.info(f"Stored in DB for {country} - {year}") 192 | else: 193 | continue 194 | -------------------------------------------------------------------------------- /orion/packages/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def _rca_division(val1, val2, val3, val4): 5 | """Multi-step division.""" 6 | return (val1 / val2) / (val3 / val4) 7 | 8 | 9 | def calculate_rca_by_sum( 10 | data, entity_column, commodity, value, paper_thresh, year_thresh 11 | ): 12 | """Groups a dataframe by entity (country or institution) and 13 | calculates the Revealed Comparative Advantage (RCA) for each 14 | entity, based on a commodity. The value used for measurement is 15 | usually citations and it's done on an annual basis. 16 | 17 | Args: 18 | data (:code:`pandas.DataFrame`): DF 19 | entity_column (str): Label of the column containing countries or institutions. 20 | commodity (str): The Field of Study to measure RCA for. 21 | value (str): Label of the column containing the values to use for measurement. 22 | paper_thresh (int): Calculate RCA for country with more than N number of papers. 23 | year_thresh (str): Consider only years higher than the threshold. 24 | 25 | Returns: 26 | (:code:`pandas.DataFrame`): grouped dataframe with entity, year and calculated RCA. 27 | 28 | """ 29 | data = data[data.year > year_thresh] 30 | # Filter out countries with less than N papers on a topic 31 | entity_count_topic = ( 32 | data[(data.field_of_study_id == commodity)] 33 | .groupby([entity_column, "year"])["paper_id"] 34 | .count() 35 | ) 36 | idx = entity_count_topic.where(entity_count_topic >= paper_thresh).dropna().index 37 | 38 | entity_sum_topic = ( 39 | data[(data.field_of_study_id == commodity)] 40 | .groupby([entity_column, "year"])[value] 41 | .sum() 42 | .loc[idx] 43 | ) 44 | entity_sum_all = data.groupby([entity_column, "year"])[value].sum() 45 | entity_sum_all = entity_sum_all.where(entity_sum_all > paper_thresh) 46 | 47 | world_sum_topic = ( 48 | data[data.field_of_study_id == commodity].groupby("year")[value].sum() 49 | ) 50 | world_sum_all = data.groupby("year")[value].sum() 51 | 52 | rca = _rca_division( 53 | entity_sum_topic, entity_sum_all, world_sum_topic, world_sum_all 54 | ) 55 | 56 | return rca.dropna().clip(upper=6) 57 | 58 | 59 | def calculate_rca_by_count(data, entity_column, commodity, paper_thresh, year_thresh): 60 | """Groups a dataframe by entity (country or institution) and 61 | calculates the Revealed Comparative Advantage (RCA) for each 62 | entity, based on a commodity. The value used for measurement is publication 63 | volume and it's done on an annual basis. 64 | 65 | Args: 66 | data (:code:`pandas.DataFrame`): DF with the re 67 | entity_column (str): Label of the column containing countries or institutions. 68 | commodity (str): The Field of Study to measure RCA for. 69 | paper_thresh (int): Calculate RCA for country with more than N number of papers. 70 | year_thresh (str): Consider only years higher than the threshold. 71 | 72 | Returns: 73 | (:code:`pandas.DataFrame`): grouped dataframe with entity, year and calculated RCA. 74 | 75 | """ 76 | data = data[data.year > year_thresh] 77 | entity_count_topic = ( 78 | data[(data.field_of_study_id == commodity)] 79 | .groupby([entity_column, "year"])["paper_id"] 80 | .count() 81 | ) 82 | entity_count_all = data.groupby([entity_column, "year"])["paper_id"].count() 83 | entity_count_all = entity_count_all.where(entity_count_all > paper_thresh) 84 | 85 | world_count_topic = ( 86 | data[data.field_of_study_id == commodity].groupby("year")["paper_id"].count() 87 | ) 88 | world_count_all = data.groupby("year")["paper_id"].count() 89 | 90 | rca = _rca_division( 91 | entity_count_topic, entity_count_all, world_count_topic, world_count_all 92 | ) 93 | 94 | return rca.dropna() 95 | 96 | 97 | def _validate_counts_vector(counts, suppress_cast=False): 98 | """Validates and converts input to an acceptable counts vector type. 99 | Note: may not always return a copy of `counts`! 100 | 101 | This is taken from scikit-bio. 102 | 103 | """ 104 | counts = np.asarray(counts) 105 | 106 | if not suppress_cast: 107 | counts = counts.astype(int, casting="safe", copy=False) 108 | 109 | if counts.ndim != 1: 110 | raise ValueError("Only 1-D vectors are supported.") 111 | elif (counts < 0).any(): 112 | raise ValueError("Counts vector cannot contain negative values.") 113 | 114 | return counts 115 | 116 | 117 | def dominance(counts): 118 | """Calculates dominance. 119 | 120 | Dominance is defined as 121 | .. math:: 122 | \sum{p_i^2} 123 | where :math:`p_i` is the proportion of the entire community that OTU 124 | :math:`i` represents. 125 | Dominance can also be defined as 1 - Simpson's index. It ranges between 126 | 0 and 1. 127 | 128 | Args: 129 | counts (:obj:`numpy.array` of `int`): Vector of counts. 130 | 131 | Returns: 132 | (float) dominance score. 133 | 134 | Notes 135 | ----- 136 | The implementation here is based on the description given in [1]_. 137 | References 138 | ---------- 139 | .. [1] http://folk.uio.no/ohammer/past/diversity.html 140 | 141 | This is taken from scikit-bio. 142 | 143 | """ 144 | counts = _validate_counts_vector(counts) 145 | freqs = counts / counts.sum() 146 | return (freqs * freqs).sum() 147 | 148 | 149 | def shannon(counts, base=2): 150 | """Calculate Shannon entropy of counts, default in bits. 151 | Shannon-Wiener diversity index is defined as: 152 | .. math:: 153 | H = -\sum_{i=1}^s\left(p_i\log_2 p_i\right) 154 | where :math:`s` is the number of OTUs and :math:`p_i` is the proportion of 155 | the community represented by OTU :math:`i`. 156 | 157 | 158 | Args: 159 | counts (:obj:`numpy.array` of `int`): Vector of counts. 160 | base (int): Logarithm base to use in the calculations. 161 | 162 | Returns: 163 | (float) Shannon diversity index H. 164 | 165 | Notes 166 | ----- 167 | The implementation here is based on the description given in the SDR-IV 168 | online manual [1]_ except that the default logarithm base used here is 2 169 | instead of :math:`e`. 170 | References 171 | ---------- 172 | .. [1] http://www.pisces-conservation.com/sdrhelp/index.html 173 | 174 | This is taken from scikit-bio. 175 | 176 | """ 177 | counts = _validate_counts_vector(counts) 178 | freqs = counts / counts.sum() 179 | nonzero_freqs = freqs[freqs.nonzero()] 180 | return -(nonzero_freqs * np.log(nonzero_freqs)).sum() / np.log(base) 181 | 182 | 183 | def simpson(counts): 184 | """Calculate Simpson's index. 185 | Simpson's index is defined as ``1 - dominance``: 186 | .. math:: 187 | 1 - \sum{p_i^2} 188 | where :math:`p_i` is the proportion of the community represented by OTU 189 | :math:`i`. 190 | 191 | Args: 192 | counts (:obj:`numpy.array` of `int`): Vector of counts. 193 | 194 | Returns: 195 | (float) Simpson's index. 196 | 197 | Notes 198 | ----- 199 | The implementation here is ``1 - dominance`` as described in [1]_. Other 200 | references (such as [2]_) define Simpson's index as ``1 / dominance``. 201 | References 202 | ---------- 203 | .. [1] http://folk.uio.no/ohammer/past/diversity.html 204 | .. [2] http://www.pisces-conservation.com/sdrhelp/index.html 205 | 206 | This is taken from scikit-bio. 207 | 208 | """ 209 | counts = _validate_counts_vector(counts) 210 | return 1 - dominance(counts) 211 | 212 | 213 | def enspie(counts): 214 | """Calculate ENS_pie alpha diversity measure. 215 | ENS_pie is equivalent to ``1 / dominance``: 216 | .. math:: 217 | ENS_{pie} = \frac{1}{\sum_{i=1}^s{p_i^2}} 218 | where :math:`s` is the number of OTUs and :math:`p_i` is the proportion of 219 | the community represented by OTU :math:`i`. 220 | 221 | Args: 222 | counts (:obj:`numpy.array` of `int`): Vector of counts. 223 | 224 | Returns: 225 | (float) ENS_pie alpha diversity measure. 226 | 227 | Notes 228 | ----- 229 | ENS_pie is defined in [1]_. 230 | References 231 | ---------- 232 | .. [1] Chase and Knight (2013). "Scale-dependent effect sizes of ecological 233 | drivers on biodiversity: why standardised sampling is not enough". 234 | Ecology Letters, Volume 16, Issue Supplement s1, pgs 17-26. 235 | 236 | This is taken from scikit-bio. 237 | 238 | """ 239 | counts = _validate_counts_vector(counts) 240 | return 1 / dominance(counts) 241 | 242 | 243 | def observed_otus(counts): 244 | """Calculate the number of distinct OTUs. 245 | 246 | Args: 247 | counts (:obj:`numpy.array` of `int`): Vector of counts. 248 | 249 | Returns: 250 | (int) Distinct OTU count. 251 | 252 | """ 253 | counts = _validate_counts_vector(counts) 254 | return (counts != 0).sum() 255 | 256 | 257 | def simpson_e(counts): 258 | """Calculate Simpson's evenness measure E. 259 | Simpson's E is defined as 260 | .. math:: 261 | E=\frac{1 / D}{S_{obs}} 262 | where :math:`D` is dominance and :math:`S_{obs}` is the number of observed 263 | OTUs. 264 | 265 | Args: 266 | counts (:obj:`numpy.array` of `int`): Vector of counts. 267 | 268 | Returns: 269 | (float) Simpson's evenness measure E. 270 | 271 | Notes 272 | ----- 273 | The implementation here is based on the description given in [1]_. 274 | References 275 | ---------- 276 | .. [1] http://www.tiem.utk.edu/~gross/bioed/bealsmodules/simpsonDI.html 277 | 278 | This is taken from scikit-bio. 279 | 280 | """ 281 | counts = _validate_counts_vector(counts) 282 | return enspie(counts) / observed_otus(counts) 283 | -------------------------------------------------------------------------------- /orion/core/dags/orion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datetime import datetime, timedelta 3 | from airflow import DAG 4 | from airflow.operators.dummy_operator import DummyOperator 5 | from airflow.operators.python_operator import PythonOperator, BranchPythonOperator 6 | from orion.packages.utils.s3_utils import create_s3_bucket 7 | from orion.core.operators.mag_parse_task import MagParserOperator, FosFrequencyOperator 8 | from orion.core.operators.draw_collaboration_graph_task import ( 9 | CountryCollaborationOperator, 10 | CountrySimilarityOperator, 11 | ) 12 | from orion.core.operators.mag_geocode_task import GeocodingOperator 13 | from orion.core.operators.mag_collect_task import ( 14 | MagCollectionOperator, 15 | MagFosCollectionOperator, 16 | ) 17 | import orion 18 | from orion.core.operators.infer_gender_task import ( 19 | NamesBatchesOperator, 20 | GenderInferenceOperator, 21 | ) 22 | from orion.core.operators.calculate_metrics_task import ( 23 | RCAOperator, 24 | ResearchDiversityOperator, 25 | GenderDiversityOperator, 26 | ) 27 | from orion.core.operators.text2vec_task import Text2SentenceBertOperator 28 | from orion.core.operators.dim_reduction_task import ( 29 | DimReductionOperator, 30 | DimReductionFittedUmapOperator, 31 | umap_exists, 32 | ) 33 | from orion.core.operators.topic_filtering_task import ( 34 | FilterTopicsByDistributionOperator, 35 | FilteredTopicsMetadataOperator, 36 | ) 37 | from orion.core.operators.faiss_index_task import FaissIndexOperator 38 | from orion.core.operators.create_viz_tables_task import CreateVizTables, Pandas2Arrow 39 | from orion.core.operators.affiliation_type_task import AffiliationTypeOperator 40 | from orion.core.operators.collect_wb_indicators_task import WBIndicatorOperator 41 | from orion.core.operators.country_details_task import ( 42 | HomogeniseCountryNamesOperator, 43 | CountryDetailsOperator, 44 | ) 45 | from orion.core.operators.postgresql2es_task import Postgreqsl2ElasticSearchOperator 46 | from orion.packages.mag.create_tables import create_db_and_tables 47 | from orion.core.operators.open_access_journals_task import OpenAccessJournalOperator 48 | from dotenv import load_dotenv, find_dotenv 49 | import os 50 | 51 | load_dotenv(find_dotenv()) 52 | 53 | default_args = { 54 | "owner": "Kostas St", 55 | "start_date": datetime(2020, 2, 2), 56 | "depends_on_past": False, 57 | "retries": 0, 58 | "email": ["k.stathoylopoylos@gmail.com"], 59 | "email_on_failure": True, 60 | } 61 | 62 | DAG_ID = "orion" 63 | db_name = orion.config["data"]["db_name"] 64 | DB_CONFIG = os.getenv(db_name) 65 | MAG_API_KEY = os.getenv("mag_api_key") 66 | MAG_OUTPUT_BUCKET = orion.config["s3_buckets"]["mag"] 67 | mag_config = orion.config["data"]["mag"] 68 | query_values = mag_config["query_values"] 69 | entity_name = mag_config["entity_name"] 70 | metadata = mag_config["metadata"] 71 | with_doi = mag_config["with_doi"] 72 | mag_start_date = mag_config["mag_start_date"] 73 | mag_end_date = mag_config["mag_end_date"] 74 | intervals_in_a_year = mag_config["intervals_in_a_year"] 75 | 76 | # geocode_places 77 | google_key = os.getenv("google_api_key") 78 | 79 | # country_collaboration 80 | collab_year = orion.config["country_collaboration"]["year"] 81 | 82 | # batch_names 83 | BATCH_SIZE = orion.config["batch_size"] 84 | S3_BUCKET = orion.config["s3_buckets"]["gender"] 85 | PREFIX = orion.config["prefix"]["gender"] 86 | 87 | # gender_inference_N 88 | parallel_tasks = orion.config["parallel_tasks"] 89 | auth_token = os.getenv("gender_api") 90 | 91 | # text2vector 92 | text_vectors_bucket = orion.config["s3_buckets"]["text_vectors"] 93 | bert_model = orion.config["sentence_bert"]["bert_model"] 94 | bert_batch_size = orion.config["sentence_bert"]["batch_size"] 95 | 96 | # dim_reduction 97 | umap_config = orion.config["umap"] 98 | # umap hyperparameters 99 | n_neighbors = umap_config["n_neighbors"] 100 | n_components = umap_config["n_components"] 101 | metric = umap_config["metric"] 102 | min_dist = umap_config["min_dist"] 103 | exclude_docs = umap_config["exclude"] 104 | 105 | # topic_filtering 106 | topic_prefix = orion.config["prefix"]["topic"] 107 | topic_bucket = orion.config["s3_buckets"]["topic"] 108 | topic_config = orion.config["topic_filter"] 109 | levels = topic_config["levels"] 110 | percentiles = topic_config["percentiles"] 111 | 112 | # metrics 113 | thresh = orion.config["gender_diversity"]["threshold"] 114 | paper_thresh_low = orion.config["metrics"]["paper_count_low"] 115 | year_thresh = orion.config["metrics"]["year"] 116 | fos_thresh = orion.config["metrics"]["fos_count"] 117 | 118 | # wb indicators 119 | wb_country = orion.config["data"]["wb"]["country"] 120 | wb_end_year = orion.config["data"]["wb"]["end_year"] 121 | wb_indicators = orion.config["data"]["wb"]["indicators"] 122 | wb_table_names = orion.config["data"]["wb"]["table_names"] 123 | 124 | # PostgreSQL to Elasticsearch 125 | es_index = os.getenv("es_index") 126 | es_host = os.getenv("es_host") 127 | es_port = os.getenv("es_port") 128 | erase_es_index = orion.config["elasticsearch"]["erase_index"] 129 | aws_region = os.getenv("region") 130 | 131 | with DAG( 132 | dag_id=DAG_ID, 133 | default_args=default_args, 134 | catchup=False, 135 | schedule_interval="@monthly", 136 | ) as dag: 137 | 138 | dummy_task = DummyOperator(task_id="start") 139 | 140 | dummy_task_2 = DummyOperator(task_id="gender_agg") 141 | 142 | dummy_task_3 = DummyOperator(task_id="world_bank_indicators") 143 | 144 | dummy_task_4 = DummyOperator(task_id="create_s3_buckets") 145 | 146 | dummy_task_5 = DummyOperator(task_id="s3_buckets") 147 | 148 | create_tables = PythonOperator( 149 | task_id="create_tables", 150 | python_callable=create_db_and_tables, 151 | op_kwargs={"db": db_name}, 152 | ) 153 | 154 | create_buckets = [ 155 | PythonOperator( 156 | task_id=bucket, 157 | python_callable=create_s3_bucket, 158 | op_kwargs={"bucket": bucket}, 159 | ) 160 | for bucket in [MAG_OUTPUT_BUCKET, S3_BUCKET, topic_bucket, text_vectors_bucket] 161 | ] 162 | 163 | query_mag = MagCollectionOperator( 164 | task_id="query_mag", 165 | output_bucket=MAG_OUTPUT_BUCKET, 166 | subscription_key=MAG_API_KEY, 167 | query_values=query_values, 168 | entity_name=entity_name, 169 | metadata=metadata, 170 | with_doi=with_doi, 171 | mag_start_date=mag_start_date, 172 | mag_end_date=mag_end_date, 173 | intervals_in_a_year=intervals_in_a_year, 174 | ) 175 | 176 | parse_mag = MagParserOperator( 177 | task_id="parse_mag", s3_bucket=MAG_OUTPUT_BUCKET, db_config=DB_CONFIG 178 | ) 179 | 180 | geocode_places = GeocodingOperator( 181 | task_id="geocode_places", db_config=DB_CONFIG, subscription_key=google_key 182 | ) 183 | 184 | collect_fos = MagFosCollectionOperator( 185 | task_id="collect_fos_metadata", 186 | db_config=DB_CONFIG, 187 | subscription_key=MAG_API_KEY, 188 | ) 189 | 190 | fos_frequency = FosFrequencyOperator(task_id="fos_frequency", db_config=DB_CONFIG) 191 | 192 | batch_names = NamesBatchesOperator( 193 | task_id="batch_names", 194 | db_config=DB_CONFIG, 195 | s3_bucket=S3_BUCKET, 196 | prefix=PREFIX, 197 | batch_size=BATCH_SIZE, 198 | ) 199 | 200 | batch_task_gender = [] 201 | for parallel_task in range(parallel_tasks): 202 | task_id = f"gender_inference_{parallel_task}" 203 | batch_task_gender.append( 204 | GenderInferenceOperator( 205 | task_id=task_id, 206 | db_config=DB_CONFIG, 207 | s3_bucket=S3_BUCKET, 208 | prefix=f"{PREFIX}_{parallel_task}", 209 | auth_token=auth_token, 210 | ) 211 | ) 212 | 213 | rca = RCAOperator( 214 | task_id="rca_measurement", 215 | db_config=DB_CONFIG, 216 | year_thresh=year_thresh, 217 | paper_thresh=paper_thresh_low, 218 | ) 219 | 220 | text2vector = Text2SentenceBertOperator( 221 | task_id="text2vector", 222 | db_config=DB_CONFIG, 223 | batch_size=bert_batch_size, 224 | bert_model=bert_model, 225 | ) 226 | 227 | dim_reduction = BranchPythonOperator( 228 | task_id="dim_reduction", 229 | python_callable=umap_exists, 230 | op_kwargs={"bucket": text_vectors_bucket}, 231 | trigger_rule="all_done", 232 | ) 233 | 234 | fit_umap = DimReductionOperator( 235 | task_id="fit_umap", 236 | db_config=DB_CONFIG, 237 | s3_bucket=text_vectors_bucket, 238 | n_neighbors=n_neighbors, 239 | min_dist=min_dist, 240 | n_components=n_components, 241 | metric=metric, 242 | exclude_docs=exclude_docs, 243 | ) 244 | 245 | call_umap = DimReductionFittedUmapOperator( 246 | task_id="call_umap", 247 | db_config=DB_CONFIG, 248 | s3_bucket=text_vectors_bucket, 249 | exclude_docs=exclude_docs, 250 | ) 251 | 252 | country_collaboration_graph = CountryCollaborationOperator( 253 | task_id="country_collaboration", db_config=DB_CONFIG, year=collab_year 254 | ) 255 | 256 | country_similarity = CountrySimilarityOperator( 257 | task_id="country_similarity", db_config=DB_CONFIG, year=collab_year 258 | ) 259 | 260 | topic_filtering = FilterTopicsByDistributionOperator( 261 | task_id="filter_topics", 262 | db_config=DB_CONFIG, 263 | s3_bucket=topic_bucket, 264 | prefix=topic_prefix, 265 | levels=levels, 266 | percentiles=percentiles, 267 | ) 268 | 269 | filtered_topic_metadata = FilteredTopicsMetadataOperator( 270 | task_id="topic_metadata", 271 | db_config=DB_CONFIG, 272 | s3_bucket=topic_bucket, 273 | prefix=topic_prefix, 274 | ) 275 | 276 | research_diversity = ResearchDiversityOperator( 277 | task_id="research_diversity", 278 | db_config=DB_CONFIG, 279 | fos_thresh=fos_thresh, 280 | year_thresh=year_thresh, 281 | ) 282 | 283 | gender_diversity = GenderDiversityOperator( 284 | task_id="gender_diversity", 285 | db_config=DB_CONFIG, 286 | paper_thresh=paper_thresh_low, 287 | thresh=thresh, 288 | ) 289 | 290 | faiss_index = FaissIndexOperator( 291 | task_id="faiss_index", bucket=text_vectors_bucket, db_config=DB_CONFIG 292 | ) 293 | 294 | viz_tables = CreateVizTables(task_id="viz_tables", db_config=DB_CONFIG) 295 | 296 | aff_types = AffiliationTypeOperator(task_id="affiliation_type", db_config=DB_CONFIG) 297 | 298 | batch_task_wb = [] 299 | for wb_indicator, wb_table_name in zip(wb_indicators, wb_table_names): 300 | task_id = f"{wb_table_name}" 301 | batch_task_wb.append( 302 | WBIndicatorOperator( 303 | task_id=task_id, 304 | db_config=DB_CONFIG, 305 | indicator=wb_indicator, 306 | start_year=year_thresh, 307 | end_year=wb_end_year, 308 | country=wb_country, 309 | table_name=wb_table_name, 310 | ) 311 | ) 312 | 313 | country_association = HomogeniseCountryNamesOperator( 314 | task_id="homogenise_countries", db_config=DB_CONFIG 315 | ) 316 | 317 | country_details = CountryDetailsOperator( 318 | task_id="country_details", db_config=DB_CONFIG 319 | ) 320 | 321 | postgres2es = Postgreqsl2ElasticSearchOperator( 322 | task_id="postgres2es", 323 | db_config=DB_CONFIG, 324 | es_host=es_host, 325 | es_index=es_index, 326 | es_port=es_port, 327 | region=aws_region, 328 | erase_es_index=erase_es_index, 329 | ) 330 | 331 | pandas2arrow = Pandas2Arrow(task_id="pandas2arrow", db_config=DB_CONFIG) 332 | 333 | open_access = OpenAccessJournalOperator(task_id="open_access", db_config=DB_CONFIG) 334 | 335 | dummy_task >> create_tables >> query_mag >> parse_mag 336 | dummy_task >> dummy_task_4 >> create_buckets >> dummy_task_5 >> query_mag 337 | parse_mag >> geocode_places >> rca 338 | parse_mag >> geocode_places >> country_collaboration_graph 339 | parse_mag >> collect_fos >> fos_frequency >> topic_filtering >> filtered_topic_metadata >> viz_tables 340 | filtered_topic_metadata >> rca >> viz_tables 341 | filtered_topic_metadata >> research_diversity 342 | filtered_topic_metadata >> gender_diversity 343 | geocode_places >> research_diversity >> viz_tables 344 | geocode_places >> gender_diversity >> viz_tables 345 | geocode_places >> country_similarity 346 | geocode_places >> viz_tables >> pandas2arrow 347 | text2vector >> country_similarity 348 | text2vector >> pandas2arrow 349 | filtered_topic_metadata >> country_similarity 350 | parse_mag >> batch_names >> batch_task_gender >> dummy_task_2 >> gender_diversity 351 | parse_mag >> text2vector >> dim_reduction 352 | dim_reduction >> fit_umap 353 | dim_reduction >> call_umap 354 | text2vector >> faiss_index 355 | parse_mag >> aff_types 356 | parse_mag >> postgres2es 357 | parse_mag >> open_access 358 | dummy_task >> create_tables >> dummy_task_3 >> batch_task_wb >> country_association 359 | geocode_places >> country_association >> country_details 360 | --------------------------------------------------------------------------------