├── .dockerignore ├── .github └── workflows │ └── deploy.yaml ├── .gitignore ├── ADDING-A-DATA-SOURCE.md ├── Dockerfile ├── LICENSE ├── README.md ├── app ├── README.md ├── __init__.py ├── alembic.ini ├── alembic │ ├── README │ ├── env.py │ ├── script.py.mako │ └── versions │ │ ├── 4d9562314bd3_parnet.py │ │ ├── 513db5127df7_status.py │ │ ├── 792a820e9374_document_id_in_data_source.py │ │ ├── 836a5f803c4d_status.py │ │ └── 9c2f5b290b16_add_fields_to_datasourcetype_model.py ├── api │ ├── __init__.py │ ├── data_source.py │ └── search.py ├── data_source │ ├── __init__.py │ ├── api │ │ ├── __init__.py │ │ ├── base_data_source.py │ │ ├── basic_document.py │ │ ├── context.py │ │ ├── dynamic_loader.py │ │ ├── exception.py │ │ └── utils.py │ └── sources │ │ ├── __init__.py │ │ ├── bookstack │ │ ├── __init__.py │ │ └── bookstack.py │ │ ├── confluence │ │ ├── __init__.py │ │ ├── confluence.py │ │ └── confluence_cloud.py │ │ ├── gitlab │ │ ├── __init__.py │ │ └── gitlab.py │ │ ├── google_drive │ │ ├── __init__.py │ │ └── google_drive.py │ │ ├── jira │ │ ├── __init__.py │ │ ├── jira.py │ │ └── jira_cloud.py │ │ ├── mattermost │ │ ├── __init__.py │ │ └── mattermost.py │ │ ├── rocketchat │ │ ├── __init__.py │ │ └── rocketchat.py │ │ └── slack │ │ ├── __init__.py │ │ └── slack.py ├── db_engine.py ├── indexing │ ├── __init__.py │ ├── background_indexer.py │ ├── bm25_index.py │ ├── faiss_index.py │ └── index_documents.py ├── main.py ├── models.py ├── parsers │ ├── __init__.py │ ├── docx.py │ ├── html.py │ ├── pdf.py │ ├── pptx.py │ └── txt.py ├── paths.py ├── queues │ ├── __init__.py │ ├── index_queue.py │ └── task_queue.py ├── requirements.txt ├── schemas │ ├── __init__.py │ ├── base.py │ ├── data_source.py │ ├── data_source_type.py │ ├── document.py │ └── paragraph.py ├── search_logic.py ├── static │ └── data_source_icons │ │ ├── bookstack.png │ │ ├── confluence.png │ │ ├── confluence_cloud.png │ │ ├── default_icon.png │ │ ├── gitlab.png │ │ ├── google_drive.png │ │ ├── jira.png │ │ ├── jira_cloud.png │ │ ├── mattermost.png │ │ ├── rocketchat.png │ │ └── slack.png ├── telemetry.py ├── util.py └── workers.py ├── deploy.sh ├── docker-compose.yaml ├── docs └── data-sources │ ├── confluence │ ├── confluence-settings.png │ ├── confluence.md │ ├── create-token-screen.png │ ├── create-token.png │ ├── personal-access-tokens.png │ └── settings.png │ ├── gitlab │ ├── access_tokens.png │ ├── gitlab.md │ └── read_api.png │ └── google-drive │ ├── copy-email.png │ ├── create-key-json.png │ ├── create-key.png │ ├── create-project.png │ ├── create-service-account.png │ ├── gerev-settings.png │ ├── google-drive-share.png │ ├── google-drive.md │ ├── keys-tab.png │ ├── loading-screen.png │ └── share-drive-folder.png ├── images ├── CodeCard.png ├── Everything.png ├── api.gif ├── bill.png ├── contact-card.png ├── everything.png ├── integ.jpeg ├── product-example.png └── sql-card.png ├── run.sh └── ui ├── .gitignore ├── README.md ├── package-lock.json ├── package.json ├── public ├── favicon.ico ├── index.html ├── logo192.png ├── logo512.png ├── manifest.json └── robots.txt ├── src ├── App.tsx ├── api.ts ├── assets │ ├── css │ │ ├── App.css │ │ ├── custom-fonts.css │ │ └── index.css │ ├── fonts │ │ ├── Inter-SemiBold.ttf │ │ ├── Poppins-Medium.ttf │ │ ├── Poppins-Regular.ttf │ │ ├── SourceSansPro-Black.ttf │ │ ├── SourceSansPro-Bold.ttf │ │ └── SourceSansPro-SemiBold.ttf │ └── images │ │ ├── blue-folder.svg │ │ ├── bookstack.svg │ │ ├── calendar.svg │ │ ├── copy-this.png │ │ ├── discord.png │ │ ├── docx.svg │ │ ├── enter.svg │ │ ├── gitlab.svg │ │ ├── google-doc.svg │ │ ├── google-drive.svg │ │ ├── left-pane-instructions.png │ │ ├── logo.svg │ │ ├── pdf.svg │ │ ├── pptx.svg │ │ ├── profile-picture-default.svg │ │ ├── pur-dir.svg │ │ ├── usa.png │ │ ├── user.webp │ │ └── warning.svg ├── autocomplete.ts ├── components │ ├── data-source-panel.tsx │ ├── search-bar.tsx │ └── search-result.tsx ├── custom.d.ts ├── data-source.ts ├── index.tsx ├── reportWebVitals.js └── setupTests.js ├── tailwind.config.js ├── tsconfig.json └── webpack.config.js /.dockerignore: -------------------------------------------------------------------------------- 1 | ui/node_modules 2 | app/venv 3 | storage 4 | app/.env 5 | .env -------------------------------------------------------------------------------- /.github/workflows/deploy.yaml: -------------------------------------------------------------------------------- 1 | name: Build and Push Docker image 2 | 3 | on: 4 | # Build from input box 5 | workflow_dispatch: 6 | inputs: 7 | version: 8 | description: 'Version' 9 | required: true 10 | push: 11 | description: 'Push to Docker Hub' 12 | required: true 13 | type: boolean 14 | default: true 15 | 16 | env: 17 | VERSION: ${{ github.event.inputs.version }} 18 | PUSH: ${{ github.event.inputs.push }} 19 | 20 | jobs: 21 | build-and-push: 22 | runs-on: ubuntu-latest 23 | steps: 24 | - name: Checkout code 25 | uses: actions/checkout@v2 26 | 27 | - name: 🐷 TruffleHog OSS 28 | uses: trufflesecurity/trufflehog@v3.29.1 29 | with: 30 | path: ./ 31 | base: ${{ github.event.repository.default_branch }} 32 | head: HEAD 33 | extra_args: --debug --only-verified 34 | 35 | 36 | # Use Caching for npm 37 | - name: Cache node modules 38 | uses: actions/cache@v2 39 | with: 40 | working-directory: ./ui 41 | path: | 42 | node_modules 43 | key: ${{ runner.os }}-npm-${{ hashFiles('**/package-lock.json') }} 44 | restore-keys: | 45 | ${{ runner.os }}-npm- 46 | 47 | - name: Install npm dependencies and build UI 48 | run: | 49 | cd ui 50 | npm install 51 | npm run build 52 | 53 | - name: Login to Docker Hub 54 | uses: docker/login-action@v1 55 | with: 56 | username: ${{ secrets.DOCKER_USERNAME }} 57 | password: ${{ secrets.DOCKER_PASSWORD }} 58 | 59 | - name: Set up QEMU 60 | uses: docker/setup-qemu-action@v2 61 | 62 | - name: Set up Docker Buildx 63 | uses: docker/setup-buildx-action@v2 64 | 65 | - name: Build and push 66 | uses: docker/build-push-action@v4 67 | env: 68 | VERSION: ${{ env.VERSION }} 69 | with: 70 | context: . 71 | file: ./Dockerfile 72 | platforms: linux/amd64,linux/arm64 73 | push: ${{ env.PUSH == 'true' }} 74 | tags: | 75 | gerev/gerev:${{ env.VERSION }} 76 | gerev/gerev:latest 77 | cache-from: type=gha 78 | cache-to: type=gha,mode=max -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .ipynb_checkpoints 3 | .mypy_cache 4 | .vscode 5 | __pycache__ 6 | .pytest_cache 7 | htmlcov 8 | .coverage 9 | coverage.xml 10 | storage 11 | log.txt 12 | Pipfile.lock 13 | env3.* 14 | env 15 | venv 16 | 17 | # vim temporary files 18 | *~ 19 | .*.sw? 20 | 21 | 22 | .env -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9 2 | 3 | ENV DOCKER_DEPLOYMENT=1 4 | 5 | RUN pip install torch 6 | 7 | COPY ./app/requirements.txt /tmp/requirements.txt 8 | 9 | RUN pip install -r /tmp/requirements.txt 10 | 11 | COPY ./app/models.py /tmp/models.py 12 | 13 | # cache the models 14 | RUN python3 /tmp/models.py 15 | 16 | COPY ./app /app 17 | 18 | COPY ./ui/build /ui 19 | 20 | COPY ./run.sh /app/run.sh 21 | 22 | WORKDIR /app 23 | 24 | VOLUME [ "/opt/storage" ] 25 | 26 | EXPOSE 80 27 | 28 | CMD ./run.sh 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Gerev Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [⚡🔎 Live online demo!](https://app.klu.so/signup?utm_source=github_gerevai) 2 | # AI-powered enterprise search engine 🔎 3 | 4 | ### Join Discord for early access code! 5 | [![Discord Follow](https://dcbadge.vercel.app/api/server/7hNdF7yu8r?style=flat)](https://discord.gg/7hNdF7yu8r) 6 | [![DockerHub Pulls][docker-pull-img]][docker-pull] 7 | 8 | 9 | [docker-pull]: https://hub.docker.com/r/gerev/gerev 10 | [docker-pull-img]: https://img.shields.io/docker/pulls/gerev/gerev.svg 11 | 12 | [Join here!](https://app.klu.so/signup?utm_source=github_gerevai) 13 | 14 | # Search engine for your organization! 15 | ![first image](./images/api.gif) 16 | **Find any conversation, doc, or internal page in seconds** ⏲️⚡️ 17 | **Join 100+** devs by hosting your own gerev instance, become a **hero** within your org! 💪 18 | 19 | 20 | * [Integrations](#integrations) 21 | * [Getting Started](#getting-started) 22 | * [Cloud](#managed-cloud-pro) 23 | * [Self-Hosted](#self-hosted-community) 24 | 25 | 26 | ## Made for help desk techies 👨‍💻 27 | ### Troubleshoot Issues 🐛 28 | ![fourth image](./images/sql-card.png) 29 | 30 | ### Or find internal issues _fast_ ⚡️ 31 | ![second image](./images/product-example.png) 32 | 33 | 34 | Integrations 35 | ============ 36 | - [x] Slack 37 | - [x] Confluence 38 | - [X] Jira 39 | - [x] Google Drive (Docs, .docx, .pptx) - by [@bary12](https://github.com/bary12) :pray: 40 | - [X] Confluence Cloud - by [@bryan-pakulski](https://github.com/bryan-pakulski) :pray: 41 | - [X] Bookstack - by [@flifloo](https://github.com/flifloo) :pray: 42 | - [X] Mattermost - by [@itaykal](https://github.com/Itaykal) :pray: 43 | - [X] RocketChat - by [@flifloo](https://github.com/flifloo) :pray: 44 | - [X] Gitlab Issues - by [@eran1232](https://github.com/eran1232) :pray: 45 | - [ ] Zendesk (In PR :pray:) 46 | - [ ] Stackoverflow Teams (In PR :pray:) 47 | - [ ] Azure DevOps (In PR :pray:) 48 | - [ ] Phabricator (In PR :pray:) 49 | - [ ] Trello (In PR... :pray:) 50 | - [ ] Notion (In Progress... :pray:) 51 | - [ ] Asana 52 | - [ ] Sharepoint 53 | - [ ] Box 54 | - [ ] Dropbox 55 | - [ ] Github Enterprise 56 | - [ ] Microsoft Teams 57 | 58 | 59 | :pray: - by the community 60 | 61 | ## Add your own data source NOW 🚀 62 | See the full guide at [ADDING-A-DATA-SOURCE.md](./ADDING-A-DATA-SOURCE.md). 63 | 64 | 65 | ## Natural Language 66 | Enables searching using natural language. such as `"How to do X"`, `"how to connect to Y"`, `"Do we support Z"` 67 | 68 | --- 69 | 70 | Getting Started 71 | ============ 72 | ## Managed Cloud (Pro) 73 | [Sign up Free](https://app.klu.so/signup?utm_source=github_gerevai) 74 | - [X] Authentication 75 | - [X] Multiple Users 76 | - [X] GPU machine 77 | - [X] 24/7 Support 78 | - Self hosted version (with multi-user also supported) 79 | 80 | ## Self-hosted (Community) 81 | 1. Install *Nvidia for docker* (on host that runs the docker runtime) 82 | 2. Run docker 83 | 84 | ### Nvidia for docker 85 | Install nvidia container toolkit on the host machine. 86 | 87 | ``` 88 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) \ 89 | && curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - \ 90 | && curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list 91 | 92 | sudo apt-get update 93 | 94 | sudo apt-get install -y nvidia-docker2 95 | 96 | sudo systemctl restart docker 97 | ``` 98 | 99 | 100 | ### Run docker 101 | Then run the docker container like so: 102 | 103 | ### Nvidia hardware 104 | ```bash 105 | docker run --gpus all --name=gerev -p 80:80 -v ~/.gerev/storage:/opt/storage gerev/gerev 106 | ``` 107 | 108 | ### CPU only (no GPU) 109 | ``` 110 | docker run --name=gerev -p 80:80 -v ~/.gerev/storage:/opt/storage gerev/gerev 111 | ``` 112 | add `-d` if you want to detach the container. 113 | 114 | ## Run from source 115 | See [ADDING-A-DATA-SOURCE.md](./ADDING-A-DATA-SOURCE.md) in the Setup development environment section. 116 | 117 | 118 | - **gerev is also popular with some big names. 😉** 119 | 120 | --- 121 | 122 | ![first image](./images/bill.png) 123 | 124 | Built by the community 💜 125 | 126 | 127 | 128 | 129 | 130 | Made with [contributors-img](https://contrib.rocks). 131 | -------------------------------------------------------------------------------- /app/README.md: -------------------------------------------------------------------------------- 1 | Run: 2 | 3 | `uvicorn main:app --env-file .env` 4 | 5 | Dev mode: 6 | 7 | `uvicorn main:app --env-file .env --reload` -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/__init__.py -------------------------------------------------------------------------------- /app/alembic.ini: -------------------------------------------------------------------------------- 1 | # A generic, single database configuration. 2 | 3 | [alembic] 4 | # path to migration scripts 5 | script_location = alembic 6 | 7 | # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s 8 | # Uncomment the line below if you want the files to be prepended with date and time 9 | # see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file 10 | # for all available tokens 11 | # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s 12 | 13 | # sys.path path, will be prepended to sys.path if present. 14 | # defaults to the current working directory. 15 | prepend_sys_path = . 16 | 17 | # timezone to use when rendering the date within the migration file 18 | # as well as the filename. 19 | # If specified, requires the python-dateutil library that can be 20 | # installed by adding `alembic[tz]` to the pip requirements 21 | # string value is passed to dateutil.tz.gettz() 22 | # leave blank for localtime 23 | # timezone = 24 | 25 | # max length of characters to apply to the 26 | # "slug" field 27 | # truncate_slug_length = 40 28 | 29 | # set to 'true' to run the environment during 30 | # the 'revision' command, regardless of autogenerate 31 | # revision_environment = false 32 | 33 | # set to 'true' to allow .pyc and .pyo files without 34 | # a source .py file to be detected as revisions in the 35 | # versions/ directory 36 | # sourceless = false 37 | 38 | # version location specification; This defaults 39 | # to alembic/versions. When using multiple version 40 | # directories, initial revisions must be specified with --version-path. 41 | # The path separator used here should be the separator specified by "version_path_separator" below. 42 | # version_locations = %(here)s/bar:%(here)s/bat:alembic/versions 43 | 44 | # version path separator; As mentioned above, this is the character used to split 45 | # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. 46 | # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. 47 | # Valid values for version_path_separator are: 48 | # 49 | # version_path_separator = : 50 | # version_path_separator = ; 51 | # version_path_separator = space 52 | version_path_separator = os # Use os.pathsep. Default configuration used for new projects. 53 | 54 | # set to 'true' to search source files recursively 55 | # in each "version_locations" directory 56 | # new in Alembic version 1.10 57 | # recursive_version_locations = false 58 | 59 | # the output encoding used when revision files 60 | # are written from script.py.mako 61 | # output_encoding = utf-8 62 | 63 | 64 | 65 | [post_write_hooks] 66 | # post_write_hooks defines scripts or Python functions that are run 67 | # on newly generated revision scripts. See the documentation for further 68 | # detail and examples 69 | 70 | # format using "black" - use the console_scripts runner, against the "black" entrypoint 71 | # hooks = black 72 | # black.type = console_scripts 73 | # black.entrypoint = black 74 | # black.options = -l 79 REVISION_SCRIPT_FILENAME 75 | 76 | # Logging configuration 77 | [loggers] 78 | keys = root,sqlalchemy,alembic 79 | 80 | [handlers] 81 | keys = console 82 | 83 | [formatters] 84 | keys = generic 85 | 86 | [logger_root] 87 | level = WARN 88 | handlers = console 89 | qualname = 90 | 91 | [logger_sqlalchemy] 92 | level = WARN 93 | handlers = 94 | qualname = sqlalchemy.engine 95 | 96 | [logger_alembic] 97 | level = INFO 98 | handlers = 99 | qualname = alembic 100 | 101 | [handler_console] 102 | class = StreamHandler 103 | args = (sys.stderr,) 104 | level = NOTSET 105 | formatter = generic 106 | 107 | [formatter_generic] 108 | format = %(levelname)-5.5s [%(name)s] %(message)s 109 | datefmt = %H:%M:%S 110 | -------------------------------------------------------------------------------- /app/alembic/README: -------------------------------------------------------------------------------- 1 | Generic single-database configuration. -------------------------------------------------------------------------------- /app/alembic/env.py: -------------------------------------------------------------------------------- 1 | from logging.config import fileConfig 2 | 3 | from sqlalchemy import engine_from_config 4 | from sqlalchemy import pool 5 | 6 | from alembic import context 7 | 8 | from paths import SQLITE_DB_PATH 9 | from schemas.base import Base 10 | from db_engine import engine 11 | 12 | # this is the Alembic Config object, which provides 13 | # access to the values within the .ini file in use. 14 | config = context.config 15 | 16 | # Interpret the config file for Python logging. 17 | # This line sets up loggers basically. 18 | if config.config_file_name is not None: 19 | fileConfig(config.config_file_name) 20 | 21 | 22 | target_metadata = Base.metadata 23 | 24 | 25 | def run_migrations_offline() -> None: 26 | """Run migrations in 'offline' mode. 27 | 28 | This configures the context with just a URL 29 | and not an Engine, though an Engine is acceptable 30 | here as well. By skipping the Engine creation 31 | we don't even need a DBAPI to be available. 32 | 33 | Calls to context.execute() here emit the given string to the 34 | script output. 35 | """ 36 | print("running migrations offline") 37 | url = f'sqlite:///{SQLITE_DB_PATH}' 38 | context.configure( 39 | url=url, 40 | target_metadata=target_metadata, 41 | literal_binds=True, 42 | dialect_opts={"paramstyle": "named"}, 43 | ) 44 | 45 | with context.begin_transaction(): 46 | context.run_migrations() 47 | 48 | 49 | def run_migrations_online() -> None: 50 | """Run migrations in 'online' mode. 51 | 52 | In this scenario we need to create an Engine 53 | and associate a connection with the context. 54 | 55 | """ 56 | print("running online") 57 | connectable = engine 58 | with connectable.connect() as connection: 59 | context.configure( 60 | connection=connection, target_metadata=target_metadata, render_as_batch=True 61 | ) 62 | 63 | with context.begin_transaction(): 64 | context.run_migrations() 65 | 66 | 67 | if context.is_offline_mode(): 68 | run_migrations_offline() 69 | else: 70 | run_migrations_online() 71 | -------------------------------------------------------------------------------- /app/alembic/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | ${imports if imports else ""} 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = ${repr(up_revision)} 14 | down_revision = ${repr(down_revision)} 15 | branch_labels = ${repr(branch_labels)} 16 | depends_on = ${repr(depends_on)} 17 | 18 | 19 | def upgrade() -> None: 20 | ${upgrades if upgrades else "pass"} 21 | 22 | 23 | def downgrade() -> None: 24 | ${downgrades if downgrades else "pass"} 25 | -------------------------------------------------------------------------------- /app/alembic/versions/4d9562314bd3_parnet.py: -------------------------------------------------------------------------------- 1 | """parnet 2 | 3 | Revision ID: 4d9562314bd3 4 | Revises: 513db5127df7 5 | Create Date: 2023-04-07 16:26:31.269456 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = '4d9562314bd3' 14 | down_revision = '513db5127df7' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade() -> None: 20 | # ### commands auto generated by Alembic - please adjust! ### 21 | try: 22 | with op.batch_alter_table('document', schema=None) as batch_op: 23 | batch_op.add_column(sa.Column('parent_id', sa.Integer(), nullable=True)) 24 | batch_op.create_foreign_key('fg_parent_id', 'document', ['parent_id'], ['id']) 25 | except: 26 | pass 27 | 28 | # ### end Alembic commands ### 29 | 30 | 31 | def downgrade() -> None: 32 | # ### commands auto generated by Alembic - please adjust! ### 33 | try: 34 | with op.batch_alter_table('document', schema=None) as batch_op: 35 | batch_op.drop_constraint(None, type_='foreignkey') 36 | batch_op.drop_column('parent_id') 37 | except: 38 | pass 39 | # ### end Alembic commands ### 40 | -------------------------------------------------------------------------------- /app/alembic/versions/513db5127df7_status.py: -------------------------------------------------------------------------------- 1 | """Your migration message here 2 | 3 | Revision ID: 513db5127df7 4 | Revises: 792a820e9374 5 | Create Date: 2023-04-07 05:01:02.427956 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = '513db5127df7' 14 | down_revision = '792a820e9374' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade() -> None: 20 | try: 21 | op.add_column('document', sa.Column('status', sa.String(length=32), nullable=True)) 22 | except: 23 | pass 24 | 25 | 26 | def downgrade() -> None: 27 | op.drop_column('document', 'status') 28 | -------------------------------------------------------------------------------- /app/alembic/versions/792a820e9374_document_id_in_data_source.py: -------------------------------------------------------------------------------- 1 | """document id_in_data_source 2 | 3 | Revision ID: 792a820e9374 4 | Revises: 9c2f5b290b16 5 | Create Date: 2023-03-26 11:27:05.341609 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = '792a820e9374' 14 | down_revision = '9c2f5b290b16' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade() -> None: 20 | try: 21 | op.add_column('document', sa.Column('id_in_data_source', sa.String(length=64), default='__none__')) 22 | except: 23 | pass 24 | 25 | 26 | def downgrade() -> None: 27 | op.drop_column('document', 'id_in_data_source') 28 | -------------------------------------------------------------------------------- /app/alembic/versions/836a5f803c4d_status.py: -------------------------------------------------------------------------------- 1 | """status 2 | 3 | Revision ID: 836a5f803c4d 4 | Revises: 4d9562314bd3 5 | Create Date: 2023-04-11 03:17:06.459499 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = '836a5f803c4d' 14 | down_revision = '4d9562314bd3' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade() -> None: 20 | # ### commands auto generated by Alembic - please adjust! ### 21 | try: 22 | with op.batch_alter_table('document', schema=None) as batch_op: 23 | batch_op.add_column(sa.Column('is_active', sa.Boolean(), nullable=True)) 24 | except Exception as e: 25 | print(e) 26 | # ### end Alembic commands ### 27 | 28 | 29 | def downgrade() -> None: 30 | # ### commands auto generated by Alembic - please adjust! ### 31 | try: 32 | with op.batch_alter_table('document', schema=None) as batch_op: 33 | batch_op.drop_column('is_active') 34 | except Exception as e: 35 | print(e) 36 | # ### end Alembic commands ### 37 | -------------------------------------------------------------------------------- /app/alembic/versions/9c2f5b290b16_add_fields_to_datasourcetype_model.py: -------------------------------------------------------------------------------- 1 | """Add fields to DataSourceType model 2 | 3 | Revision ID: 9c2f5b290b16 4 | Revises: 5 | Create Date: 2023-03-20 14:19:47.665501 6 | 7 | """ 8 | import json 9 | 10 | from alembic import op 11 | import sqlalchemy as sa 12 | 13 | from data_source.api.dynamic_loader import DynamicLoader 14 | from db_engine import Session 15 | from schemas import DataSourceType 16 | 17 | # revision identifiers, used by Alembic. 18 | revision = '9c2f5b290b16' 19 | down_revision = None 20 | branch_labels = None 21 | depends_on = None 22 | 23 | 24 | def upgrade() -> None: 25 | try: 26 | op.add_column('data_source_type', sa.Column('display_name', sa.String(length=32), nullable=True)) 27 | op.add_column('data_source_type', sa.Column('config_fields', sa.String(length=1024), nullable=True)) 28 | with Session() as session: 29 | # update existing data sources 30 | data_source_types = session.query(DataSourceType).all() 31 | for data_source_type in data_source_types: 32 | data_source_class = DynamicLoader.get_data_source_class(data_source_type.name) 33 | config_fields = data_source_class.get_config_fields() 34 | 35 | data_source_type.config_fields = json.dumps([config_field.dict() for config_field in config_fields]) 36 | data_source_type.display_name = data_source_class.get_display_name() 37 | 38 | session.commit() 39 | except: 40 | pass 41 | 42 | 43 | def downgrade() -> None: 44 | # ### commands auto generated by Alembic - please adjust! ### 45 | op.drop_column('data_source_type', 'config_fields') 46 | op.drop_column('data_source_type', 'display_name') 47 | # ### end Alembic commands ### 48 | -------------------------------------------------------------------------------- /app/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/api/__init__.py -------------------------------------------------------------------------------- /app/api/data_source.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import logging 4 | import os.path 5 | from typing import List 6 | 7 | from fastapi import APIRouter 8 | from pydantic import BaseModel 9 | from starlette.background import BackgroundTasks 10 | from starlette.requests import Request 11 | 12 | from data_source.api.base_data_source import ConfigField, BaseDataSource, Location 13 | from data_source.api.context import DataSourceContext 14 | from db_engine import Session 15 | from indexing.background_indexer import BackgroundIndexer 16 | from schemas import DataSource 17 | from telemetry import Posthog 18 | 19 | router = APIRouter( 20 | prefix='/data-sources', 21 | ) 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class DataSourceTypeDto(BaseModel): 27 | name: str 28 | display_name: str 29 | config_fields: List[ConfigField] 30 | image_base64: str 31 | has_prerequisites: bool 32 | 33 | @staticmethod 34 | def from_data_source_class(name: str, data_source_class: BaseDataSource) -> 'DataSourceTypeDto': 35 | icon_path_template = "static/data_source_icons/{name}.png" 36 | data_source_icon = icon_path_template.format(name=name) 37 | 38 | if not os.path.exists(data_source_icon): 39 | data_source_icon = icon_path_template.format(name="default_icon") 40 | 41 | with open(data_source_icon, "rb") as file: 42 | encoded_string = base64.b64encode(file.read()) 43 | image_base64 = f"data:image/png;base64,{encoded_string.decode()}" 44 | 45 | return DataSourceTypeDto( 46 | name=name, 47 | display_name=data_source_class.get_display_name(), 48 | config_fields=data_source_class.get_config_fields(), 49 | image_base64=image_base64, 50 | has_prerequisites=data_source_class.has_prerequisites() 51 | ) 52 | 53 | 54 | class ConnectedDataSourceDto(BaseModel): 55 | id: int 56 | name: str 57 | 58 | 59 | class AddDataSourceDto(BaseModel): 60 | name: str 61 | config: dict 62 | 63 | 64 | @router.get("/types") 65 | async def list_data_source_types() -> List[DataSourceTypeDto]: 66 | return [DataSourceTypeDto.from_data_source_class(name=name, data_source_class=data_source_class) 67 | for name, data_source_class in DataSourceContext.get_data_source_classes().items()] 68 | 69 | 70 | @router.get("/connected") 71 | async def list_connected_data_sources() -> List[ConnectedDataSourceDto]: 72 | with Session() as session: 73 | data_sources = session.query(DataSource).all() 74 | return [ConnectedDataSourceDto(id=data_source.id, name=data_source.type.name) 75 | for data_source in data_sources] 76 | 77 | 78 | @router.delete("/{data_source_id}") 79 | async def delete_data_source(request: Request, data_source_id: int): 80 | deleted_name = DataSourceContext.delete_data_source(data_source_id=data_source_id) 81 | BackgroundIndexer.reset_indexed_count() 82 | Posthog.removed_data_source(uuid=request.headers.get('uuid'), name=deleted_name) 83 | return {"success": "Data source deleted successfully"} 84 | 85 | 86 | @router.post("/{data_source_name}/list-locations") 87 | async def list_locations(request: Request, data_source_name: str, config: dict) -> List[Location]: 88 | data_source = DataSourceContext.get_data_source_class(data_source_name=data_source_name) 89 | locations = data_source.list_locations(config=config) 90 | Posthog.listed_locations(uuid=request.headers.get('uuid'), name=data_source_name) 91 | return locations 92 | 93 | 94 | @router.post("") 95 | async def connect_data_source(request: Request, dto: AddDataSourceDto, background_tasks: BackgroundTasks) -> int: 96 | logger.info(f"Adding data source {dto.name} with config {json.dumps(dto.config)}") 97 | data_source = await DataSourceContext.create_data_source(name=dto.name, config=dto.config) 98 | Posthog.added_data_source(uuid=request.headers.get('uuid'), name=dto.name) 99 | # in main.py we have a background task that runs every 5 minutes and indexes the data source 100 | # but here we want to index the data source immediately 101 | background_tasks.add_task(data_source.index) 102 | 103 | return data_source.get_id() 104 | -------------------------------------------------------------------------------- /app/api/search.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | from starlette.requests import Request 3 | 4 | from search_logic import search_documents 5 | from telemetry import Posthog 6 | 7 | router = APIRouter( 8 | prefix='/search', 9 | ) 10 | 11 | 12 | @router.get("") 13 | async def search(request: Request, query: str, top_k: int = 10): 14 | uuid_header = request.headers.get('uuid') 15 | Posthog.increase_search_count(uuid=uuid_header) 16 | return search_documents(query, top_k) 17 | -------------------------------------------------------------------------------- /app/data_source/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/data_source/__init__.py -------------------------------------------------------------------------------- /app/data_source/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/data_source/api/__init__.py -------------------------------------------------------------------------------- /app/data_source/api/base_data_source.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from abc import abstractmethod, ABC 4 | from datetime import datetime 5 | from enum import Enum 6 | from typing import Dict, List, Optional, Callable 7 | 8 | from pydantic import BaseModel 9 | 10 | from data_source.api.utils import get_utc_time_now 11 | from db_engine import Session 12 | from queues.task_queue import TaskQueue, Task 13 | from schemas import DataSource 14 | 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class Location(BaseModel): 20 | value: str 21 | label: str 22 | 23 | 24 | class HTMLInputType(Enum): 25 | TEXT = "text" 26 | TEXTAREA = "textarea" 27 | PASSWORD = "password" 28 | 29 | 30 | class BaseDataSourceConfig(BaseModel): 31 | locations_to_index: List[Location] = [] 32 | 33 | 34 | class ConfigField(BaseModel): 35 | name: str 36 | input_type: HTMLInputType = HTMLInputType.TEXT 37 | label: Optional[str] = None 38 | placeholder: Optional[str] = None 39 | 40 | def __init__(self, **data): 41 | name = data.get("name") 42 | label = data.get("label") or name.title() 43 | data["label"] = label 44 | data["placeholder"] = data.get("placeholder") or label 45 | super().__init__(**data) 46 | 47 | class Config: 48 | use_enum_values = True 49 | 50 | 51 | class BaseDataSource(ABC): 52 | 53 | @staticmethod 54 | @abstractmethod 55 | def get_config_fields() -> List[ConfigField]: 56 | """ 57 | Returns a list of fields that are required to configure the data source for UI. 58 | for example: 59 | [ 60 | ConfigField(label="Url", name="url", type="text", placeholder="https://example.com"), 61 | ConfigField(label="Token", name="token", type="password", placeholder="paste-your-token-here") 62 | ] 63 | """ 64 | raise NotImplementedError 65 | 66 | @staticmethod 67 | @abstractmethod 68 | async def validate_config(config: Dict) -> None: 69 | """ 70 | Validates the config and raises an exception if it's invalid. 71 | """ 72 | raise NotImplementedError 73 | 74 | @classmethod 75 | def get_display_name(cls) -> str: 76 | """ 77 | Returns the display name of the data source, change GoogleDriveDataSource to Google Drive. 78 | """ 79 | pascal_case_source = cls.__name__.replace("DataSource", "") 80 | words = re.findall('[A-Z][^A-Z]*', pascal_case_source) 81 | return " ".join(words) 82 | 83 | @staticmethod 84 | def has_prerequisites() -> bool: 85 | """ 86 | Data sources that require some prerequisites to be installed before they can be used should override this method 87 | """ 88 | return False 89 | 90 | @staticmethod 91 | def list_locations(config: Dict) -> List[Location]: 92 | """ 93 | Returns a list of locations that are available in the data source. 94 | Only for data sources that want the user to select only some location to index 95 | """ 96 | return [] 97 | 98 | @abstractmethod 99 | def _feed_new_documents(self) -> None: 100 | """ 101 | Feeds the indexing queue with new documents. 102 | """ 103 | raise NotImplementedError 104 | 105 | def __init__(self, config: Dict, data_source_id: int, last_index_time: datetime = None) -> None: 106 | self._raw_config = config 107 | self._config: BaseDataSourceConfig = BaseDataSourceConfig(**self._raw_config) 108 | self._data_source_id = data_source_id 109 | 110 | if last_index_time is None: 111 | last_index_time = datetime(2012, 1, 1) 112 | self._last_index_time = last_index_time 113 | self._last_task_time = None 114 | 115 | def get_id(self): 116 | return self._data_source_id 117 | 118 | def _save_index_time_in_db(self) -> None: 119 | """ 120 | Sets the index time in the database, to be now 121 | """ 122 | with Session() as session: 123 | data_source: DataSource = session.query(DataSource).filter_by(id=self._data_source_id).first() 124 | data_source.last_indexed_at = get_utc_time_now() 125 | session.commit() 126 | 127 | def add_task_to_queue(self, function: Callable, **kwargs): 128 | task = Task(data_source_id=self._data_source_id, 129 | function_name=function.__name__, 130 | kwargs=kwargs) 131 | TaskQueue.get_instance().add_task(task) 132 | 133 | def run_task(self, function_name: str, **kwargs) -> None: 134 | self._last_task_time = get_utc_time_now() 135 | function = getattr(self, function_name) 136 | function(**kwargs) 137 | 138 | def index(self, force: bool = False) -> None: 139 | if self._last_task_time is not None and not force: 140 | # Don't index if the last task was less than an hour ago 141 | time_since_last_task = get_utc_time_now() - self._last_task_time 142 | if time_since_last_task.total_seconds() < 60 * 60: 143 | logging.info("Skipping indexing data source because it was indexed recently") 144 | return 145 | 146 | try: 147 | self._save_index_time_in_db() 148 | self._feed_new_documents() 149 | except Exception as e: 150 | logging.exception("Error while indexing data source") 151 | 152 | def _is_prior_to_last_index_time(self, doc_time: datetime) -> bool: 153 | if doc_time.tzinfo is not None and self._last_index_time.tzinfo is None: 154 | self._last_index_time = self._last_index_time.replace(tzinfo=doc_time.tzinfo) 155 | 156 | return doc_time < self._last_index_time 157 | -------------------------------------------------------------------------------- /app/data_source/api/basic_document.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from dataclasses import dataclass 3 | from enum import Enum 4 | from typing import Union, List 5 | 6 | class DocumentType(Enum): 7 | DOCUMENT = "document" 8 | MESSAGE = "message" 9 | COMMENT = "comment" 10 | PERSON = "person" 11 | ISSUE = "issue" 12 | GIT_PR = "git_pr" 13 | 14 | 15 | class DocumentStatus(Enum): 16 | OPEN = "open" 17 | IN_PROGRESS = "in_progress" 18 | CLOSED = "closed" 19 | 20 | 21 | class FileType(Enum): 22 | GOOGLE_DOC = "doc" 23 | DOCX = "docx" 24 | PPTX = "pptx" 25 | TXT = "txt" 26 | PDF = "pdf" 27 | 28 | @classmethod 29 | def from_mime_type(cls, mime_type: str): 30 | if mime_type == 'application/vnd.google-apps.document': 31 | return cls.GOOGLE_DOC 32 | elif mime_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 33 | return cls.DOCX 34 | elif mime_type == 'application/vnd.openxmlformats-officedocument.presentationml.presentation': 35 | return cls.PPTX 36 | elif mime_type == 'text/plain': 37 | return cls.TXT 38 | elif mime_type == 'application/pdf': 39 | return cls.PDF 40 | else: 41 | return None 42 | 43 | 44 | @dataclass 45 | class BasicDocument: 46 | id: Union[int, str] # row id in database 47 | data_source_id: int # data source id in database 48 | type: DocumentType 49 | title: str 50 | content: str 51 | timestamp: datetime 52 | author: str 53 | author_image_url: str 54 | location: str 55 | url: str 56 | status: str = None 57 | is_active: bool = None 58 | file_type: FileType = None 59 | children: List['BasicDocument'] = None 60 | 61 | @property 62 | def id_in_data_source(self): 63 | return str(self.data_source_id) + '_' + str(self.id) 64 | 65 | -------------------------------------------------------------------------------- /app/data_source/api/context.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from dataclasses import dataclass 4 | from typing import Dict, List 5 | 6 | from pydantic import ValidationError 7 | from sqlalchemy import select 8 | 9 | from data_source.api.base_data_source import BaseDataSource 10 | from data_source.api.dynamic_loader import DynamicLoader, ClassInfo 11 | from data_source.api.exception import KnownException 12 | from data_source.api.utils import get_utc_time_now 13 | from db_engine import Session, async_session 14 | from schemas import DataSourceType, DataSource, Document 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | @dataclass 20 | class CachedDataSource: 21 | indexed_docs: int 22 | failed_tasks: int 23 | instance: BaseDataSource 24 | 25 | 26 | class DataSourceContext: 27 | """ 28 | This class is responsible for loading data sources and caching them. 29 | It dynamically loads data source types from the data_source/sources directory. 30 | It loads data sources from the database and caches them. 31 | """ 32 | _initialized = False 33 | _data_source_cache: Dict[int, CachedDataSource] = {} 34 | _data_source_classes: Dict[str, BaseDataSource] = {} 35 | 36 | @classmethod 37 | def get_data_source_instance(cls, data_source_id: int) -> BaseDataSource: 38 | if not cls._initialized: 39 | cls.init() 40 | cls._initialized = True 41 | 42 | return cls._data_source_cache[data_source_id].instance 43 | 44 | @classmethod 45 | def get_data_source_class(cls, data_source_name: str) -> BaseDataSource: 46 | if not cls._initialized: 47 | cls.init() 48 | cls._initialized = True 49 | 50 | return cls._data_source_classes[data_source_name] 51 | 52 | @classmethod 53 | def get_data_source_classes(cls) -> Dict[str, BaseDataSource]: 54 | if not cls._initialized: 55 | cls.init() 56 | cls._initialized = True 57 | 58 | return cls._data_source_classes 59 | 60 | @classmethod 61 | async def create_data_source(cls, name: str, config: dict) -> BaseDataSource: 62 | async with async_session() as session: 63 | data_source_type = await session.execute( 64 | select(DataSourceType).filter_by(name=name) 65 | ) 66 | data_source_type = data_source_type.scalar_one_or_none() 67 | if data_source_type is None: 68 | raise KnownException(message=f"Data source type {name} does not exist") 69 | 70 | data_source_class = DynamicLoader.get_data_source_class(name) 71 | logger.info(f"validating config for data source {name}") 72 | await data_source_class.validate_config(config) 73 | config_str = json.dumps(config) 74 | 75 | data_source_row = DataSource(type_id=data_source_type.id, config=config_str, created_at=get_utc_time_now()) 76 | session.add(data_source_row) 77 | await session.commit() 78 | 79 | data_source = data_source_class(config=config, data_source_id=data_source_row.id) 80 | cls._data_source_cache[data_source_row.id] = CachedDataSource(indexed_docs=0, failed_tasks=0, 81 | instance=data_source) 82 | 83 | return data_source 84 | 85 | @classmethod 86 | def delete_data_source(cls, data_source_id: int) -> str: 87 | with Session() as session: 88 | data_source = session.query(DataSource).filter_by(id=data_source_id).first() 89 | if data_source is None: 90 | raise KnownException(message=f"Data source {data_source_id} does not exist") 91 | 92 | data_source_name = data_source.type.name 93 | logger.info(f"Deleting data source {data_source_id} ({data_source_name})...") 94 | session.delete(data_source) 95 | session.commit() 96 | 97 | del cls._data_source_cache[data_source_id] 98 | 99 | return data_source_name 100 | 101 | @classmethod 102 | def init(cls): 103 | cls._load_data_source_classes() 104 | cls._load_connected_sources_from_db() 105 | 106 | @classmethod 107 | def _load_connected_sources_from_db(cls): 108 | logger.info("Loading data sources from database") 109 | 110 | with Session() as session: 111 | data_sources: List[DataSource] = session.query(DataSource).all() 112 | for data_source in data_sources: 113 | logger.info(f"Loading data source {data_source.id} ({data_source.type.name})") 114 | data_source_cls = DynamicLoader.get_data_source_class(data_source.type.name) 115 | config = json.loads(data_source.config) 116 | try: 117 | data_source_instance = data_source_cls(config=config, data_source_id=data_source.id, 118 | last_index_time=data_source.last_indexed_at) 119 | except ValidationError as e: 120 | logger.error(f"Error loading data source {data_source.id}: {e}") 121 | return 122 | 123 | cached_data_source = CachedDataSource(indexed_docs=len(data_source.documents), 124 | failed_tasks=0, 125 | instance=data_source_instance) 126 | cls._data_source_cache[data_source.id] = cached_data_source 127 | logger.info(f"Loaded data source {data_source.id} ({data_source.type.name})") 128 | 129 | cls._initialized = True 130 | 131 | @classmethod 132 | def _load_data_source_classes(cls): 133 | data_sources: Dict[str, ClassInfo] = DynamicLoader.find_data_sources() 134 | 135 | with Session() as session: 136 | for source_name in data_sources.keys(): 137 | class_info = data_sources[source_name] 138 | data_source_class = DynamicLoader.get_class(file_path=class_info.file_path, 139 | class_name=class_info.name) 140 | cls._data_source_classes[source_name] = data_source_class 141 | 142 | if session.query(DataSourceType).filter_by(name=source_name).first(): 143 | continue 144 | 145 | config_fields = data_source_class.get_config_fields() 146 | config_fields_str = json.dumps([config_field.dict() for config_field in config_fields]) 147 | new_data_source = DataSourceType(name=source_name, 148 | display_name=data_source_class.get_display_name(), 149 | config_fields=config_fields_str) 150 | session.add(new_data_source) 151 | session.commit() 152 | -------------------------------------------------------------------------------- /app/data_source/api/dynamic_loader.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os 3 | import re 4 | from dataclasses import dataclass 5 | from typing import Dict 6 | import importlib 7 | 8 | from data_source.api.utils import snake_case_to_pascal_case 9 | 10 | 11 | @dataclass 12 | class ClassInfo: 13 | name: str 14 | file_path: str 15 | 16 | 17 | class DynamicLoader: 18 | """ 19 | This class is used to dynamically load classes from files. 20 | Specifically, it is used to load data sources from the data_source/sources directory. 21 | """ 22 | SOURCES_PATH = os.path.join('data_source', 'sources') 23 | 24 | @staticmethod 25 | def extract_classes(file_path: str): 26 | with open(file_path, 'r') as f: 27 | file_ast = ast.parse(f.read()) 28 | classes = {} 29 | for node in file_ast.body: 30 | if isinstance(node, ast.ClassDef): 31 | classes[node.name] = {'node': node, 'file': file_path} 32 | return classes 33 | 34 | @staticmethod 35 | def get_data_source_class(data_source_name: str): 36 | class_name = f"{snake_case_to_pascal_case(data_source_name)}DataSource" 37 | class_file_path = DynamicLoader.find_class_file(DynamicLoader.SOURCES_PATH, class_name) 38 | return DynamicLoader.get_class(class_file_path, class_name) 39 | 40 | @staticmethod 41 | def get_class(file_path: str, class_name: str): 42 | loader = importlib.machinery.SourceFileLoader(class_name, file_path) 43 | module = loader.load_module() 44 | try: 45 | return getattr(module, class_name) 46 | except AttributeError: 47 | raise AttributeError(f"Class {class_name} not found in module {module}," 48 | f"make sure you named the class correctly (it should be DataSource)") 49 | 50 | @staticmethod 51 | def find_class_file(directory, class_name): 52 | for root, dirs, files in os.walk(directory): 53 | for file in files: 54 | if file.endswith('.py'): 55 | file_path = os.path.join(root, file) 56 | classes = DynamicLoader.extract_classes(file_path) 57 | if class_name in classes: 58 | return file_path 59 | return None 60 | 61 | @staticmethod 62 | def find_data_sources() -> Dict[str, ClassInfo]: 63 | all_classes = {} 64 | # First, extract all classes and their file paths 65 | for root, dirs, files in os.walk(DynamicLoader.SOURCES_PATH): 66 | for file in files: 67 | if file.endswith('.py'): 68 | file_path = os.path.join(root, file) 69 | all_classes.update(DynamicLoader.extract_classes(file_path)) 70 | 71 | def is_base_data_source(class_name: str): 72 | if class_name not in all_classes: 73 | return False 74 | 75 | class_info = all_classes[class_name] 76 | node = class_info['node'] 77 | 78 | for base in node.bases: 79 | if isinstance(base, ast.Name): 80 | if base.id == 'BaseDataSource': 81 | return True 82 | elif is_base_data_source(base.id): 83 | return True 84 | 85 | return False 86 | 87 | data_sources = {} 88 | # Then, check if each class inherits from BaseDataSource 89 | for class_name, class_info in all_classes.items(): 90 | if is_base_data_source(class_name): 91 | snake_case = re.sub('([a-z0-9])([A-Z])', r'\1_\2', 92 | re.sub('(.)([A-Z][a-z]+)', r'\1_\2', class_name)).lower() 93 | clas_name = snake_case.replace('_data_source', '') 94 | data_sources[clas_name] = ClassInfo(name=class_name, 95 | file_path=class_info['file']) 96 | 97 | return data_sources 98 | 99 | 100 | if __name__ == '__main__': 101 | print(DynamicLoader.find_data_sources()) 102 | -------------------------------------------------------------------------------- /app/data_source/api/exception.py: -------------------------------------------------------------------------------- 1 | class KnownException(Exception): 2 | def __init__(self, message): 3 | self.message = message 4 | super().__init__(message) 5 | 6 | 7 | class InvalidDataSourceConfig(Exception): 8 | pass 9 | -------------------------------------------------------------------------------- /app/data_source/api/utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import logging 3 | import concurrent.futures 4 | from functools import lru_cache 5 | from io import BytesIO 6 | from typing import Optional 7 | from datetime import datetime, timezone 8 | import requests 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def snake_case_to_pascal_case(snake_case_string: str): 14 | """Converts a snake case string to a PascalCase string""" 15 | components = snake_case_string.split('_') 16 | return "".join(x.title() for x in components) 17 | 18 | 19 | def _wrap_with_try_except(func): 20 | def wrapper(*args, **kwargs): 21 | try: 22 | return func(*args, **kwargs) 23 | except Exception as e: 24 | logger.exception("Failed to parse data source", exc_info=e) 25 | raise e 26 | 27 | return wrapper 28 | 29 | 30 | def parse_with_workers(method_name: callable, items: list, **kwargs): 31 | workers = 10 # should be a config value 32 | 33 | logger.info(f'Parsing {len(items)} documents using {method} (with {workers} workers)...') 34 | 35 | with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: 36 | futures = [] 37 | for i in range(workers): 38 | futures.append(executor.submit(_wrap_with_try_except(method), items[i::workers], **kwargs)) 39 | concurrent.futures.wait(futures) 40 | for w in futures: 41 | e = w.exception() 42 | if e: 43 | logging.exception("Worker failed", exc_info=e) 44 | 45 | 46 | @lru_cache(maxsize=512) 47 | def get_confluence_user_image(image_url: str, token: str) -> Optional[str]: 48 | try: 49 | if "anonymous.svg" in image_url: 50 | image_url = image_url.replace(".svg", ".png") 51 | 52 | response = requests.get(url=image_url, timeout=1, headers={'Accept': 'application/json', 53 | "Authorization": f"Bearer {token}"}) 54 | image_bytes = BytesIO(response.content) 55 | return f"data:image/jpeg;base64,{base64.b64encode(image_bytes.getvalue()).decode()}" 56 | except: 57 | logger.warning(f"Failed to get confluence user image {image_url}") 58 | 59 | 60 | def get_utc_time_now() -> datetime: 61 | return datetime.now(tz=timezone.utc) 62 | -------------------------------------------------------------------------------- /app/data_source/sources/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/data_source/sources/__init__.py -------------------------------------------------------------------------------- /app/data_source/sources/bookstack/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/data_source/sources/bookstack/__init__.py -------------------------------------------------------------------------------- /app/data_source/sources/bookstack/bookstack.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from datetime import datetime 4 | from time import sleep 5 | from typing import List, Dict 6 | from urllib.parse import urljoin 7 | 8 | from pydantic import BaseModel 9 | from requests import Session, HTTPError 10 | from requests.auth import AuthBase 11 | 12 | from data_source.api.base_data_source import BaseDataSource, ConfigField, HTMLInputType, BaseDataSourceConfig 13 | from data_source.api.basic_document import BasicDocument, DocumentType 14 | from data_source.api.exception import InvalidDataSourceConfig 15 | from parsers.html import html_to_text 16 | from queues.index_queue import IndexQueue 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class BookStackAuth(AuthBase): 22 | def __init__(self, token_id, token_secret, header_key="Authorization"): 23 | self.header_key = header_key 24 | self.token_id = token_id 25 | self.token_secret = token_secret 26 | 27 | def __call__(self, r): 28 | r.headers[self.header_key] = f"Token {self.token_id}:{self.token_secret}" 29 | return r 30 | 31 | 32 | class BookStack(Session): 33 | VERIFY_SSL = os.environ.get('BOOKSTACK_VERIFY_SSL') is not None 34 | 35 | def __init__(self, url: str, token_id: str, token_secret: str, *args, **kwargs): 36 | super().__init__(*args, **kwargs) 37 | self.base_url = url 38 | self.auth = BookStackAuth(token_id, token_secret) 39 | self.rate_limit_reach = False 40 | 41 | def request(self, method, url_path, *args, **kwargs): 42 | while self.rate_limit_reach: 43 | sleep(1) 44 | 45 | url = urljoin(self.base_url, url_path) 46 | r = super().request(method, url, verify=BookStack.VERIFY_SSL, *args, **kwargs) 47 | 48 | if r.status_code != 200: 49 | if r.status_code == 429: 50 | if not self.rate_limit_reach: 51 | logger.info("API rate limit reach, waiting...") 52 | self.rate_limit_reach = True 53 | sleep(60) 54 | self.rate_limit_reach = False 55 | logger.info("Done waiting for the API rate limit") 56 | return self.request(method, url, verify=BookStack.VERIFY_SSL, *args, **kwargs) 57 | r.raise_for_status() 58 | return r 59 | 60 | def get_list(self, url: str, count: int = 500, sort: str = None, filters: Dict = None): 61 | # Add filter[...] to keys, avoiding the insertion of unwanted parameters 62 | if filters is not None: 63 | filters = {f"filter[{k}]": v for k, v in filters.items()} 64 | else: 65 | filters = {} 66 | 67 | data = [] 68 | records = 0 69 | total = 1 # Set 1 to enter the loop 70 | while records < total: 71 | r = self.get(url, params={"count": count, "offset": records, "sort": sort, **filters}, 72 | headers={"Content-Type": "application/json"}) 73 | json = r.json() 74 | data += json.get("data") 75 | records = len(data) 76 | total = json.get("total") 77 | return data 78 | 79 | def get_all_books(self) -> List[Dict]: 80 | return self.get_list("/api/books", sort="+updated_at") 81 | 82 | def get_all_pages_from_book(self, book) -> List[Dict]: 83 | pages = self.get_list("/api/pages", sort="+updated_at", filters={"book_id": book["id"]}) 84 | 85 | # Add parent book object to each page 86 | for page in pages: 87 | page.update({"book": book}) 88 | 89 | return pages 90 | 91 | def get_page(self, page_id: int): 92 | r = self.get(f"/api/pages/{page_id}", headers={"Content-Type": "application/json"}) 93 | return r.json() 94 | 95 | def get_user(self, user_id: int): 96 | try: 97 | return self.get(f"/api/users/{user_id}", headers={"Content-Type": "application/json"}).json() 98 | # If the user lack the privileges to make this call, return None 99 | except HTTPError: 100 | return None 101 | 102 | 103 | class BookStackConfig(BaseDataSourceConfig): 104 | url: str 105 | token_id: str 106 | token_secret: str 107 | 108 | 109 | class BookstackDataSource(BaseDataSource): 110 | @staticmethod 111 | def get_config_fields() -> List[ConfigField]: 112 | return [ 113 | ConfigField(label="BookStack instance URL", name="url"), 114 | ConfigField(label="Token ID", name="token_id", input_type=HTMLInputType.PASSWORD), 115 | ConfigField(label="Token Secret", name="token_secret", input_type=HTMLInputType.PASSWORD) 116 | ] 117 | 118 | @classmethod 119 | def get_display_name(cls) -> str: 120 | return "BookStack" 121 | 122 | @staticmethod 123 | def list_books(book_stack: BookStack) -> List[Dict]: 124 | # Usually the book_stack connection fails, so we retry a few times 125 | retries = 3 126 | for i in range(retries): 127 | try: 128 | return book_stack.get_all_books() 129 | except Exception as e: 130 | logging.error(f"BookStack connection failed: {e}") 131 | if i == retries - 1: 132 | raise e 133 | 134 | @staticmethod 135 | async def validate_config(config: Dict) -> None: 136 | try: 137 | parsed_config = BookStackConfig(**config) 138 | book_stack = BookStack(url=parsed_config.url, token_id=parsed_config.token_id, 139 | token_secret=parsed_config.token_secret) 140 | BookstackDataSource.list_books(book_stack=book_stack) 141 | except Exception as e: 142 | raise InvalidDataSourceConfig from e 143 | 144 | def __init__(self, *args, **kwargs): 145 | super().__init__(*args, **kwargs) 146 | book_stack_config = BookStackConfig(**self._raw_config) 147 | self._book_stack = BookStack(url=book_stack_config.url, token_id=book_stack_config.token_id, 148 | token_secret=book_stack_config.token_secret) 149 | 150 | def _list_books(self) -> List[Dict]: 151 | logger.info("Listing books with BookStack") 152 | return BookstackDataSource.list_books(book_stack=self._book_stack) 153 | 154 | def _feed_new_documents(self) -> None: 155 | logger.info("Feeding new documents with BookStack") 156 | 157 | books = self._list_books() 158 | for book in books: 159 | self.add_task_to_queue(self._feed_book, book=book) 160 | 161 | def _feed_book(self, book: Dict): 162 | logger.info(f"Getting documents from book {book['name']} ({book['id']})") 163 | pages = self._book_stack.get_all_pages_from_book(book) 164 | for page in pages: 165 | self.add_task_to_queue(self._feed_page, raw_page=page) 166 | 167 | def _feed_page(self, raw_page: Dict): 168 | last_modified = datetime.strptime(raw_page["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") 169 | if last_modified < self._last_index_time: 170 | return 171 | 172 | page_id = raw_page["id"] 173 | page_content = self._book_stack.get_page(page_id) 174 | author_name = page_content["created_by"]["name"] 175 | 176 | author_image_url = "" 177 | author = self._book_stack.get_user(raw_page["created_by"]) 178 | if author: 179 | author_image_url = author["avatar_url"] 180 | 181 | plain_text = html_to_text(page_content["html"]) 182 | 183 | url = urljoin(self._raw_config.get('url'), f"/books/{raw_page['book_slug']}/page/{raw_page['slug']}") 184 | 185 | document = BasicDocument(title=raw_page["name"], 186 | content=plain_text, 187 | author=author_name, 188 | author_image_url=author_image_url, 189 | timestamp=last_modified, 190 | id=page_id, 191 | data_source_id=self._data_source_id, 192 | location=raw_page["book"]["name"], 193 | url=url, 194 | type=DocumentType.DOCUMENT) 195 | IndexQueue.get_instance().put_single(doc=document) 196 | 197 | # if __name__ == "__main__": 198 | # import os 199 | # config = {"url": os.environ["BOOKSTACK_URL"], "token_id": os.environ["BOOKSTACK_TOKEN_ID"], 200 | # "token_secret": os.environ["BOOKSTACK_TOKEN_SECRET"]} 201 | # book_stack = BookstackDataSource(config=config, data_source_id=0) 202 | # book_stack._feed_new_documents() 203 | -------------------------------------------------------------------------------- /app/data_source/sources/confluence/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/data_source/sources/confluence/__init__.py -------------------------------------------------------------------------------- /app/data_source/sources/confluence/confluence.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import dateutil.parser 4 | from typing import List, Dict 5 | 6 | from atlassian import Confluence 7 | from atlassian.errors import ApiError 8 | from requests import HTTPError 9 | 10 | from data_source.api.base_data_source import BaseDataSource, ConfigField, HTMLInputType, Location, BaseDataSourceConfig 11 | from data_source.api.basic_document import BasicDocument, DocumentType 12 | from data_source.api.exception import InvalidDataSourceConfig 13 | from parsers.html import html_to_text 14 | from queues.index_queue import IndexQueue 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class ConfluenceConfig(BaseDataSourceConfig): 20 | url: str 21 | token: str 22 | 23 | 24 | class ConfluenceDataSource(BaseDataSource): 25 | 26 | @staticmethod 27 | def get_config_fields() -> List[ConfigField]: 28 | return [ 29 | ConfigField(label="Confluence URL", name="url", placeholder="https://example.confluence.com"), 30 | ConfigField(label="Personal Access Token", name="token", input_type=HTMLInputType.PASSWORD) 31 | ] 32 | 33 | @classmethod 34 | def get_display_name(cls) -> str: 35 | return "Confluence Self-Hosted" 36 | 37 | @staticmethod 38 | def list_spaces(confluence: Confluence, start=0) -> List[Location]: 39 | # Usually the confluence connection fails, so we retry a few times 40 | retries = 3 41 | for i in range(retries): 42 | try: 43 | return [Location(label=space['name'], value=space['key']) 44 | for space in confluence.get_all_spaces(expand='status', start=start)['results']] 45 | except Exception as e: 46 | logging.error(f'Confluence connection failed: {e}') 47 | if i == retries - 1: 48 | raise e 49 | 50 | @staticmethod 51 | def list_all_spaces(confluence: Confluence) -> List[Location]: 52 | logger.info('Listing spaces') 53 | 54 | spaces = [] 55 | start = 0 56 | while True: 57 | new_spaces = ConfluenceDataSource.list_spaces(confluence=confluence, start=start) 58 | if len(new_spaces) == 0: 59 | break 60 | 61 | spaces.extend(new_spaces) 62 | start += len(new_spaces) 63 | 64 | logger.info(f'Found {len(spaces)} spaces') 65 | return spaces 66 | 67 | @staticmethod 68 | async def validate_config(config: Dict) -> None: 69 | try: 70 | client = ConfluenceDataSource.confluence_client_from_config(config) 71 | ConfluenceDataSource.list_spaces(confluence=client) 72 | except Exception as e: 73 | raise InvalidDataSourceConfig from e 74 | 75 | @staticmethod 76 | def confluence_client_from_config(config: Dict) -> Confluence: 77 | parsed_config = ConfluenceConfig(**config) 78 | should_verify_ssl = os.environ.get('CONFLUENCE_VERIFY_SSL') is not None 79 | return Confluence(url=parsed_config.url, token=parsed_config.token, verify_ssl=should_verify_ssl) 80 | 81 | @staticmethod 82 | def list_locations(config: Dict) -> List[Location]: 83 | confluence = ConfluenceDataSource.confluence_client_from_config(config) 84 | return ConfluenceDataSource.list_all_spaces(confluence=confluence) 85 | 86 | @staticmethod 87 | def has_prerequisites() -> bool: 88 | return True 89 | 90 | def __init__(self, *args, **kwargs): 91 | super().__init__(*args, **kwargs) 92 | self._confluence = ConfluenceDataSource.confluence_client_from_config(self._raw_config) 93 | 94 | def _list_spaces(self) -> List[Location]: 95 | return ConfluenceDataSource.list_all_spaces(confluence=self._confluence) 96 | 97 | def _feed_new_documents(self) -> None: 98 | logger.info('Feeding new documents with Confluence') 99 | spaces = self._config.locations_to_index or self._list_spaces() 100 | for space in spaces: 101 | self.add_task_to_queue(self._feed_space_docs, space=space) 102 | 103 | def _feed_space_docs(self, space: Location) -> List[Dict]: 104 | logging.info(f'Getting documents from space {space.label} ({space.value})') 105 | start = 0 106 | limit = 200 # limit when expanding the version 107 | 108 | last_index_time = self._last_index_time.strftime("%Y-%m-%d %H:%M") 109 | cql_query = f'type = page AND Space = "{space.value}" AND lastModified >= "{last_index_time}" ' \ 110 | f'ORDER BY lastModified DESC' 111 | logger.info(f'Querying confluence with CQL: {cql_query}') 112 | while True: 113 | new_batch = self._confluence.cql(cql_query, start=start, limit=limit, 114 | expand='version')['results'] 115 | len_new_batch = len(new_batch) 116 | logger.info(f'Got {len_new_batch} documents from space {space.label} (total {start + len_new_batch})') 117 | for raw_doc in new_batch: 118 | raw_doc['space_name'] = space.label 119 | self.add_task_to_queue(self._feed_doc, raw_doc=raw_doc) 120 | 121 | if len(new_batch) < limit: 122 | break 123 | 124 | start += limit 125 | 126 | def _feed_doc(self, raw_doc: Dict): 127 | last_modified = dateutil.parser.parse(raw_doc['lastModified']) 128 | doc_id = raw_doc['content']['id'] 129 | try: 130 | fetched_raw_page = self._confluence.get_page_by_id(doc_id, expand='body.storage,history') 131 | except HTTPError as e: 132 | logging.warning( 133 | f'Confluence returned status code {e.response.status_code} for document {doc_id} ({raw_doc["title"]}). skipping.') 134 | return 135 | except ApiError as e: 136 | logging.warning( 137 | f'unable to access document {doc_id} ({raw_doc["title"]}). reason: "{e.reason}". skipping.') 138 | return 139 | 140 | author = fetched_raw_page['history']['createdBy']['displayName'] 141 | author_image = fetched_raw_page['history']['createdBy']['profilePicture']['path'] 142 | author_image_url = fetched_raw_page['_links']['base'] + author_image 143 | html_content = fetched_raw_page['body']['storage']['value'] 144 | plain_text = html_to_text(html_content) 145 | 146 | url = fetched_raw_page['_links']['base'] + fetched_raw_page['_links']['webui'] 147 | 148 | doc = BasicDocument(title=fetched_raw_page['title'], 149 | content=plain_text, 150 | author=author, 151 | author_image_url=author_image_url, 152 | timestamp=last_modified, 153 | id=doc_id, 154 | data_source_id=self._data_source_id, 155 | location=raw_doc['space_name'], 156 | url=url, 157 | type=DocumentType.DOCUMENT) 158 | IndexQueue.get_instance().put_single(doc=doc) 159 | 160 | # if __name__ == '__main__': 161 | # import os 162 | # config = {"url": os.environ['CONFLUENCE_URL'], "token": os.environ['CONFLUENCE_TOKEN']} 163 | # confluence = ConfluenceDataSource(config=config, data_source_id=0) 164 | # spaces = ConfluenceDataSource.list_all_spaces(confluence=confluence._confluence) 165 | # confluence._feed_space_docs(space=spaces[0]) 166 | -------------------------------------------------------------------------------- /app/data_source/sources/confluence/confluence_cloud.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | from atlassian import Confluence 4 | 5 | from data_source.api.base_data_source import ConfigField, HTMLInputType, Location, BaseDataSourceConfig 6 | from data_source.api.exception import InvalidDataSourceConfig 7 | from data_source.sources.confluence.confluence import ConfluenceDataSource 8 | 9 | 10 | class ConfluenceCloudConfig(BaseDataSourceConfig): 11 | url: str 12 | token: str 13 | username: str 14 | 15 | 16 | class ConfluenceCloudDataSource(ConfluenceDataSource): 17 | 18 | @staticmethod 19 | def get_config_fields() -> List[ConfigField]: 20 | return [ 21 | ConfigField(label="Confluence URL", name="url", placeholder="https://example.confluence.com"), 22 | ConfigField(label="Personal API Token", name="token", input_type=HTMLInputType.PASSWORD), 23 | ConfigField(label="Username", name="username", placeholder="example.user@email.com") 24 | ] 25 | 26 | @classmethod 27 | def get_display_name(cls) -> str: 28 | return "Confluence Cloud" 29 | 30 | @staticmethod 31 | async def validate_config(config: Dict) -> None: 32 | try: 33 | client = ConfluenceCloudDataSource.confluence_client_from_config(config) 34 | ConfluenceCloudDataSource.list_spaces(confluence=client) 35 | except Exception as e: 36 | raise InvalidDataSourceConfig from e 37 | 38 | @staticmethod 39 | def confluence_client_from_config(config: Dict) -> Confluence: 40 | parsed_config = ConfluenceCloudConfig(**config) 41 | return Confluence(url=parsed_config.url, username=parsed_config.username, 42 | password=parsed_config.token, cloud=True) 43 | 44 | @staticmethod 45 | def list_locations(config: Dict) -> List[Location]: 46 | confluence = ConfluenceCloudDataSource.confluence_client_from_config(config) 47 | return ConfluenceDataSource.list_all_spaces(confluence=confluence) 48 | 49 | def __init__(self, *args, **kwargs): 50 | super().__init__(*args, **kwargs) 51 | self._confluence = ConfluenceCloudDataSource.confluence_client_from_config(self._raw_config) 52 | 53 | 54 | # if __name__ == '__main__': 55 | # import os 56 | # 57 | # config = {"url": os.environ.get('CONFLUENCE_CLOUD_URL'), "token": os.environ.get('CONFLUENCE_CLOUD_TOKEN'), 58 | # "username": os.environ.get('CONFLUENCE_CLOUD_USER')} 59 | # confluence = ConfluenceCloudDataSource(data_source_id=1, config=config) 60 | # confluence._feed_new_documents() 61 | -------------------------------------------------------------------------------- /app/data_source/sources/gitlab/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/data_source/sources/gitlab/__init__.py -------------------------------------------------------------------------------- /app/data_source/sources/gitlab/gitlab.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import requests 4 | from typing import Dict, List, Optional 5 | import dateutil.parser 6 | 7 | from data_source.api.base_data_source import BaseDataSource, BaseDataSourceConfig, ConfigField, HTMLInputType 8 | from data_source.api.basic_document import BasicDocument, DocumentType, DocumentStatus 9 | from data_source.api.exception import InvalidDataSourceConfig 10 | from queues.index_queue import IndexQueue 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class GitlabConfig(BaseDataSourceConfig): 17 | url: str 18 | access_token: str 19 | 20 | 21 | def gitlab_status_to_doc_status(status: str) -> Optional[DocumentStatus]: 22 | if status == "opened": 23 | return DocumentStatus.OPEN 24 | elif status == "closed": 25 | return DocumentStatus.CLOSED 26 | else: 27 | logger.warning(f"[!] Unknown status {status}") 28 | return None 29 | 30 | 31 | class GitlabDataSource(BaseDataSource): 32 | 33 | @staticmethod 34 | def get_config_fields() -> List[ConfigField]: 35 | return [ 36 | ConfigField(label="Gitlab URL", name="url", input_type=HTMLInputType.TEXT, 37 | placeholder="https://gitlab.com"), 38 | ConfigField(label="API Access Token", name="access_token", input_type=HTMLInputType.PASSWORD), 39 | ] 40 | 41 | @staticmethod 42 | async def validate_config(config: Dict) -> None: 43 | try: 44 | parsed_config = GitlabConfig(**config) 45 | session = requests.Session() 46 | session.headers.update({"PRIVATE-TOKEN": parsed_config.access_token}) 47 | projects_response = session.get(f"{parsed_config.url}/api/v4/projects?membership=true") 48 | projects_response.raise_for_status() 49 | except (KeyError, ValueError) as e: 50 | raise InvalidDataSourceConfig from e 51 | 52 | def __init__(self, *args, **kwargs): 53 | super().__init__(*args, **kwargs) 54 | self.gitlab_config = GitlabConfig(**self._raw_config) 55 | self._session = requests.Session() 56 | self._session.headers.update({"PRIVATE-TOKEN": self.gitlab_config.access_token}) 57 | 58 | def _get_all_paginated(self, url: str) -> List[Dict]: 59 | items = [] 60 | page = 1 61 | per_page = 100 62 | 63 | while True: 64 | try: 65 | response = self._session.get(url + f"&per_page={per_page}&page={page}") 66 | response.raise_for_status() 67 | new_items: List[Dict] = response.json() 68 | items.extend(new_items) 69 | 70 | if len(new_items) < per_page: 71 | break 72 | 73 | page += 1 74 | except: 75 | logging.exception("Error while fetching items paginated for url: " + url) 76 | 77 | return items 78 | 79 | def _list_all_projects(self) -> List[Dict]: 80 | return self._get_all_paginated(f"{self.gitlab_config.url}/api/v4/projects?membership=true") 81 | 82 | def _feed_new_documents(self) -> None: 83 | for project in self._list_all_projects(): 84 | logger.info(f"Feeding project {project['name']}") 85 | self.add_task_to_queue(self._feed_project_issues, project=project) 86 | 87 | def _feed_project_issues(self, project: Dict): 88 | project_id = project["id"] 89 | issues_url = f"{self.gitlab_config.url}/api/v4/projects/{project_id}/issues?scope=all" 90 | all_issues = self._get_all_paginated(issues_url) 91 | 92 | for issue in all_issues: 93 | self.add_task_to_queue(self.feed_issue, issue=issue) 94 | 95 | def feed_issue(self, issue: Dict): 96 | updated_at = dateutil.parser.parse(issue["updated_at"]) 97 | if self._is_prior_to_last_index_time(doc_time=updated_at): 98 | logger.info(f"Issue {issue['id']} is too old, skipping") 99 | return 100 | 101 | comments_url = \ 102 | f"{self.gitlab_config.url}/api/v4/projects/{issue['project_id']}/issues/{issue['iid']}/notes?sort=asc" 103 | raw_comments = self._get_all_paginated(comments_url) 104 | comments = [] 105 | issue_url = issue['web_url'] 106 | 107 | for raw_comment in raw_comments: 108 | if raw_comment["system"]: 109 | continue 110 | 111 | comments.append(BasicDocument( 112 | id=raw_comment["id"], 113 | data_source_id=self._data_source_id, 114 | type=DocumentType.COMMENT, 115 | title=raw_comment["author"]["name"], 116 | content=raw_comment["body"], 117 | author=raw_comment["author"]["name"], 118 | author_image_url=raw_comment["author"]["avatar_url"], 119 | location=issue['references']['full'].replace("/", " / "), 120 | url=issue_url, 121 | timestamp=dateutil.parser.parse(raw_comment["updated_at"]) 122 | )) 123 | 124 | status = gitlab_status_to_doc_status(issue["state"]) 125 | is_active = status == DocumentStatus.OPEN 126 | doc = BasicDocument( 127 | id=issue["id"], 128 | data_source_id=self._data_source_id, 129 | type=DocumentType.ISSUE, 130 | title=issue['title'], 131 | content=issue.get("description") or "", 132 | author=issue['author']['name'], 133 | author_image_url=issue['author']['avatar_url'], 134 | location=issue['references']['full'].replace("/", " / "), 135 | url=issue['web_url'], 136 | timestamp=updated_at, 137 | status=issue["state"], 138 | is_active=is_active, 139 | children=comments 140 | ) 141 | IndexQueue.get_instance().put_single(doc=doc) 142 | -------------------------------------------------------------------------------- /app/data_source/sources/google_drive/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/data_source/sources/google_drive/__init__.py -------------------------------------------------------------------------------- /app/data_source/sources/jira/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/data_source/sources/jira/__init__.py -------------------------------------------------------------------------------- /app/data_source/sources/jira/jira.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import urllib 4 | from typing import List, Dict 5 | import dateutil.parser 6 | 7 | import dateutil 8 | from atlassian import Jira 9 | from atlassian.errors import ApiError 10 | 11 | from data_source.api.base_data_source import BaseDataSource, ConfigField, HTMLInputType, Location, BaseDataSourceConfig 12 | from data_source.api.basic_document import BasicDocument, DocumentType, DocumentStatus 13 | from data_source.api.exception import InvalidDataSourceConfig 14 | from queues.index_queue import IndexQueue 15 | 16 | 17 | class JiraConfig(BaseDataSourceConfig): 18 | url: str 19 | token: str 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class JiraDataSource(BaseDataSource): 26 | 27 | @classmethod 28 | def get_display_name(cls) -> str: 29 | return "Jira Self-Hosted" 30 | 31 | @staticmethod 32 | def get_config_fields() -> List[ConfigField]: 33 | return [ 34 | ConfigField(label="Jira URL", name="url", placeholder="https://self-hosted-jira.com"), 35 | ConfigField(label="Personal Access Token", name="token", input_type=HTMLInputType.PASSWORD) 36 | ] 37 | 38 | @staticmethod 39 | def list_projects(jira: Jira) -> List[Location]: 40 | logger.info('Listing projects') 41 | projects = jira.get_all_projects() 42 | return [Location(label=project['name'], value=project['key']) for project in projects] 43 | 44 | @staticmethod 45 | def list_locations(config: Dict) -> List[Location]: 46 | jira = JiraDataSource.client_from_config(config) 47 | return JiraDataSource.list_projects(jira=jira) 48 | 49 | @staticmethod 50 | def client_from_config(config: Dict) -> Jira: 51 | parsed_config = JiraConfig(**config) 52 | should_verify_ssl = os.environ.get('JIRA_VERIFY_SSL') is not None 53 | return Jira(url=parsed_config.url, token=parsed_config.token, verify_ssl=should_verify_ssl) 54 | 55 | @staticmethod 56 | def has_prerequisites() -> bool: 57 | return True 58 | 59 | @staticmethod 60 | async def validate_config(config: Dict) -> None: 61 | try: 62 | jira = JiraDataSource.client_from_config(config) 63 | jira.get_all_priorities() 64 | except ApiError as e: 65 | raise InvalidDataSourceConfig from e 66 | 67 | def __init__(self, *args, **kwargs): 68 | super().__init__(*args, **kwargs) 69 | self._jira = JiraDataSource.client_from_config(self._raw_config) 70 | 71 | def _feed_new_documents(self) -> None: 72 | logger.info('Feeding new documents with Jira') 73 | projects = self._config.locations_to_index or JiraDataSource.list_projects(jira=self._jira) 74 | for project in projects: 75 | self.add_task_to_queue(self._feed_project_issues, project=project) 76 | 77 | def _feed_project_issues(self, project: Location): 78 | logging.info(f'Getting issues from project {project.label} ({project.value})') 79 | 80 | start = 0 81 | limit = 100 82 | last_index_time = self._last_index_time.strftime("%Y-%m-%d %H:%M") 83 | jql_query = f'project = "{project.value}" AND updated >= "{last_index_time}" ORDER BY updated DESC' 84 | logger.info(f'Querying jira with JQL: {jql_query}') 85 | while True: 86 | new_batch = self._jira.jql_get_list_of_tickets(jql_query, start=start, limit=limit, validate_query=True) 87 | len_new_batch = len(new_batch) 88 | logger.info(f'Got {len_new_batch} issues from project {project.label} (total {start + len_new_batch})') 89 | for raw_issue in new_batch: 90 | self.add_task_to_queue(self._feed_issue, raw_issue=raw_issue, project_name=project.label) 91 | 92 | if len(new_batch) < limit: 93 | break 94 | 95 | start += limit 96 | 97 | def _feed_issue(self, raw_issue: Dict, project_name: str): 98 | issue_id = raw_issue['id'] 99 | last_modified = dateutil.parser.parse(raw_issue['fields']['updated']) 100 | 101 | base_url = self._raw_config['url'] 102 | issue_url = urllib.parse.urljoin(base_url, f"/browse/{raw_issue['key']}") 103 | comments = [] 104 | raw_comments = self._jira.issue_get_comments(issue_id) 105 | for raw_comment in raw_comments['comments']: 106 | comments.append(BasicDocument( 107 | id=raw_comment["id"], 108 | data_source_id=self._data_source_id, 109 | type=DocumentType.COMMENT, 110 | title=raw_comment["author"]["displayName"], 111 | content=raw_comment["body"], 112 | author=raw_comment["author"]["displayName"], 113 | author_image_url=raw_comment["author"]["avatarUrls"]["48x48"], 114 | location=raw_issue['key'], 115 | url=issue_url, 116 | timestamp=dateutil.parser.parse(raw_comment["updated"]) 117 | )) 118 | 119 | author = None 120 | if assignee := raw_issue['fields'].get('assignee'): 121 | author = assignee 122 | elif reporter := raw_issue['fields'].get('reporter'): 123 | author = reporter 124 | elif creator := raw_issue['fields'].get('creator'): 125 | author = creator 126 | 127 | if author: 128 | author_name = author['displayName'] 129 | author_image_url = author['avatarUrls']['48x48'] 130 | else: 131 | author_name = 'Unknown' 132 | author_image_url = "" 133 | 134 | content = raw_issue['fields']['description'] 135 | title = raw_issue['fields']['summary'] 136 | doc = BasicDocument(title=title, 137 | content=content, 138 | author=author_name, 139 | author_image_url=author_image_url, 140 | timestamp=last_modified, 141 | id=issue_id, 142 | data_source_id=self._data_source_id, 143 | location=project_name, 144 | url=issue_url, 145 | status=raw_issue['fields']['status']['name'], 146 | type=DocumentType.ISSUE, 147 | children=comments) 148 | IndexQueue.get_instance().put_single(doc=doc) 149 | 150 | 151 | # if __name__ == '__main__': 152 | # import os 153 | # ds = JiraDataSource(config={"url": os.getenv('JIRA_URL'), "token": os.getenv('JIRA_TOKEN')}, data_source_id=5) 154 | # projects = ds.list_projects(ds._jira) 155 | # for project in projects: 156 | # ds._feed_project_issues(project=project) 157 | -------------------------------------------------------------------------------- /app/data_source/sources/jira/jira_cloud.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | from atlassian import Jira 4 | 5 | from data_source.api.base_data_source import ConfigField, HTMLInputType, Location, BaseDataSourceConfig 6 | from data_source.api.exception import InvalidDataSourceConfig 7 | from data_source.sources.jira.jira import JiraDataSource 8 | 9 | 10 | class JiraCloudConfig(BaseDataSourceConfig): 11 | url: str 12 | token: str 13 | username: str 14 | 15 | 16 | class JiraCloudDataSource(JiraDataSource): 17 | 18 | @staticmethod 19 | def get_config_fields() -> List[ConfigField]: 20 | return [ 21 | ConfigField(label="Jira Cloud URL", name="url", placeholder="https://example.jira.com"), 22 | ConfigField(label="Personal API Token", name="token", input_type=HTMLInputType.PASSWORD), 23 | ConfigField(label="Username", name="username", placeholder="example.user@email.com") 24 | ] 25 | 26 | @staticmethod 27 | async def validate_config(config: Dict) -> None: 28 | try: 29 | client = JiraCloudDataSource.client_from_config(config) 30 | JiraCloudDataSource.list_projects(jira=client) 31 | except Exception as e: 32 | raise InvalidDataSourceConfig from e 33 | 34 | @classmethod 35 | def get_display_name(cls) -> str: 36 | return "Jira Cloud" 37 | 38 | @staticmethod 39 | def client_from_config(config: Dict) -> Jira: 40 | parsed_config = JiraCloudConfig(**config) 41 | return Jira(url=parsed_config.url, username=parsed_config.username, 42 | password=parsed_config.token, cloud=True) 43 | 44 | @staticmethod 45 | def list_locations(config: Dict) -> List[Location]: 46 | jira = JiraCloudDataSource.client_from_config(config) 47 | return JiraDataSource.list_projects(jira=jira) 48 | 49 | def __init__(self, *args, **kwargs): 50 | super().__init__(*args, **kwargs) 51 | self._jira = JiraCloudDataSource.client_from_config(self._raw_config) 52 | -------------------------------------------------------------------------------- /app/data_source/sources/mattermost/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/data_source/sources/mattermost/__init__.py -------------------------------------------------------------------------------- /app/data_source/sources/mattermost/mattermost.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass, asdict 3 | from datetime import datetime 4 | from functools import lru_cache 5 | from typing import Dict, List, Optional 6 | from urllib.parse import urlparse 7 | 8 | from mattermostdriver import Driver 9 | 10 | from data_source.api.base_data_source import BaseDataSource, ConfigField, HTMLInputType, BaseDataSourceConfig, Location 11 | from data_source.api.basic_document import BasicDocument, DocumentType 12 | from data_source.api.exception import InvalidDataSourceConfig 13 | from queues.index_queue import IndexQueue 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | @dataclass 19 | class MattermostChannel: 20 | id: str 21 | name: str 22 | team_id: str 23 | 24 | 25 | @dataclass 26 | class MattermostConfig: 27 | url: str 28 | token: str 29 | locations_to_index: Optional[List[Location]] 30 | scheme: Optional[str] = "https" 31 | port: Optional[int] = 443 32 | 33 | def __post_init__(self): 34 | try: 35 | parsed_url = urlparse(self.url) 36 | except Exception as e: 37 | raise ValueError from e 38 | 39 | self.url = parsed_url.hostname 40 | self.port = parsed_url.port if parsed_url.port is not None else self.port 41 | self.scheme = parsed_url.scheme if parsed_url.scheme != "" else self.scheme 42 | 43 | 44 | class MattermostDataSource(BaseDataSource): 45 | FEED_BATCH_SIZE = 500 46 | 47 | @staticmethod 48 | def get_config_fields() -> List[ConfigField]: 49 | return [ 50 | ConfigField(label="Mattermost Server", name="url", placeholder="https://mattermost.server.com", 51 | input_type=HTMLInputType.TEXT), 52 | ConfigField(label="Access Token", name="token", placeholder="paste-your-access-token-here", 53 | input_type=HTMLInputType.PASSWORD), 54 | ] 55 | 56 | @staticmethod 57 | async def validate_config(config: Dict) -> None: 58 | try: 59 | parsed_config = MattermostConfig(**config) 60 | maattermost = Driver(options=asdict(parsed_config)) 61 | maattermost.login() 62 | except Exception as e: 63 | raise InvalidDataSourceConfig from e 64 | 65 | def __init__(self, *args, **kwargs): 66 | super().__init__(*args, **kwargs) 67 | mattermost_config = MattermostConfig(**self._raw_config) 68 | self._mattermost = Driver(options=asdict(mattermost_config)) 69 | 70 | def _list_channels(self) -> List[MattermostChannel]: 71 | channels = self._mattermost.channels.client.get(f"/users/me/channels") 72 | return [MattermostChannel(id=channel["id"], name=channel["name"], team_id=channel["team_id"]) 73 | for channel in channels] 74 | 75 | def _is_valid_message(self, message: Dict) -> bool: 76 | return message["type"] == "" 77 | 78 | def _is_valid_channel(self, channel: MattermostChannel) -> bool: 79 | return channel.team_id != "" 80 | 81 | def _list_posts_in_channel(self, channel_id: str, page: int) -> Dict: 82 | endpoint = f"/channels/{channel_id}/posts" 83 | params = { 84 | "since": int(self._last_index_time.timestamp()) * 1000, 85 | "page": page 86 | } 87 | 88 | posts = self._mattermost.channels.client.get(endpoint, params=params) 89 | return posts 90 | 91 | def _feed_new_documents(self) -> None: 92 | self._mattermost.login() 93 | 94 | channels = self._list_channels() 95 | logger.info(f'Found {len(channels)} channels') 96 | 97 | for channel in channels: 98 | self.add_task_to_queue(self._feed_channel, channel=channel) 99 | 100 | def _get_mattermost_url(self): 101 | options = self._mattermost.options 102 | return f"{options['scheme']}://{options['url']}:{options['port']}" 103 | 104 | def _get_team_url(self, channel: MattermostChannel): 105 | url = self._get_mattermost_url() 106 | team = self._mattermost.teams.get_team(channel.team_id) 107 | return f"{url}/{team['name']}" 108 | 109 | @lru_cache(maxsize=512) 110 | def _get_mattermost_user(self, user_id: str): 111 | return self._mattermost.users.get_user(user_id)["username"] 112 | 113 | def _feed_channel(self, channel: MattermostChannel): 114 | if not self._is_valid_channel(channel): 115 | return 116 | 117 | logger.info(f'Feeding channel {channel.name}') 118 | 119 | page = 0 120 | team_url = self._get_team_url(channel) 121 | while True: 122 | posts = self._list_posts_in_channel(channel.id, page) 123 | 124 | last_message: Optional[BasicDocument] = None 125 | posts["order"].reverse() 126 | for id in posts["order"]: 127 | post = posts["posts"][id] 128 | 129 | if not self._is_valid_message(post): 130 | if last_message is not None: 131 | IndexQueue.get_instance().put_single(doc=last_message) 132 | last_message = None 133 | continue 134 | 135 | author = self._get_mattermost_user(post["user_id"]) 136 | content = post["message"] 137 | 138 | if last_message is not None: 139 | if last_message.author == author: 140 | last_message.content += f"\n{content}" 141 | continue 142 | else: 143 | IndexQueue.get_instance().put_single(doc=last_message) 144 | last_message = None 145 | 146 | author_image_url = f"{self._get_mattermost_url()}/api/v4/users/{post['user_id']}/image?_=0" 147 | timestamp = datetime.fromtimestamp(post["update_at"] / 1000) 148 | last_message = BasicDocument( 149 | id=id, 150 | data_source_id=self._data_source_id, 151 | title=channel.name, 152 | content=content, 153 | timestamp=timestamp, 154 | author=author, 155 | author_image_url=author_image_url, 156 | location=channel.name, 157 | url=f"{team_url}/pl/{id}", 158 | type=DocumentType.MESSAGE 159 | ) 160 | 161 | if posts["prev_post_id"] == "": 162 | break 163 | page += 1 164 | 165 | if last_message is not None: 166 | IndexQueue.get_instance().put_single(doc=last_message) 167 | -------------------------------------------------------------------------------- /app/data_source/sources/rocketchat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/data_source/sources/rocketchat/__init__.py -------------------------------------------------------------------------------- /app/data_source/sources/rocketchat/rocketchat.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | from dataclasses import dataclass 4 | from typing import Optional, Dict, List 5 | import os 6 | 7 | from pydantic import BaseModel 8 | from rocketchat_API.rocketchat import RocketChat 9 | 10 | from data_source.api.base_data_source import BaseDataSource, ConfigField, HTMLInputType, BaseDataSourceConfig 11 | from data_source.api.basic_document import DocumentType, BasicDocument 12 | from data_source.api.exception import InvalidDataSourceConfig 13 | from queues.index_queue import IndexQueue 14 | 15 | 16 | @dataclass 17 | class RocketchatThread: 18 | id: str 19 | name: str 20 | channel_id: str 21 | 22 | 23 | @dataclass 24 | class RocketchatRoom: 25 | id: str 26 | name: str 27 | type: str 28 | archived: bool 29 | 30 | 31 | @dataclass 32 | class RocketchatAuthor: 33 | name: str 34 | image_url: str 35 | 36 | 37 | class RocketchatConfig(BaseDataSourceConfig): 38 | url: str 39 | token_id: str 40 | token_secret: str 41 | 42 | 43 | class RocketchatDataSource(BaseDataSource): 44 | @staticmethod 45 | def get_config_fields() -> List[ConfigField]: 46 | return [ 47 | ConfigField(label="Rockat.Chat instance URL (including https://)", name="url"), 48 | ConfigField(label="User ID", name="token_id", type=HTMLInputType.PASSWORD), 49 | ConfigField(label="Token", name="token_secret", type=HTMLInputType.PASSWORD) 50 | ] 51 | 52 | @classmethod 53 | def get_display_name(cls) -> str: 54 | return "Rocket.Chat" 55 | 56 | @staticmethod 57 | async def validate_config(config: Dict) -> None: 58 | rocket_chat_config = RocketchatConfig(**config) 59 | should_verify_ssl = os.environ.get('ROCKETCHAT_VERIFY_SSL') is not None 60 | rocket_chat = RocketChat(user_id=rocket_chat_config.token_id, auth_token=rocket_chat_config.token_secret, 61 | server_url=rocket_chat_config.url, ssl_verify=should_verify_ssl) 62 | try: 63 | rocket_chat.me().json() 64 | except Exception as e: 65 | raise InvalidDataSourceConfig from e 66 | 67 | def __init__(self, *args, **kwargs): 68 | super().__init__(*args, **kwargs) 69 | rocket_chat_config = RocketchatConfig(**self._raw_config) 70 | self._rocket_chat = RocketChat(user_id=rocket_chat_config.token_id, auth_token=rocket_chat_config.token_secret, 71 | server_url=rocket_chat_config.url) 72 | self._authors_cache: Dict[str, RocketchatAuthor] = {} 73 | 74 | def _list_rooms(self) -> List[RocketchatRoom]: 75 | oldest = self._last_index_time.strftime("%Y-%m-%dT%H:%M:%S.%f%z") 76 | r = self._rocket_chat.call_api_get("rooms.get", updatedSince=oldest) 77 | json = r.json() 78 | data = json.get("update") 79 | 80 | rooms = [] 81 | for r in data: 82 | room_id = r["_id"] 83 | if "fname" in r: 84 | name = r["fname"] 85 | elif "name" in r: 86 | name = r["name"] 87 | elif r["t"] == "d": 88 | my_uid = self._rocket_chat.me().json()["_id"] 89 | uid = next(filter(lambda u: u != my_uid, r["uids"]), None) 90 | if not uid: 91 | uid = my_uid 92 | user = self._get_author_details(uid) 93 | name = user.name 94 | else: 95 | raise Exception("Unknown name") 96 | room_type = r["t"] 97 | archived = r.get("archived", False) 98 | rooms.append(RocketchatRoom(id=room_id, name=name, type=room_type, archived=archived)) 99 | 100 | return rooms 101 | 102 | def _list_threads(self, channel: RocketchatRoom) -> List[RocketchatThread]: 103 | data = [] 104 | records = 0 105 | total = 1 # Set 1 to enter the loop 106 | while records < total: 107 | r = self._rocket_chat.call_api_get("chat.getThreadsList", rid=channel.id, count=20, offset=records) 108 | json = r.json() 109 | data += json.get("threads") 110 | records = len(data) 111 | total = json.get("total") 112 | return [RocketchatThread(id=trds["_id"], name=trds["msg"], channel_id=trds["rid"]) for trds in data] 113 | 114 | def _list_messages(self, channel: RocketchatRoom): 115 | oldest = self._last_index_time.strftime("%Y-%m-%dT%H:%M:%S.%f%z") 116 | data = [] 117 | while oldest: 118 | r = self._rocket_chat.call_api_get("chat.syncMessages", roomId=channel.id, lastUpdate=oldest) 119 | json = r.json() 120 | messages = json["result"].get("updated") 121 | if messages: 122 | data += messages 123 | oldest = messages[0]["_updatedAt"] 124 | else: 125 | oldest = None 126 | return data 127 | 128 | def _list_thread_messages(self, thread: RocketchatThread): 129 | oldest = self._last_index_time.strftime("%Y-%m-%dT%H:%M:%S.%f%z") 130 | data = [] 131 | records = 0 132 | total = 1 # Set 1 to enter the loop 133 | while records < total: 134 | r = self._rocket_chat.call_api_get("chat.getThreadMessages", tmid=thread.id, tlm=oldest, count=20, 135 | offset=records) 136 | json = r.json() 137 | messages = json.get("messages") 138 | if messages: 139 | data += messages 140 | records = len(data) 141 | total = json.get("total") 142 | return data 143 | 144 | def _get_author_details(self, author_id: str) -> RocketchatAuthor: 145 | author = self._authors_cache.get(author_id, None) 146 | if author is None: 147 | author_info = self._rocket_chat.users_info(author_id).json().get("user") 148 | author = RocketchatAuthor(name=author_info.get("name", author_info.get("username")), 149 | image_url=f"{self._raw_config.get('url')}/avatar/{author_info.get('username')}") 150 | self._authors_cache[author_id] = author 151 | 152 | return author 153 | 154 | def _feed_new_documents(self) -> None: 155 | for channel in self._list_rooms(): 156 | self.add_task_to_queue(self._feed_channel, channel=channel) 157 | 158 | def _feed_channel(self, channel): 159 | messages = self._list_messages(channel) 160 | threads = self._list_threads(channel) 161 | for thread in threads: 162 | messages += self._list_thread_messages(thread) 163 | 164 | logging.info(f"Got {len(messages)} messages from room {channel.name} ({channel.id})" 165 | f" with {len(threads)} threads") 166 | 167 | last_msg: Optional[BasicDocument] = None 168 | for message in messages: 169 | if "msg" not in message: 170 | if last_msg is not None: 171 | IndexQueue.get_instance().put_single(doc=last_msg) 172 | last_msg = None 173 | continue 174 | 175 | text = message["msg"] 176 | author_id = message["u"]["_id"] 177 | author = self._get_author_details(author_id) 178 | 179 | if last_msg is not None: 180 | if last_msg.author == author.name: 181 | last_msg.content += f"\n{text}" 182 | continue 183 | else: 184 | IndexQueue.get_instance().put_single(doc=last_msg) 185 | last_msg = None 186 | 187 | timestamp = message["ts"] 188 | message_id = message["_id"] 189 | readable_timestamp = datetime.datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f%z") 190 | message_url = f"{self._raw_config.get('url')}/{channel.id}?msg={message_id}" 191 | last_msg = BasicDocument(title=channel.name, content=text, author=author.name, 192 | timestamp=readable_timestamp, id=message_id, 193 | data_source_id=self._data_source_id, location=channel.name, 194 | url=message_url, author_image_url=author.image_url, 195 | type=DocumentType.MESSAGE) 196 | 197 | if last_msg is not None: 198 | IndexQueue.get_instance().put_single(doc=last_msg) 199 | 200 | 201 | if __name__ == "__main__": 202 | import os 203 | conf = {"url": os.environ["ROCKETCHAT_URL"], "token_id": os.environ["ROCKETCHAT_TOKEN_ID"], "token_secret": os.environ["ROCKETCHAT_TOKEN_SECRET"]} 204 | rc = RocketchatDataSource(config=conf, data_source_id=0) 205 | rc._feed_new_documents() 206 | -------------------------------------------------------------------------------- /app/data_source/sources/slack/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/data_source/sources/slack/__init__.py -------------------------------------------------------------------------------- /app/data_source/sources/slack/slack.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | from dataclasses import dataclass 5 | from http.client import IncompleteRead 6 | from typing import Optional, Dict, List 7 | 8 | from retry import retry 9 | from slack_sdk import WebClient 10 | from slack_sdk.errors import SlackApiError 11 | 12 | from data_source.api.base_data_source import BaseDataSource, ConfigField, HTMLInputType, BaseDataSourceConfig 13 | from data_source.api.basic_document import DocumentType, BasicDocument 14 | from queues.index_queue import IndexQueue 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | @dataclass 20 | class SlackConversation: 21 | id: str 22 | name: str 23 | 24 | 25 | @dataclass 26 | class SlackAuthor: 27 | name: str 28 | image_url: str 29 | 30 | 31 | class SlackConfig(BaseDataSourceConfig): 32 | token: str 33 | 34 | 35 | class SlackDataSource(BaseDataSource): 36 | FEED_BATCH_SIZE = 500 37 | 38 | @staticmethod 39 | def get_config_fields() -> List[ConfigField]: 40 | return [ 41 | ConfigField(label="Bot User OAuth Token", name="token", type=HTMLInputType.PASSWORD) 42 | ] 43 | 44 | @staticmethod 45 | async def validate_config(config: Dict) -> None: 46 | slack_config = SlackConfig(**config) 47 | slack = WebClient(token=slack_config.token) 48 | slack.auth_test() 49 | 50 | @staticmethod 51 | def _is_valid_message(message: Dict) -> bool: 52 | return 'client_msg_id' in message or 'bot_id' in message 53 | 54 | def __init__(self, *args, **kwargs): 55 | super().__init__(*args, **kwargs) 56 | slack_config = SlackConfig(**self._raw_config) 57 | self._slack = WebClient(token=slack_config.token) 58 | self._authors_cache: Dict[str, SlackAuthor] = {} 59 | 60 | def _list_conversations(self) -> List[SlackConversation]: 61 | conversations = self._slack.conversations_list(exclude_archived=True, limit=1000) 62 | return [SlackConversation(id=conv['id'], name=conv['name']) 63 | for conv in conversations['channels']] 64 | 65 | def _feed_conversations(self, conversations: List[SlackConversation]) -> List[SlackConversation]: 66 | joined_conversations = [] 67 | 68 | for conv in conversations: 69 | try: 70 | result = self._slack.conversations_join(channel=conv.id) 71 | if result['ok']: 72 | logger.info(f'Joined channel {conv.name}, adding a fetching task...') 73 | self.add_task_to_queue(self._feed_conversation, conv=conv) 74 | except Exception as e: 75 | logger.warning(f'Could not join channel {conv.name}: {e}') 76 | 77 | return joined_conversations 78 | 79 | def _get_author_details(self, author_id: str) -> SlackAuthor: 80 | author = self._authors_cache.get(author_id, None) 81 | if author is None: 82 | author_info = self._slack.users_info(user=author_id) 83 | user = author_info['user'] 84 | name = user.get('real_name') or user.get('name') or user.get('profile', {}).get('display_name') or 'Unknown' 85 | author = SlackAuthor(name=name, 86 | image_url=author_info['user']['profile']['image_72']) 87 | self._authors_cache[author_id] = author 88 | 89 | return author 90 | 91 | def _feed_new_documents(self) -> None: 92 | conversations = self._list_conversations() 93 | logger.info(f'Found {len(conversations)} conversations') 94 | 95 | self._feed_conversations(conversations) 96 | 97 | def _feed_conversation(self, conv: SlackConversation): 98 | logger.info(f'Feeding conversation {conv.name}') 99 | 100 | last_msg: Optional[BasicDocument] = None 101 | 102 | messages = self._fetch_conversation_messages(conv) 103 | for message in messages: 104 | if not self._is_valid_message(message): 105 | if last_msg is not None: 106 | IndexQueue.get_instance().put_single(doc=last_msg) 107 | last_msg = None 108 | continue 109 | 110 | text = message['text'] 111 | if author_id := message.get('user'): 112 | author = self._get_author_details(author_id) 113 | elif message.get('bot_id'): 114 | author = SlackAuthor(name=message.get('username'), image_url=message.get('icons', {}).get('image_48')) 115 | else: 116 | logger.warning(f'Unknown message author: {message}') 117 | continue 118 | 119 | if last_msg is not None: 120 | if last_msg.author == author.name: 121 | last_msg.content += f"\n{text}" 122 | continue 123 | else: 124 | IndexQueue.get_instance().put_single(doc=last_msg) 125 | last_msg = None 126 | 127 | timestamp = message['ts'] 128 | message_id = message.get('client_msg_id') or timestamp 129 | readable_timestamp = datetime.datetime.fromtimestamp(float(timestamp)) 130 | message_url = f"https://slack.com/app_redirect?channel={conv.id}&message_ts={timestamp}" 131 | last_msg = BasicDocument(title=author.name, content=text, author=author.name, 132 | timestamp=readable_timestamp, id=message_id, 133 | data_source_id=self._data_source_id, location=conv.name, 134 | url=message_url, author_image_url=author.image_url, 135 | type=DocumentType.MESSAGE) 136 | 137 | if last_msg is not None: 138 | IndexQueue.get_instance().put_single(doc=last_msg) 139 | 140 | @retry(tries=5, delay=1, backoff=2, logger=logger) 141 | def _get_conversation_history(self, conv: SlackConversation, cursor: str, last_index_unix: str): 142 | try: 143 | return self._slack.conversations_history(channel=conv.id, oldest=last_index_unix, 144 | limit=1000, cursor=cursor) 145 | except SlackApiError as e: 146 | logger.warning(f'SlackApi error while fetching messages for conversation {conv.name}: {e}') 147 | response = e.response 148 | if response['error'] == 'ratelimited': 149 | retry_after_seconds = int(response['headers']['Retry-After']) 150 | logger.warning(f'Rate-limited: Slack API rate limit exceeded,' 151 | f' retrying after {retry_after_seconds} seconds') 152 | time.sleep(retry_after_seconds) 153 | raise e 154 | except IncompleteRead as e: 155 | logger.warning(f'IncompleteRead error while fetching messages for conversation {conv.name}') 156 | raise e 157 | 158 | def _fetch_conversation_messages(self, conv: SlackConversation): 159 | messages = [] 160 | cursor = None 161 | has_more = True 162 | last_index_unix = self._last_index_time.timestamp() 163 | logger.info(f'Fetching messages for conversation {conv.name}') 164 | 165 | while has_more: 166 | try: 167 | response = self._get_conversation_history(conv=conv, cursor=cursor, 168 | last_index_unix=str(last_index_unix)) 169 | except Exception as e: 170 | logger.warning(f'Error fetching all messages for conversation {conv.name},' 171 | f' returning {len(messages)} messages. Error: {e}') 172 | return messages 173 | 174 | logger.info(f'Fetched {len(response["messages"])} messages for conversation {conv.name}') 175 | messages.extend(response['messages']) 176 | if has_more := response["has_more"]: 177 | cursor = response["response_metadata"]["next_cursor"] 178 | 179 | return messages 180 | -------------------------------------------------------------------------------- /app/db_engine.py: -------------------------------------------------------------------------------- 1 | from schemas import DataSourceType 2 | from schemas import DataSource 3 | from schemas import Document 4 | from schemas import Paragraph 5 | 6 | from sqlalchemy import create_engine 7 | from sqlalchemy.orm import sessionmaker 8 | # import base document and then register all classes 9 | from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession 10 | from schemas.base import Base 11 | 12 | from paths import SQLITE_DB_PATH 13 | 14 | db_url = f'sqlite:///{SQLITE_DB_PATH}' 15 | engine = create_engine(db_url) 16 | Base.metadata.create_all(engine) 17 | Session = sessionmaker(bind=engine) 18 | 19 | async_db_url = db_url.replace('sqlite', 'sqlite+aiosqlite', 1) 20 | async_engine = create_async_engine(async_db_url) 21 | async_session = sessionmaker(async_engine, expire_on_commit=False, class_=AsyncSession) 22 | -------------------------------------------------------------------------------- /app/indexing/__init__.py: -------------------------------------------------------------------------------- 1 | from indexing.faiss_index import FaissIndex 2 | -------------------------------------------------------------------------------- /app/indexing/background_indexer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import threading 3 | from typing import List 4 | 5 | from queues.index_queue import IndexQueue 6 | from indexing.index_documents import Indexer 7 | 8 | 9 | logger = logging.getLogger() 10 | 11 | 12 | class BackgroundIndexer: 13 | _thread = None 14 | _stop_event = threading.Event() 15 | _currently_indexing_count = 0 16 | _total_indexed_count = 0 17 | 18 | @classmethod 19 | def get_currently_indexing(cls) -> int: 20 | return cls._currently_indexing_count 21 | 22 | @classmethod 23 | def get_indexed_count(cls) -> int: 24 | return cls._total_indexed_count 25 | 26 | @classmethod 27 | def reset_indexed_count(cls): 28 | cls._total_indexed_count = 0 29 | 30 | @classmethod 31 | def start(cls): 32 | cls._thread = threading.Thread(target=cls.run) 33 | cls._thread.start() 34 | 35 | @classmethod 36 | def stop(cls): 37 | cls._stop_event.set() 38 | logging.info('Stop event set, waiting for background indexer to stop...') 39 | 40 | cls._thread.join() 41 | logging.info('Background indexer stopped') 42 | 43 | cls._thread = None 44 | 45 | @staticmethod 46 | def run(): 47 | docs_queue_instance = IndexQueue.get_instance() 48 | logger.info(f'Background indexer started...') 49 | 50 | while not BackgroundIndexer._stop_event.is_set(): 51 | try: 52 | queue_items = docs_queue_instance.consume_all() 53 | if not queue_items: 54 | continue 55 | 56 | BackgroundIndexer._currently_indexing_count = len(queue_items) 57 | logger.info(f'Got chunk of {len(queue_items)} documents') 58 | 59 | docs = [doc.doc for doc in queue_items] 60 | Indexer.index_documents(docs) 61 | BackgroundIndexer._ack_chunk(docs_queue_instance, [doc.queue_item_id for doc in queue_items]) 62 | except Exception as e: 63 | logger.exception(e) 64 | logger.error('Error while indexing documents...') 65 | 66 | @staticmethod 67 | def _ack_chunk(queue: IndexQueue, ids: List[int]): 68 | logger.info(f'Finished indexing chunk of {len(ids)} documents') 69 | for item_id in ids: 70 | queue.ack(id=item_id) 71 | 72 | logger.info(f'Acked {len(ids)} documents.') 73 | BackgroundIndexer._total_indexed_count += len(ids) 74 | BackgroundIndexer._currently_indexing_count = 0 75 | -------------------------------------------------------------------------------- /app/indexing/bm25_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from typing import List 4 | 5 | import nltk 6 | import numpy as np 7 | from rank_bm25 import BM25Okapi 8 | 9 | from db_engine import Session 10 | from paths import BM25_INDEX_PATH 11 | from schemas import Paragraph 12 | 13 | 14 | def _add_metadata_for_indexing(paragraph: Paragraph) -> str: 15 | result = paragraph.content 16 | if paragraph.document.title is not None: 17 | result += ' ' + paragraph.document.title 18 | if paragraph.document.author is not None: 19 | result += ' ' + paragraph.document.author 20 | if data_source_name := paragraph.document.data_source.type.name: 21 | result += ' ' + data_source_name 22 | return result 23 | 24 | 25 | class Bm25Index: 26 | instance = None 27 | 28 | @staticmethod 29 | def create(): 30 | if Bm25Index.instance is not None: 31 | raise RuntimeError("Index is already initialized") 32 | 33 | if os.path.exists(BM25_INDEX_PATH): 34 | with open(BM25_INDEX_PATH, 'rb') as f: 35 | Bm25Index.instance = pickle.load(f) 36 | else: 37 | Bm25Index.instance = Bm25Index() 38 | 39 | @staticmethod 40 | def get() -> 'Bm25Index': 41 | if Bm25Index.instance is None: 42 | raise RuntimeError("Index is not initialized") 43 | return Bm25Index.instance 44 | 45 | def __init__(self) -> None: 46 | self.index = None 47 | self.id_map = [] 48 | 49 | def _update(self, session): 50 | all_paragraphs = session.query(Paragraph).all() 51 | if len(all_paragraphs) == 0: 52 | self.index = None 53 | self.id_map = [] 54 | return 55 | 56 | corpus = [nltk.word_tokenize(_add_metadata_for_indexing(paragraph)) for paragraph in all_paragraphs] 57 | id_map = [paragraph.id for paragraph in all_paragraphs] 58 | self.index = BM25Okapi(corpus) 59 | self.id_map = id_map 60 | 61 | def update(self, session=None): 62 | if session is None: 63 | with Session() as session: 64 | self._update(session) 65 | else: 66 | self._update(session) 67 | 68 | self._save() 69 | 70 | def search(self, query: str, top_k: int) -> List[int]: 71 | if self.index is None: 72 | return [] 73 | tokenized_query = nltk.word_tokenize(query) 74 | bm25_scores = self.index.get_scores(tokenized_query) 75 | top_k = min(top_k, len(bm25_scores)) 76 | top_n = np.argpartition(bm25_scores, -top_k)[-top_k:] 77 | bm25_hits = [{'id': self.id_map[idx], 'score': bm25_scores[idx]} for idx in top_n] 78 | bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True) 79 | return [hit['id'] for hit in bm25_hits] 80 | 81 | def clear(self): 82 | self.index = None 83 | self.id_map = [] 84 | self._save() 85 | 86 | def _save(self): 87 | with open(BM25_INDEX_PATH, 'wb') as f: 88 | pickle.dump(self, f) 89 | -------------------------------------------------------------------------------- /app/indexing/faiss_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import faiss 4 | 5 | from paths import FAISS_INDEX_PATH 6 | 7 | MODEL_DIM = 384 8 | 9 | 10 | class FaissIndex: 11 | instance = None 12 | 13 | @staticmethod 14 | def create(): 15 | if FaissIndex.instance is not None: 16 | raise RuntimeError("Index is already initialized") 17 | 18 | FaissIndex.instance = FaissIndex() 19 | 20 | @staticmethod 21 | def get() -> 'FaissIndex': 22 | if FaissIndex.instance is None: 23 | raise RuntimeError("Index is not initialized") 24 | return FaissIndex.instance 25 | 26 | def __init__(self) -> None: 27 | if os.path.exists(FAISS_INDEX_PATH): 28 | index = faiss.read_index(FAISS_INDEX_PATH) 29 | else: 30 | index = faiss.IndexFlatIP(MODEL_DIM) 31 | index = faiss.IndexIDMap(index) 32 | 33 | self.index: faiss.IndexIDMap = index 34 | 35 | def update(self, ids: torch.LongTensor, embeddings: torch.FloatTensor): 36 | self.index.add_with_ids(embeddings.cpu(), ids) 37 | 38 | faiss.write_index(self.index, FAISS_INDEX_PATH) 39 | 40 | def remove(self, ids: torch.LongTensor): 41 | self.index.remove_ids(torch.tensor(ids)) 42 | 43 | faiss.write_index(self.index, FAISS_INDEX_PATH) 44 | 45 | def search(self, queries: torch.FloatTensor, top_k: int, *args, **kwargs): 46 | if queries.ndim == 1: 47 | queries = queries.unsqueeze(0) 48 | _, ids = self.index.search(queries.cpu(), top_k, *args, **kwargs) 49 | return ids 50 | 51 | def clear(self): 52 | self.index.reset() 53 | faiss.write_index(self.index, FAISS_INDEX_PATH) 54 | -------------------------------------------------------------------------------- /app/indexing/index_documents.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from enum import Enum 4 | from typing import List, Optional 5 | 6 | from data_source.api.basic_document import BasicDocument, FileType 7 | from db_engine import Session 8 | from indexing.bm25_index import Bm25Index 9 | from indexing.faiss_index import FaissIndex 10 | from models import bi_encoder 11 | from parsers.pdf import split_PDF_into_paragraphs 12 | from paths import IS_IN_DOCKER 13 | from schemas import Document, Paragraph 14 | from langchain.schema import Document as PDFDocument 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def get_enum_value_or_none(enum: Optional[Enum]) -> Optional[str]: 21 | if enum is None: 22 | return None 23 | 24 | return enum.value 25 | 26 | 27 | class Indexer: 28 | 29 | @staticmethod 30 | def basic_to_document(document: BasicDocument, parent: Document = None) -> Document: 31 | paragraphs = Indexer._split_into_paragraphs(document.content) 32 | 33 | return Document( 34 | data_source_id=document.data_source_id, 35 | id_in_data_source=document.id_in_data_source, 36 | type=document.type.value, 37 | file_type=get_enum_value_or_none(document.file_type), 38 | status=document.status, 39 | is_active=document.is_active, 40 | title=document.title, 41 | author=document.author, 42 | author_image_url=document.author_image_url, 43 | location=document.location, 44 | url=document.url, 45 | timestamp=document.timestamp, 46 | paragraphs=[ 47 | Paragraph(content=content) 48 | for content in paragraphs 49 | ], 50 | parent=parent 51 | ) 52 | 53 | @staticmethod 54 | def index_documents(documents: List[BasicDocument]): 55 | logger.info(f"Indexing {len(documents)} documents") 56 | 57 | ids_in_data_source = [document.id_in_data_source for document in documents] 58 | 59 | with Session() as session: 60 | documents_to_delete = session.query(Document).filter( 61 | Document.id_in_data_source.in_(ids_in_data_source)).all() 62 | if documents_to_delete: 63 | logging.info(f'removing documents that were updated and need to be re-indexed.') 64 | Indexer.remove_documents(documents_to_delete, session) 65 | for document in documents_to_delete: 66 | # Currently bulk deleting doesn't cascade. So we need to delete them one by one. 67 | # See https://stackoverflow.com/a/19245058/3541901 68 | session.delete(document) 69 | session.commit() 70 | 71 | with Session() as session: 72 | db_documents = [] 73 | for document in documents: 74 | # Split the content into paragraphs that fit inside the database 75 | paragraphs = Indexer._split_into_paragraphs(document.content) 76 | # Create a new document in the database 77 | db_document = Indexer.basic_to_document(document) 78 | children = [] 79 | if document.children: 80 | children = [Indexer.basic_to_document(child, db_document) for child in document.children] 81 | db_documents.append(db_document) 82 | db_documents.extend(children) 83 | 84 | # Save the documents to the database 85 | session.add_all(db_documents) 86 | session.commit() 87 | 88 | # Create a list of all the paragraphs in the documents 89 | logger.info(f"Indexing {len(db_documents)} documents => {len(paragraphs)} paragraphs") 90 | paragraphs = [paragraph for document in db_documents for paragraph in document.paragraphs] 91 | if len(paragraphs) == 0: 92 | logger.info(f"No paragraphs to index") 93 | return 94 | 95 | paragraph_ids = [paragraph.id for paragraph in paragraphs] 96 | paragraph_contents = [Indexer._add_metadata_for_indexing(paragraph) for paragraph in paragraphs] 97 | 98 | logger.info(f"Updating BM25 index...") 99 | Bm25Index.get().update() 100 | 101 | if len(paragraph_contents) == 0: 102 | return 103 | 104 | # Encode the paragraphs 105 | show_progress_bar = not IS_IN_DOCKER 106 | logger.info(f"Encoding with bi-encoder...") 107 | embeddings = bi_encoder.encode(paragraph_contents, convert_to_tensor=True, show_progress_bar=show_progress_bar) 108 | 109 | # Add the embeddings to the index 110 | logger.info(f"Updating Faiss index...") 111 | FaissIndex.get().update(paragraph_ids, embeddings) 112 | 113 | logger.info(f"Finished indexing {len(documents)} documents => {len(paragraphs)} paragraphs") 114 | 115 | @staticmethod 116 | def _split_into_paragraphs(text, minimum_length=256): 117 | """ 118 | split into paragraphs and batch small paragraphs together into the same paragraph 119 | """ 120 | if text is None: 121 | return [] 122 | paragraphs = [] 123 | current_paragraph = '' 124 | for paragraph in re.split(r'\n\s*\n', text): 125 | if len(current_paragraph) > 0: 126 | current_paragraph += ' ' 127 | current_paragraph += paragraph.strip() 128 | 129 | if len(current_paragraph) > minimum_length: 130 | paragraphs.append(current_paragraph) 131 | current_paragraph = '' 132 | 133 | if len(current_paragraph) > 0: 134 | paragraphs.append(current_paragraph) 135 | return paragraphs 136 | 137 | @staticmethod 138 | def _add_metadata_for_indexing(paragraph: Paragraph) -> str: 139 | result = paragraph.content 140 | if paragraph.document.title is not None: 141 | result += '; ' + paragraph.document.title 142 | return result 143 | 144 | @staticmethod 145 | def remove_documents(documents: List[Document], session=None): 146 | logger.info(f"Removing {len(documents)} documents") 147 | 148 | # Get the paragraphs from the documents 149 | db_paragraphs = [paragraph for document in documents for paragraph in document.paragraphs] 150 | 151 | # Remove the paragraphs from the index 152 | paragraph_ids = [paragraph.id for paragraph in db_paragraphs] 153 | 154 | logger.info(f"Removing documents from faiss index...") 155 | FaissIndex.get().remove(paragraph_ids) 156 | 157 | logger.info(f"Removing documents from BM25 index...") 158 | Bm25Index.get().update(session=session) 159 | 160 | logger.info(f"Finished removing {len(documents)} documents => {len(db_paragraphs)} paragraphs") 161 | -------------------------------------------------------------------------------- /app/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from typing import List 4 | import os 5 | import torch 6 | from fastapi import FastAPI, Request, HTTPException 7 | from fastapi.middleware.cors import CORSMiddleware 8 | from fastapi_restful.tasks import repeat_every 9 | from starlette.responses import Response, FileResponse 10 | 11 | from api.data_source import router as data_source_router 12 | from api.search import router as search_router 13 | from data_source.api.exception import KnownException 14 | from data_source.api.context import DataSourceContext 15 | from data_source.api.utils import get_utc_time_now 16 | from db_engine import Session 17 | from indexing.background_indexer import BackgroundIndexer 18 | from indexing.bm25_index import Bm25Index 19 | from indexing.faiss_index import FaissIndex 20 | from queues.index_queue import IndexQueue 21 | from paths import UI_PATH 22 | from queues.task_queue import TaskQueue 23 | from schemas import DataSource 24 | from schemas.document import Document 25 | from schemas.paragraph import Paragraph 26 | from workers import Workers 27 | from telemetry import Posthog 28 | 29 | logging.basicConfig(level=logging.INFO, 30 | format='%(asctime)s | %(levelname)s | %(filename)s:%(lineno)d | %(message)s') 31 | logger = logging.getLogger(__name__) 32 | logging.getLogger("urllib3").propagate = False 33 | 34 | app = FastAPI() 35 | 36 | 37 | async def catch_exceptions_middleware(request: Request, call_next): 38 | try: 39 | return await call_next(request) 40 | except KnownException as e: 41 | logger.exception("Known exception") 42 | return Response(e.message, status_code=501) 43 | except Exception: 44 | logger.exception("Server error") 45 | return Response("Oops. Server error...", status_code=500) 46 | 47 | app.middleware('http')(catch_exceptions_middleware) 48 | 49 | 50 | app.add_middleware( 51 | CORSMiddleware, 52 | allow_origins=["*"], 53 | allow_credentials=True, 54 | allow_methods=["*"], 55 | allow_headers=["*"], 56 | ) 57 | app.include_router(search_router, prefix="/api/v1") 58 | app.include_router(data_source_router, prefix="/api/v1") 59 | 60 | 61 | def _check_for_new_documents(force=False): 62 | with Session() as session: 63 | data_sources: List[DataSource] = session.query(DataSource).all() 64 | for data_source in data_sources: 65 | # data source should be checked once every hour 66 | if (get_utc_time_now() - data_source.last_indexed_at).total_seconds() <= 60 * 60 and not force: 67 | continue 68 | 69 | logger.info(f"Checking for new docs in {data_source.type.name} (id: {data_source.id})") 70 | data_source_instance = DataSourceContext.get_data_source_instance(data_source_id=data_source.id) 71 | data_source_instance._last_index_time = data_source.last_indexed_at 72 | data_source_instance.index(force=force) 73 | 74 | 75 | @app.on_event("startup") 76 | @repeat_every(seconds=60) 77 | def check_for_new_documents(): 78 | _check_for_new_documents(force=False) 79 | 80 | 81 | @app.on_event("startup") 82 | def send_startup_telemetry(): 83 | try: 84 | Posthog.send_startup_telemetry() 85 | except: 86 | pass 87 | 88 | 89 | @app.on_event("startup") 90 | @repeat_every(wait_first=60 * 60 * 24, seconds=60 * 60 * 24) 91 | def send_daily_telemetry(): 92 | try: 93 | Posthog.send_daily() 94 | except: 95 | pass 96 | 97 | 98 | @app.on_event("startup") 99 | async def startup_event(): 100 | if not torch.cuda.is_available(): 101 | logger.warning("CUDA is not available, using CPU. This will make indexing and search very slow!!!") 102 | FaissIndex.create() 103 | Bm25Index.create() 104 | DataSourceContext.init() 105 | BackgroundIndexer.start() 106 | Workers.start() 107 | 108 | 109 | @app.on_event("shutdown") 110 | async def shutdown_event(): 111 | Workers.stop() 112 | BackgroundIndexer.stop() 113 | 114 | 115 | @app.get("/api/v1/status") 116 | def status(): 117 | @dataclass 118 | class Status: 119 | docs_in_indexing: int 120 | docs_left_to_index: int 121 | docs_indexed: int 122 | 123 | return Status(docs_in_indexing=BackgroundIndexer.get_currently_indexing(), 124 | docs_left_to_index=IndexQueue.get_instance().qsize() + TaskQueue.get_instance().qsize(), 125 | docs_indexed=BackgroundIndexer.get_indexed_count()) 126 | 127 | 128 | @app.post("/clear-index") 129 | async def clear_index(): 130 | FaissIndex.get().clear() 131 | Bm25Index.get().clear() 132 | with Session() as session: 133 | session.query(Document).delete() 134 | session.query(Paragraph).delete() 135 | session.commit() 136 | 137 | 138 | @app.post("/check-for-new-documents") 139 | async def check_for_new_documents_endpoint(): 140 | _check_for_new_documents(force=True) 141 | 142 | 143 | @app.get("/{path:path}", include_in_schema=False) 144 | async def serve_ui(request: Request, path: str): 145 | try: 146 | if path == "" or path.startswith("search"): 147 | file_path = os.path.join(UI_PATH, "index.html") 148 | else: 149 | file_path = os.path.join(UI_PATH, path) 150 | 151 | return FileResponse(file_path, status_code=200) 152 | except Exception as e: 153 | logger.warning(f"Failed to serve UI (you probably need to build it): {e}") 154 | raise HTTPException(status_code=404, detail="File not found") 155 | 156 | 157 | if __name__ == '__main__': 158 | import uvicorn 159 | uvicorn.run("main:app", host="localhost", port=8000) 160 | -------------------------------------------------------------------------------- /app/models.py: -------------------------------------------------------------------------------- 1 | from sentence_transformers import SentenceTransformer, CrossEncoder 2 | from transformers import pipeline 3 | import torch 4 | 5 | 6 | bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') 7 | 8 | cross_encoder_small = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2') 9 | cross_encoder_large = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') 10 | 11 | qa_model = pipeline('question-answering', model='deepset/roberta-base-squad2') 12 | -------------------------------------------------------------------------------- /app/parsers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/parsers/__init__.py -------------------------------------------------------------------------------- /app/parsers/docx.py: -------------------------------------------------------------------------------- 1 | import mammoth 2 | 3 | 4 | def docx_to_html(input_filename: str) -> str: 5 | with open(input_filename, "rb") as docx_file: 6 | result = mammoth.convert_to_html(docx_file) 7 | return result.value 8 | -------------------------------------------------------------------------------- /app/parsers/html.py: -------------------------------------------------------------------------------- 1 | import re 2 | from bs4 import BeautifulSoup 3 | 4 | 5 | def html_to_text(html: str) -> str: 6 | # Becuase documents only contain text, we use a colon to separate subtitles from the text 7 | html = re.sub(r'(?=<\/h[1234567]>)', ': ', html) 8 | 9 | soup = BeautifulSoup(html, features='html.parser') 10 | plain_text = soup.get_text(separator="\n\n") 11 | 12 | # When there is a link immidiately followed by a symbol, BeautifulSoup adds whitespace between them. We remove it. 13 | plain_text = re.sub(r'\s+(?=[\.\?\!\:\,])', '', plain_text) 14 | return plain_text -------------------------------------------------------------------------------- /app/parsers/pdf.py: -------------------------------------------------------------------------------- 1 | from PyPDF2 import PdfReader 2 | from typing import List 3 | from langchain.document_loaders import PyPDFLoader 4 | from langchain.schema import Document 5 | from langchain.text_splitter import CharacterTextSplitter 6 | def pdf_to_text(input_filename: str) -> str: 7 | pdf_file = PdfReader(input_filename) 8 | text='' 9 | 10 | for page in pdf_file.pages: 11 | text = text + page.extract_text() 12 | 13 | return text 14 | 15 | 16 | def pdf_to_textV2(input_filename: str) -> str: 17 | loader = PyPDFLoader(input_filename) 18 | documents = loader.load() 19 | text_split = CharacterTextSplitter(chunk_size=256, chunk_overlap=0) 20 | texts = text_split.split_documents(documents) 21 | current_paragraph = '' 22 | for text in texts: 23 | paragraph = text.page_content 24 | if len(current_paragraph) > 0: 25 | current_paragraph += '\n\n' 26 | current_paragraph += paragraph.strip() 27 | 28 | return current_paragraph 29 | 30 | -------------------------------------------------------------------------------- /app/parsers/pptx.py: -------------------------------------------------------------------------------- 1 | from pptx import Presentation 2 | 3 | 4 | def pptx_to_text(input_filename: str, slides_seperator: str = "\n\n") -> str: 5 | presentation = Presentation(input_filename) 6 | presentation_text = "" 7 | 8 | for slide in presentation.slides: 9 | 10 | slide_has_title = slide.shapes.title is not None 11 | 12 | for shape in slide.shapes: 13 | if not hasattr(shape, "text"): 14 | continue 15 | 16 | shape_text = f'\n{shape.text}' 17 | 18 | if slide_has_title and shape.text == slide.shapes.title.text: 19 | shape_text += ":" 20 | 21 | presentation_text += shape_text 22 | 23 | presentation_text += slides_seperator 24 | 25 | return presentation_text 26 | -------------------------------------------------------------------------------- /app/parsers/txt.py: -------------------------------------------------------------------------------- 1 | def txt_to_string(input_filename: str) -> str: 2 | with open(input_filename, 'r', encoding="utf-8") as file: 3 | return file.read() 4 | -------------------------------------------------------------------------------- /app/paths.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os 3 | 4 | IS_IN_DOCKER = os.environ.get('DOCKER_DEPLOYMENT', False) 5 | 6 | if os.name == 'nt': 7 | STORAGE_PATH = Path(".gerev\\storage") 8 | else: 9 | STORAGE_PATH = Path('/opt/storage/') if IS_IN_DOCKER else Path(f'/home/{os.getlogin()}/.gerev/storage/') 10 | 11 | if not STORAGE_PATH.exists(): 12 | STORAGE_PATH.mkdir(parents=True) 13 | 14 | UI_PATH = Path('/ui/') if IS_IN_DOCKER else Path('../ui/build/') 15 | SQLITE_DB_PATH = STORAGE_PATH / 'db.sqlite3' 16 | SQLITE_TASKS_PATH = STORAGE_PATH / 'tasks.sqlite3' 17 | SQLITE_INDEXING_PATH = STORAGE_PATH / 'indexing.sqlite3' 18 | FAISS_INDEX_PATH = str(STORAGE_PATH / 'faiss_index.bin') 19 | BM25_INDEX_PATH = str(STORAGE_PATH / 'bm25_index.bin') 20 | UUID_PATH = str(STORAGE_PATH / '.uuid') 21 | -------------------------------------------------------------------------------- /app/queues/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/queues/__init__.py -------------------------------------------------------------------------------- /app/queues/index_queue.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from dataclasses import dataclass 3 | from typing import List 4 | 5 | from persistqueue import SQLiteAckQueue 6 | 7 | from data_source.api.basic_document import BasicDocument 8 | from paths import SQLITE_INDEXING_PATH 9 | 10 | 11 | @dataclass 12 | class IndexQueueItem: 13 | queue_item_id: int 14 | doc: BasicDocument 15 | 16 | 17 | class IndexQueue(SQLiteAckQueue): 18 | _instance = None 19 | _lock = threading.Lock() 20 | 21 | @classmethod 22 | def get_instance(cls): 23 | with cls._lock: 24 | if cls._instance is None: 25 | cls._instance = cls() 26 | return cls._instance 27 | 28 | def __init__(self): 29 | if IndexQueue._instance is not None: 30 | raise RuntimeError("Queue is a singleton, use .get() to get the instance") 31 | 32 | self.condition = threading.Condition() 33 | super().__init__(path=SQLITE_INDEXING_PATH, multithreading=True, name="index") 34 | 35 | def put_single(self, doc: BasicDocument): 36 | self.put([doc]) 37 | 38 | def put(self, docs: List[BasicDocument]): 39 | with self.condition: 40 | for doc in docs: 41 | super().put(doc) 42 | 43 | self.condition.notify_all() 44 | 45 | def consume_all(self, max_docs=5000, timeout=1) -> List[IndexQueueItem]: 46 | with self.condition: 47 | self.condition.wait(timeout=timeout) 48 | 49 | queue_items = [] 50 | count = 0 51 | while not super().empty() and count < max_docs: 52 | raw_item = super().get(raw=True) 53 | queue_items.append(IndexQueueItem(queue_item_id=raw_item['pqid'], doc=raw_item['data'])) 54 | count += 1 55 | 56 | return queue_items 57 | -------------------------------------------------------------------------------- /app/queues/task_queue.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | from persistqueue import SQLiteAckQueue, Empty 6 | 7 | from paths import SQLITE_TASKS_PATH 8 | 9 | 10 | @dataclass 11 | class Task: 12 | data_source_id: int 13 | function_name: str 14 | kwargs: dict 15 | attempts: int = 3 16 | 17 | 18 | @dataclass 19 | class TaskQueueItem: 20 | queue_item_id: int 21 | task: Task 22 | 23 | 24 | class TaskQueue(SQLiteAckQueue): 25 | _instance = None 26 | _lock = threading.Lock() 27 | 28 | @classmethod 29 | def get_instance(cls): 30 | with cls._lock: 31 | if cls._instance is None: 32 | cls._instance = cls() 33 | return cls._instance 34 | 35 | def __init__(self): 36 | if TaskQueue._instance is not None: 37 | raise RuntimeError("TaskQueue is a singleton, use .get() to get the instance") 38 | 39 | self.condition = threading.Condition() 40 | super().__init__(path=SQLITE_TASKS_PATH, multithreading=True, name="task") 41 | 42 | def add_task(self, task: Task): 43 | self.put(task) 44 | 45 | def get_task(self, timeout=1) -> Optional[TaskQueueItem]: 46 | try: 47 | raw_item = super().get(raw=True, block=True, timeout=timeout) 48 | return TaskQueueItem(queue_item_id=raw_item['pqid'], task=raw_item['data']) 49 | 50 | except Empty: 51 | return None 52 | -------------------------------------------------------------------------------- /app/requirements.txt: -------------------------------------------------------------------------------- 1 | faiss-cpu 2 | transformers 3 | sentence_transformers 4 | sqlalchemy>=2.0.4 5 | fastapi 6 | uvicorn 7 | rank_bm25 8 | atlassian-python-api 9 | beautifulsoup4 10 | python-dotenv 11 | slack_sdk 12 | pydantic 13 | python-multipart 14 | posthog 15 | fastapi-restful 16 | google-api-python-client 17 | google-auth-httplib2 18 | google-auth-oauthlib 19 | oauth2client 20 | mammoth 21 | python-pptx 22 | alembic 23 | rocketchat-API 24 | mattermostdriver 25 | persistqueue 26 | retry 27 | PyPDF2 28 | pytz 29 | aiosqlite 30 | starlette 31 | torch 32 | langchain~=0.0.141 33 | nltk 34 | numpy 35 | requests 36 | python-dateutil 37 | httplib2 38 | pypdf 39 | pycryptodome -------------------------------------------------------------------------------- /app/schemas/__init__.py: -------------------------------------------------------------------------------- 1 | from schemas.data_source_type import DataSourceType 2 | from schemas.data_source import DataSource 3 | from schemas.document import Document 4 | from schemas.paragraph import Paragraph 5 | -------------------------------------------------------------------------------- /app/schemas/base.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.orm import DeclarativeBase 2 | 3 | 4 | class Base(DeclarativeBase): 5 | pass 6 | -------------------------------------------------------------------------------- /app/schemas/data_source.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional 3 | 4 | from schemas.base import Base 5 | from sqlalchemy import ForeignKey, Column, Integer, Connection 6 | from sqlalchemy.orm import Mapped, mapped_column, relationship 7 | from sqlalchemy import String, DateTime 8 | from sqlalchemy import event 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class DataSource(Base): 15 | __tablename__ = "data_source" 16 | 17 | id: Mapped[int] = mapped_column(primary_key=True) 18 | type_id = Column(Integer, ForeignKey('data_source_type.id')) 19 | type = relationship("DataSourceType", back_populates="data_sources") 20 | config: Mapped[Optional[str]] = mapped_column(String(512)) 21 | last_indexed_at: Mapped[Optional[DateTime]] = mapped_column(DateTime()) 22 | created_at: Mapped[Optional[DateTime]] = mapped_column(DateTime()) 23 | documents = relationship("Document", back_populates="data_source", cascade='all, delete, delete-orphan') 24 | 25 | 26 | @event.listens_for(DataSource, 'before_delete') 27 | def receive_before_delete(mapper, connection: Connection, target): 28 | # import here to avoid circular imports 29 | from indexing.index_documents import Indexer 30 | from db_engine import Session 31 | 32 | with Session(bind=connection) as session: 33 | logger.info(f"Deleting documents for data source {target.id}...") 34 | Indexer.remove_documents(target.documents, session=session) 35 | -------------------------------------------------------------------------------- /app/schemas/data_source_type.py: -------------------------------------------------------------------------------- 1 | from schemas.base import Base 2 | from sqlalchemy import String 3 | from sqlalchemy.orm import Mapped, mapped_column, relationship 4 | 5 | 6 | class DataSourceType(Base): 7 | __tablename__ = 'data_source_type' 8 | 9 | id: Mapped[int] = mapped_column(primary_key=True) 10 | name: Mapped[str] = mapped_column(String(32)) 11 | display_name: Mapped[str] = mapped_column(String(32)) 12 | config_fields: Mapped[str] = mapped_column(String(1024)) 13 | data_sources = relationship("DataSource", back_populates="type", foreign_keys="DataSource.type_id") 14 | -------------------------------------------------------------------------------- /app/schemas/document.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | from sqlalchemy import String, DateTime, ForeignKey, Column, Integer, Boolean 5 | from sqlalchemy.orm import Mapped, mapped_column, relationship, backref 6 | 7 | from schemas.base import Base 8 | 9 | 10 | @dataclass 11 | class Document(Base): 12 | __tablename__ = 'document' 13 | 14 | id: Mapped[int] = mapped_column(primary_key=True) 15 | id_in_data_source: Mapped[str] = mapped_column(String(64)) 16 | data_source_id = Column(Integer, ForeignKey('data_source.id')) 17 | data_source = relationship("DataSource", back_populates="documents") 18 | type: Mapped[Optional[str]] = mapped_column(String(32)) 19 | file_type: Mapped[Optional[str]] = mapped_column(String(32)) 20 | status: Mapped[Optional[str]] = mapped_column(String(32)) 21 | is_active: Mapped[Optional[bool]] = mapped_column(Boolean()) 22 | title: Mapped[Optional[str]] = mapped_column(String(128)) 23 | author: Mapped[Optional[str]] = mapped_column(String(64)) 24 | author_image_url: Mapped[Optional[str]] = mapped_column(String(512)) 25 | url: Mapped[Optional[str]] = mapped_column(String(512)) 26 | location: Mapped[Optional[str]] = mapped_column(String(512)) 27 | timestamp: Mapped[Optional[DateTime]] = mapped_column(DateTime()) 28 | paragraphs = relationship("Paragraph", back_populates="document", cascade='all, delete, delete-orphan', 29 | foreign_keys="Paragraph.document_id") 30 | 31 | parent_id = Column(Integer, ForeignKey('document.id')) 32 | children = relationship("Document", foreign_keys=[parent_id], backref=backref("parent", remote_side=[id]), 33 | cascade='all, delete, delete-orphan', single_parent=True) 34 | -------------------------------------------------------------------------------- /app/schemas/paragraph.py: -------------------------------------------------------------------------------- 1 | from schemas.base import Base 2 | from sqlalchemy import String, ForeignKey, Column, Integer 3 | from sqlalchemy.orm import Mapped, mapped_column, relationship 4 | 5 | 6 | class Paragraph(Base): 7 | __tablename__ = "paragraph" 8 | 9 | id: Mapped[int] = mapped_column(primary_key=True) 10 | content: Mapped[str] = mapped_column(String(2048)) 11 | 12 | document_id = Column(Integer, ForeignKey('document.id')) 13 | document = relationship("Document", back_populates="paragraphs") 14 | -------------------------------------------------------------------------------- /app/search_logic.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import logging 4 | import re 5 | import urllib.parse 6 | from concurrent.futures import ThreadPoolExecutor 7 | from dataclasses import dataclass 8 | from typing import List 9 | from typing import Optional 10 | 11 | import nltk 12 | import torch 13 | from sentence_transformers import CrossEncoder 14 | 15 | from data_source.api.basic_document import DocumentType, FileType, DocumentStatus 16 | from data_source.api.utils import get_confluence_user_image 17 | from db_engine import Session 18 | from indexing.bm25_index import Bm25Index 19 | from indexing.faiss_index import FaissIndex 20 | from models import bi_encoder, cross_encoder_small, cross_encoder_large, qa_model 21 | from schemas import Paragraph, Document 22 | from util import threaded_method 23 | 24 | BM_25_CANDIDATES = 100 if torch.cuda.is_available() else 20 25 | BI_ENCODER_CANDIDATES = 60 if torch.cuda.is_available() else 20 26 | SMALL_CROSS_ENCODER_CANDIDATES = 30 if torch.cuda.is_available() else 10 27 | 28 | nltk.download('punkt') 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | @dataclass 33 | class TextPart: 34 | content: str 35 | bold: bool 36 | 37 | 38 | @dataclass 39 | class SearchResult: 40 | type: DocumentType 41 | score: float 42 | content: List[TextPart] 43 | author: str 44 | title: str 45 | url: str 46 | location: str 47 | data_source: str 48 | time: datetime 49 | file_type: FileType 50 | status: str 51 | is_active: bool 52 | author_image_url: Optional[str] 53 | author_image_data: Optional[str] 54 | child: Optional['SearchResult'] = None 55 | 56 | 57 | @dataclass 58 | class Candidate: 59 | content: str 60 | score: float = 0.0 61 | document: Document = None 62 | answer_start: int = -1 63 | answer_end: int = -1 64 | parent: 'Candidate' = None 65 | 66 | def _text_anchor(self, url, text) -> str: 67 | if '#' not in url: 68 | url += '#' 69 | text = re.sub(r'\s+', ' ', text) 70 | text = text.strip() 71 | words = text.split() 72 | url += ':~:text=' 73 | if len(words) > 7: 74 | url += urllib.parse.quote(' '.join(words[:3])).replace('-', '%2D') 75 | url += ',' 76 | url += urllib.parse.quote(' '.join(words[-3:])).replace('-', '%2D') 77 | else: 78 | url += urllib.parse.quote(text).replace('-', '%2D') 79 | return url 80 | 81 | @threaded_method 82 | def to_search_result(self) -> SearchResult: 83 | parent_result = None 84 | 85 | if self.parent is not None: 86 | parent_result = self.parent.to_search_result() 87 | parent_result.score = max(parent_result.score, self.score) 88 | elif self.document.parent_id is not None: 89 | parent_result = Candidate(content="", score=self.score, document=self.document.parent).to_search_result() 90 | 91 | answer = TextPart(self.content[self.answer_start: self.answer_end], True) 92 | content = [answer] 93 | 94 | if self.answer_end < len(self.content) - 1: 95 | words = self.content[self.answer_end:].split() 96 | suffix = ' '.join(words[:20]) 97 | content.append(TextPart(suffix, False)) 98 | 99 | data_uri = None 100 | if self.document.data_source.type.name == 'confluence': 101 | config = json.loads(self.document.data_source.config) 102 | data_uri = get_confluence_user_image(self.document.author_image_url, config['token']) 103 | 104 | result = SearchResult(score=(self.score + 12) / 24 * 100, 105 | content=content, 106 | author=self.document.author, 107 | author_image_url=self.document.author_image_url, 108 | author_image_data=data_uri, 109 | title=self.document.title, 110 | url=self._text_anchor(self.document.url, answer.content), 111 | time=self.document.timestamp, 112 | location=self.document.location, 113 | data_source=self.document.data_source.type.name, 114 | type=self.document.type, 115 | file_type=self.document.file_type, 116 | status=self.document.status, 117 | is_active=self.document.is_active) 118 | 119 | if parent_result is not None: 120 | parent_result.child = result 121 | return parent_result 122 | else: 123 | return result 124 | 125 | 126 | def _cross_encode( 127 | cross_encoder: CrossEncoder, 128 | query: str, 129 | candidates: List[Candidate], 130 | top_k: int, 131 | use_answer: bool = False, 132 | use_titles: bool = False) -> List[Candidate]: 133 | if use_answer: 134 | contents = [candidate.content[candidate.answer_start:candidate.answer_end] for candidate in candidates] 135 | else: 136 | contents = [candidate.content for candidate in candidates] 137 | 138 | if use_titles: 139 | contents = [ 140 | content + ' [SEP] ' + candidate.document.title 141 | for content, candidate in zip(contents, candidates) 142 | ] 143 | 144 | scores = cross_encoder.predict([(query, content) for content in contents], show_progress_bar=False) 145 | for candidate, score in zip(candidates, scores): 146 | candidate.score = score.item() 147 | candidates.sort(key=lambda c: c.score, reverse=True) 148 | return candidates[:top_k] 149 | 150 | 151 | def _assign_answer_sentence(candidate: Candidate, answer: str): 152 | paragraph_sentences = re.split(r'([\.\!\?\:\-] |[\"“\(\)])', candidate.content) 153 | sentence = None 154 | for i, paragraph_sentence in enumerate(paragraph_sentences): 155 | if answer in paragraph_sentence: 156 | sentence = paragraph_sentence 157 | break 158 | else: 159 | sentence = answer 160 | start = candidate.content.find(sentence) 161 | end = start + len(sentence) 162 | candidate.answer_start = start 163 | candidate.answer_end = end 164 | 165 | 166 | def _find_answers_in_candidates(candidates: List[Candidate], query: str) -> List[Candidate]: 167 | contexts = [candidate.content for candidate in candidates] 168 | answers = qa_model(question=[query] * len(contexts), context=contexts) 169 | 170 | if type(answers) == dict: 171 | answers = [answers] 172 | 173 | for candidate, answer in zip(candidates, answers): 174 | _assign_answer_sentence(candidate, answer['answer']) 175 | 176 | return candidates 177 | 178 | 179 | def search_documents(query: str, top_k: int) -> List[SearchResult]: 180 | # Encode the query 181 | query_embedding = bi_encoder.encode(query, convert_to_tensor=True, show_progress_bar=False) 182 | 183 | # Search the index for 100 candidates 184 | index = FaissIndex.get() 185 | results = index.search(query_embedding, BI_ENCODER_CANDIDATES) 186 | results = results[0] 187 | results = [int(id) for id in results if id != -1] # filter out empty results 188 | 189 | results += Bm25Index.get().search(query, BM_25_CANDIDATES) 190 | # Get the paragraphs from the database 191 | with Session() as session: 192 | paragraphs = session.query(Paragraph).filter(Paragraph.id.in_(results)).all() 193 | if len(paragraphs) == 0: 194 | return [] 195 | candidates = [Candidate(content=paragraph.content, document=paragraph.document, score=0.0) 196 | for paragraph in paragraphs] 197 | 198 | # calculate small cross-encoder scores to leave just a few candidates 199 | logger.info(f'Found {len(candidates)} candidates, filtering...') 200 | candidates = _cross_encode(cross_encoder_small, query, candidates, BI_ENCODER_CANDIDATES, use_titles=True) 201 | # calculate large cross-encoder scores to leave just top_k candidates 202 | candidates = _cross_encode(cross_encoder_large, query, candidates, top_k, use_titles=True) 203 | candidates = _find_answers_in_candidates(candidates, query) 204 | candidates = _cross_encode(cross_encoder_large, query, candidates, top_k, use_answer=True, use_titles=True) 205 | 206 | logger.info(f'Parsing {len(candidates)} candidates to search results...') 207 | 208 | for possible_child in candidates: 209 | if possible_child.document.parent_id is not None: 210 | for possible_parent in candidates: 211 | if possible_parent.document.id == possible_child.document.parent_id: 212 | possible_child.parent = possible_parent 213 | candidates.remove(possible_parent) 214 | break 215 | 216 | with ThreadPoolExecutor(max_workers=10) as executor: 217 | result = list(executor.map(lambda c: c.to_search_result(), candidates)) 218 | result.sort(key=lambda r: r.score, reverse=True) 219 | return result 220 | -------------------------------------------------------------------------------- /app/static/data_source_icons/bookstack.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/static/data_source_icons/bookstack.png -------------------------------------------------------------------------------- /app/static/data_source_icons/confluence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/static/data_source_icons/confluence.png -------------------------------------------------------------------------------- /app/static/data_source_icons/confluence_cloud.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/static/data_source_icons/confluence_cloud.png -------------------------------------------------------------------------------- /app/static/data_source_icons/default_icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/static/data_source_icons/default_icon.png -------------------------------------------------------------------------------- /app/static/data_source_icons/gitlab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/static/data_source_icons/gitlab.png -------------------------------------------------------------------------------- /app/static/data_source_icons/google_drive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/static/data_source_icons/google_drive.png -------------------------------------------------------------------------------- /app/static/data_source_icons/jira.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/static/data_source_icons/jira.png -------------------------------------------------------------------------------- /app/static/data_source_icons/jira_cloud.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/static/data_source_icons/jira_cloud.png -------------------------------------------------------------------------------- /app/static/data_source_icons/mattermost.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/static/data_source_icons/mattermost.png -------------------------------------------------------------------------------- /app/static/data_source_icons/rocketchat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/static/data_source_icons/rocketchat.png -------------------------------------------------------------------------------- /app/static/data_source_icons/slack.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/app/static/data_source_icons/slack.png -------------------------------------------------------------------------------- /app/telemetry.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Optional 4 | from uuid import uuid4 5 | 6 | import posthog 7 | 8 | from paths import UUID_PATH 9 | 10 | logging.basicConfig(level=logging.INFO, 11 | format='%(asctime)s | %(levelname)s | %(filename)s:%(lineno)d | %(message)s') 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class Posthog: 16 | API_KEY = "phc_unIQdP9MFUa5bQNIKy5ktoRCPWMPWgqTbRvZr4391Pm" 17 | HOST = 'https://eu.posthog.com' 18 | 19 | RUN_EVENT = "run" 20 | BACKEND_SEARCH_EVENT = "backend_search" 21 | BACKEND_ADDED = "backend_added" 22 | BACKEND_REMOVED = "backend_removed" 23 | BACKEND_LISTED_LOCATIONS = "backend_listed_locations" 24 | _identified_uuid: Optional[str] = None 25 | 26 | @classmethod 27 | def _read_uuid_file(cls) -> Optional[str]: 28 | if os.path.exists(UUID_PATH): 29 | with open(UUID_PATH, 'r') as f: 30 | existing_uuid = f.read().strip() 31 | return existing_uuid 32 | 33 | return [] 34 | 35 | @classmethod 36 | def _create_uuid_file(cls, user_uuid: str): 37 | with open(UUID_PATH, 'w') as f: 38 | f.write(user_uuid) 39 | 40 | @classmethod 41 | def _identify(cls): 42 | user_uuid = cls._read_uuid_file() 43 | if user_uuid is None: 44 | new_uuid = str(uuid4()) 45 | cls._create_uuid_file(new_uuid) 46 | user_uuid = new_uuid 47 | 48 | try: 49 | posthog.api_key = cls.API_KEY 50 | posthog.host = cls.HOST 51 | posthog.identify(user_uuid) 52 | cls._identified_uuid = user_uuid 53 | except Exception as e: 54 | pass 55 | 56 | @classmethod 57 | def _capture(cls, event: str, uuid=None, properties=None): 58 | if cls._identified_uuid is None: 59 | cls._identify() 60 | 61 | try: 62 | posthog.capture(uuid or cls._identified_uuid, event, properties) 63 | except Exception as e: 64 | pass 65 | 66 | @classmethod 67 | def send_daily(cls): 68 | cls._capture(cls.RUN_EVENT) 69 | 70 | @classmethod 71 | def send_startup_telemetry(cls): 72 | cls._capture(cls.RUN_EVENT) 73 | 74 | @classmethod 75 | def increase_search_count(cls, uuid: str): 76 | cls._capture(cls.BACKEND_SEARCH_EVENT, uuid=uuid) 77 | 78 | @classmethod 79 | def added_data_source(cls, uuid: str, name: str): 80 | cls._capture(cls.BACKEND_ADDED, uuid=uuid, properties={"name": name}) 81 | 82 | @classmethod 83 | def removed_data_source(cls, uuid: str, name: str): 84 | cls._capture(cls.BACKEND_REMOVED, uuid=uuid, properties={"name": name}) 85 | 86 | @classmethod 87 | def listed_locations(cls, uuid: str, name: str): 88 | cls._capture(cls.BACKEND_LISTED_LOCATIONS, uuid=uuid, properties={"name": name}) 89 | -------------------------------------------------------------------------------- /app/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger = logging.getLogger(__name__) 4 | 5 | 6 | def threaded_method(func): 7 | # so we won't miss exceptions 8 | def wrapper(*args, **kwargs): 9 | try: 10 | return func(*args, **kwargs) 11 | except Exception as e: 12 | logger.exception(e) 13 | raise e 14 | 15 | return wrapper 16 | -------------------------------------------------------------------------------- /app/workers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import threading 3 | 4 | from data_source.api.context import DataSourceContext 5 | from queues.task_queue import TaskQueue, TaskQueueItem, Task 6 | 7 | logger = logging.getLogger() 8 | 9 | 10 | class Workers: 11 | _threads = [] 12 | _stop_event = threading.Event() 13 | WORKER_AMOUNT = 20 14 | 15 | @classmethod 16 | def start(cls): 17 | for i in range(cls.WORKER_AMOUNT): 18 | cls._threads.append(threading.Thread(target=cls.run)) 19 | for thread in cls._threads: 20 | thread.start() 21 | 22 | @classmethod 23 | def stop(cls): 24 | cls._stop_event.set() 25 | logging.info('Stop event set, waiting for workers to stop...') 26 | 27 | for thread in cls._threads: 28 | thread.join() 29 | logging.info('Workers stopped') 30 | 31 | cls._thread = None 32 | 33 | @staticmethod 34 | def run(): 35 | task_queue = TaskQueue.get_instance() 36 | logger.info(f'Worker started...') 37 | 38 | while not Workers._stop_event.is_set(): 39 | task_item: TaskQueueItem = task_queue.get_task() 40 | if not task_item: 41 | continue 42 | 43 | task_data: Task = task_item.task 44 | try: 45 | data_source = DataSourceContext.get_data_source_instance(task_data.data_source_id) 46 | data_source.run_task(task_data.function_name, **task_data.kwargs) 47 | task_queue.ack(id=task_item.queue_item_id) 48 | except Exception: 49 | logger.exception(f'Failed to ack task {task_data.function_name} ' 50 | f'for data source {task_data.data_source_id}, decrementing remaining attempts') 51 | try: 52 | task_data.attempts -= 1 53 | 54 | if task_data.attempts == 0: 55 | logger.error(f'max attempts reached, dropping') 56 | task_queue.ack_failed(id=task_item.queue_item_id) 57 | else: 58 | task_queue.update(id=task_item.queue_item_id, item=task_data) 59 | task_queue.nack(id=task_item.queue_item_id) 60 | except Exception: 61 | logger.exception('Error while handling task that failed...') 62 | task_queue.ack_failed(id=task_item.queue_item_id) 63 | -------------------------------------------------------------------------------- /deploy.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | VERSION=0.0.4 3 | 4 | cd ui || exit 1 5 | 6 | npm install 7 | npm run build 8 | 9 | cd .. 10 | 11 | mkdir -p $HOME/.gerev/.buildx-cache 12 | 13 | sudo docker buildx create --use 14 | sudo docker buildx build --platform linux/amd64,linux/arm64 \ 15 | --cache-from type=local,src=$HOME/.gerev/.buildx-cache \ 16 | --cache-to type=local,dest=$HOME/.gerev/.buildx-cache \ 17 | -t gerev/gerev:$VERSION . \ 18 | -t gerev/gerev:latest --push 19 | -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | gerev: 3 | image: gerev:latest 4 | ports: 5 | - 80:80 6 | volumes: 7 | - ~/.gerev/storage:/opt/storage 8 | build: . 9 | deploy: 10 | resources: 11 | reservations: 12 | devices: 13 | - driver: nvidia 14 | count: 1 15 | capabilities: [gpu] -------------------------------------------------------------------------------- /docs/data-sources/confluence/confluence-settings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/docs/data-sources/confluence/confluence-settings.png -------------------------------------------------------------------------------- /docs/data-sources/confluence/confluence.md: -------------------------------------------------------------------------------- 1 | # Setting up Confluence data source 2 | 3 | Please note that all pages you have read access to, in all spaces, will be indexed. 4 | 5 | 1. Click on your profile picture and go to **Settings** 6 | 7 | ![Settings](./settings.png) 8 | 9 | 2. Go to **Personal Access Tokens** 10 | 11 | ![Personal Access Tokens](./personal-access-tokens.png) 12 | 13 | 3. On the right, click **Create Token** 14 | 15 | ![Create token](./create-token.png) 16 | 17 | 4. Give the token a **name** and **uncheck automatic expiry** 18 | 19 | ![Create token screen](./create-token-screen.png) 20 | 21 | 5. **Copy the generated token** into the Gerev settings page, together with the **Confluence URL** 22 | 23 | ![Confluence settings page](./confluence-settings.png) 24 | -------------------------------------------------------------------------------- /docs/data-sources/confluence/create-token-screen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/docs/data-sources/confluence/create-token-screen.png -------------------------------------------------------------------------------- /docs/data-sources/confluence/create-token.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/docs/data-sources/confluence/create-token.png -------------------------------------------------------------------------------- /docs/data-sources/confluence/personal-access-tokens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/docs/data-sources/confluence/personal-access-tokens.png -------------------------------------------------------------------------------- /docs/data-sources/confluence/settings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/docs/data-sources/confluence/settings.png -------------------------------------------------------------------------------- /docs/data-sources/gitlab/access_tokens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/docs/data-sources/gitlab/access_tokens.png -------------------------------------------------------------------------------- /docs/data-sources/gitlab/gitlab.md: -------------------------------------------------------------------------------- 1 | # Setting up Gitlab data source 2 | 3 | Please note that all projects you have member access to, will be indexed. 4 | 5 | 1. Click on your profile picture and go to **Preferences** 6 | 7 | 2. Go to **Access Tokens** 8 | 9 | ![Access Tokens](./access_tokens.png) 10 | 11 | 3. Add the token name, remove the expiration date and check the read api checkbox 12 | 13 | ![token](./read_api.png) 14 | 15 | 5. **Copy the generated token** into the Gerev settings page -------------------------------------------------------------------------------- /docs/data-sources/gitlab/read_api.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/docs/data-sources/gitlab/read_api.png -------------------------------------------------------------------------------- /docs/data-sources/google-drive/copy-email.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/docs/data-sources/google-drive/copy-email.png -------------------------------------------------------------------------------- /docs/data-sources/google-drive/create-key-json.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/docs/data-sources/google-drive/create-key-json.png -------------------------------------------------------------------------------- /docs/data-sources/google-drive/create-key.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/docs/data-sources/google-drive/create-key.png -------------------------------------------------------------------------------- /docs/data-sources/google-drive/create-project.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/docs/data-sources/google-drive/create-project.png -------------------------------------------------------------------------------- /docs/data-sources/google-drive/create-service-account.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/docs/data-sources/google-drive/create-service-account.png -------------------------------------------------------------------------------- /docs/data-sources/google-drive/gerev-settings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/docs/data-sources/google-drive/gerev-settings.png -------------------------------------------------------------------------------- /docs/data-sources/google-drive/google-drive-share.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/docs/data-sources/google-drive/google-drive-share.png -------------------------------------------------------------------------------- /docs/data-sources/google-drive/google-drive.md: -------------------------------------------------------------------------------- 1 | # Google drive 2 | 3 | ## Setting up 4 | 5 | The process involves creating a service account with an email address, then sharing drive folders with that service account. 6 | 7 | 1. Go to [Create a project](https://console.cloud.google.com/projectcreate?previousPage=%2Fprojectselector2%2Fiam-admin%2Fserviceaccounts%2Fcreate) on google cloud console. It's important to choose your **organization** if you have one. 8 | 9 | ![Create a project](./create-project.png) 10 | 11 | 2. Enable [Google Drive API](https://console.cloud.google.com/apis/library/drive.googleapis.com) for the project 12 | 13 | 14 | 15 | 3. Go to [Create a service account](https://console.cloud.google.com/projectselector/iam-admin/serviceaccounts/create?walkthrough_id=iam--create-service-account&_ga=2.83190934.1008578598.1678803320-418131787.1678643755#step_index=1) and select your project. 16 | 17 | 4. Select a name for the account. The rest of the fields don't matter, you can press done at the bottom. 18 | 19 | ![Create service account](./create-service-account.png) 20 | 21 | 5. hover over the newly created service account and copy the email address and save it. 22 | 23 | ![Copy email address](./copy-email.png) 24 | 25 | 6. click the newly created service account, then go to the KEYS tab 26 | 27 | ![KEYS tab](./keys-tab.png) 28 | 29 | 7. click ADD KEY -> Create new key and Select JSON. It should download a json file. 30 | 31 | ![Create a key](./create-key.png) 32 | 33 | ![Select JSON](./create-key-json.png) 34 | 35 | 8. Go to [Google Drive](https://drive.google.com/) and share the google drive folders you want to index with the service account email 36 | 37 | ![Share drive folder](./google-drive-share.png) 38 | 39 | ![Share drive folder with the service account](./share-drive-folder.png) 40 | 41 | 9. copy the contents of the json file into the box in the Gerev settings panel 42 | 43 | ![Gerev settings page](./gerev-settings.png) 44 | -------------------------------------------------------------------------------- /docs/data-sources/google-drive/keys-tab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/docs/data-sources/google-drive/keys-tab.png -------------------------------------------------------------------------------- /docs/data-sources/google-drive/loading-screen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/docs/data-sources/google-drive/loading-screen.png -------------------------------------------------------------------------------- /docs/data-sources/google-drive/share-drive-folder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/docs/data-sources/google-drive/share-drive-folder.png -------------------------------------------------------------------------------- /images/CodeCard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/images/CodeCard.png -------------------------------------------------------------------------------- /images/Everything.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/images/Everything.png -------------------------------------------------------------------------------- /images/api.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/images/api.gif -------------------------------------------------------------------------------- /images/bill.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/images/bill.png -------------------------------------------------------------------------------- /images/contact-card.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/images/contact-card.png -------------------------------------------------------------------------------- /images/everything.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/images/everything.png -------------------------------------------------------------------------------- /images/integ.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/images/integ.jpeg -------------------------------------------------------------------------------- /images/product-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/images/product-example.png -------------------------------------------------------------------------------- /images/sql-card.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/images/sql-card.png -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # db migrations 2 | alembic upgrade head 3 | 4 | # run server 5 | uvicorn main:app --host 0.0.0.0 --port 80 -------------------------------------------------------------------------------- /ui/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | 8 | # testing 9 | /coverage 10 | 11 | # production 12 | /build 13 | 14 | # misc 15 | .DS_Store 16 | .env.local 17 | .env.development.local 18 | .env.test.local 19 | .env.production.local 20 | 21 | npm-debug.log* 22 | yarn-debug.log* 23 | yarn-error.log* 24 | -------------------------------------------------------------------------------- /ui/README.md: -------------------------------------------------------------------------------- 1 | # Getting Started with Create React App 2 | 3 | This project was bootstrapped with [Create React App](https://github.com/facebook/create-react-app). 4 | 5 | ## Available Scripts 6 | 7 | In the project directory, you can run: 8 | 9 | ### `npm start` 10 | 11 | Runs the app in the development mode.\ 12 | Open [http://localhost:3000](http://localhost:3000) to view it in your browser. 13 | 14 | The page will reload when you make changes.\ 15 | You may also see any lint errors in the console. 16 | 17 | ### `npm test` 18 | 19 | Launches the test runner in the interactive watch mode.\ 20 | See the section about [running tests](https://facebook.github.io/create-react-app/docs/running-tests) for more information. 21 | 22 | ### `npm run build` 23 | 24 | Builds the app for production to the `build` folder.\ 25 | It correctly bundles React in production mode and optimizes the build for the best performance. 26 | 27 | The build is minified and the filenames include the hashes.\ 28 | Your app is ready to be deployed! 29 | 30 | See the section about [deployment](https://facebook.github.io/create-react-app/docs/deployment) for more information. 31 | 32 | ### `npm run eject` 33 | 34 | **Note: this is a one-way operation. Once you `eject`, you can't go back!** 35 | 36 | If you aren't satisfied with the build tool and configuration choices, you can `eject` at any time. This command will remove the single build dependency from your project. 37 | 38 | Instead, it will copy all the configuration files and the transitive dependencies (webpack, Babel, ESLint, etc) right into your project so you have full control over them. All of the commands except `eject` will still work, but they will point to the copied scripts so you can tweak them. At this point you're on your own. 39 | 40 | You don't have to ever use `eject`. The curated feature set is suitable for small and middle deployments, and you shouldn't feel obligated to use this feature. However we understand that this tool wouldn't be useful if you couldn't customize it when you are ready for it. 41 | 42 | ## Learn More 43 | 44 | You can learn more in the [Create React App documentation](https://facebook.github.io/create-react-app/docs/getting-started). 45 | 46 | To learn React, check out the [React documentation](https://reactjs.org/). 47 | 48 | ### Code Splitting 49 | 50 | This section has moved here: [https://facebook.github.io/create-react-app/docs/code-splitting](https://facebook.github.io/create-react-app/docs/code-splitting) 51 | 52 | ### Analyzing the Bundle Size 53 | 54 | This section has moved here: [https://facebook.github.io/create-react-app/docs/analyzing-the-bundle-size](https://facebook.github.io/create-react-app/docs/analyzing-the-bundle-size) 55 | 56 | ### Making a Progressive Web App 57 | 58 | This section has moved here: [https://facebook.github.io/create-react-app/docs/making-a-progressive-web-app](https://facebook.github.io/create-react-app/docs/making-a-progressive-web-app) 59 | 60 | ### Advanced Configuration 61 | 62 | This section has moved here: [https://facebook.github.io/create-react-app/docs/advanced-configuration](https://facebook.github.io/create-react-app/docs/advanced-configuration) 63 | 64 | ### Deployment 65 | 66 | This section has moved here: [https://facebook.github.io/create-react-app/docs/deployment](https://facebook.github.io/create-react-app/docs/deployment) 67 | 68 | ### `npm run build` fails to minify 69 | 70 | This section has moved here: [https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify](https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify) 71 | -------------------------------------------------------------------------------- /ui/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ui", 3 | "version": "0.1.0", 4 | "private": true, 5 | "dependencies": { 6 | "@ramonak/react-progress-bar": "^5.0.3", 7 | "@testing-library/jest-dom": "^5.16.5", 8 | "@testing-library/react": "^13.4.0", 9 | "@testing-library/user-event": "^13.5.0", 10 | "@types/react": "^18.0.28", 11 | "@types/react-dom": "^18.0.11", 12 | "axios": "^1.3.4", 13 | "copy-to-clipboard": "^3.3.3", 14 | "posthog-js": "^1.51.2", 15 | "react": "^18.2.0", 16 | "react-confirm-alert": "^3.0.6", 17 | "react-dom": "^18.2.0", 18 | "react-icons": "^4.8.0", 19 | "react-image": "^4.1.0", 20 | "react-modal": "^3.16.1", 21 | "react-router": "^6.10.0", 22 | "react-router-dom": "^5.2.0", 23 | "react-scripts": "5.0.1", 24 | "react-select": "^5.7.0", 25 | "react-spinners": "^0.13.8", 26 | "react-toastify": "^9.1.1", 27 | "react-tooltip": "^5.10.1", 28 | "ts-loader": "^9.4.2", 29 | "typescript": "^4.9.5", 30 | "uuid": "^9.0.0", 31 | "web-vitals": "^2.1.4", 32 | "webpack": "^5.75.0", 33 | "webpack-cli": "^5.0.1" 34 | }, 35 | "scripts": { 36 | "start": "react-scripts start", 37 | "build": "react-scripts build", 38 | "test": "react-scripts test", 39 | "eject": "react-scripts eject", 40 | "magic": "webpack" 41 | }, 42 | "eslintConfig": { 43 | "extends": [ 44 | "react-app", 45 | "react-app/jest" 46 | ] 47 | }, 48 | "browserslist": { 49 | "production": [ 50 | ">0.2%", 51 | "not dead", 52 | "not op_mini all" 53 | ], 54 | "development": [ 55 | "last 1 chrome version", 56 | "last 1 firefox version", 57 | "last 1 safari version" 58 | ] 59 | }, 60 | "devDependencies": { 61 | "@types/node": "^18.15.0", 62 | "tailwindcss": "^3.2.7" 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /ui/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/ui/public/favicon.ico -------------------------------------------------------------------------------- /ui/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 12 | 13 | 17 | 18 | 27 | gerev.ai 28 | 29 | 30 | 31 |
32 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /ui/public/logo192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/ui/public/logo192.png -------------------------------------------------------------------------------- /ui/public/logo512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/ui/public/logo512.png -------------------------------------------------------------------------------- /ui/public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "React App", 3 | "name": "Create React App Sample", 4 | "icons": [ 5 | { 6 | "src": "favicon.ico", 7 | "sizes": "64x64 32x32 24x24 16x16", 8 | "type": "image/x-icon" 9 | }, 10 | { 11 | "src": "logo192.png", 12 | "type": "image/png", 13 | "sizes": "192x192" 14 | }, 15 | { 16 | "src": "logo512.png", 17 | "type": "image/png", 18 | "sizes": "512x512" 19 | } 20 | ], 21 | "start_url": ".", 22 | "display": "standalone", 23 | "theme_color": "#000000", 24 | "background_color": "#ffffff" 25 | } 26 | -------------------------------------------------------------------------------- /ui/public/robots.txt: -------------------------------------------------------------------------------- 1 | # https://www.robotstxt.org/robotstxt.html 2 | User-agent: * 3 | Disallow: 4 | -------------------------------------------------------------------------------- /ui/src/api.ts: -------------------------------------------------------------------------------- 1 | import axios from 'axios'; 2 | 3 | let port = (!process.env.NODE_ENV || process.env.NODE_ENV === 'development') ? 8000 : window.location.port; 4 | export const api = axios.create({ 5 | baseURL: `${window.location.protocol}//${window.location.hostname}:${port}/api/v1`, 6 | }) 7 | -------------------------------------------------------------------------------- /ui/src/assets/css/App.css: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/ui/src/assets/css/App.css -------------------------------------------------------------------------------- /ui/src/assets/css/custom-fonts.css: -------------------------------------------------------------------------------- 1 | @font-face { 2 | font-family: 'Poppins'; 3 | font-weight: 500; 4 | font-style: normal; 5 | font-display: fallback; 6 | src: url('../fonts/Poppins-Medium.ttf') format('truetype'); 7 | } 8 | 9 | @font-face { 10 | font-family: 'Poppins'; 11 | font-weight: 400; 12 | font-style: normal; 13 | font-display: fallback; 14 | src: url('../fonts/Poppins-Regular.ttf') format('truetype'); 15 | } 16 | 17 | @font-face { 18 | font-family: 'Source Sans Pro'; 19 | font-weight: 600; 20 | font-style: normal; 21 | font-display: fallback; 22 | src: url('../fonts/SourceSansPro-SemiBold.ttf') format('truetype'); 23 | } 24 | 25 | @font-face { 26 | font-family: 'Source Sans Pro'; 27 | font-weight: 900; 28 | font-style: normal; 29 | font-display: fallback; 30 | src: url('../fonts/SourceSansPro-Black.ttf') format('truetype'); 31 | } 32 | 33 | @font-face { 34 | font-family: 'Inter'; 35 | font-weight: 600; 36 | font-style: normal; 37 | font-display: fallback; 38 | src: url('../fonts/Inter-SemiBold.ttf') format('truetype'); 39 | } -------------------------------------------------------------------------------- /ui/src/assets/css/index.css: -------------------------------------------------------------------------------- 1 | @tailwind base; 2 | @tailwind components; 3 | @tailwind utilities; 4 | 5 | 6 | @import 'custom-fonts.css'; 7 | 8 | 9 | body { 10 | background-color: #221f2e; 11 | margin: 0; 12 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 13 | 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', 14 | sans-serif; 15 | -webkit-font-smoothing: antialiased; 16 | -moz-osx-font-smoothing: grayscale; 17 | } -------------------------------------------------------------------------------- /ui/src/assets/fonts/Inter-SemiBold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/ui/src/assets/fonts/Inter-SemiBold.ttf -------------------------------------------------------------------------------- /ui/src/assets/fonts/Poppins-Medium.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/ui/src/assets/fonts/Poppins-Medium.ttf -------------------------------------------------------------------------------- /ui/src/assets/fonts/Poppins-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/ui/src/assets/fonts/Poppins-Regular.ttf -------------------------------------------------------------------------------- /ui/src/assets/fonts/SourceSansPro-Black.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/ui/src/assets/fonts/SourceSansPro-Black.ttf -------------------------------------------------------------------------------- /ui/src/assets/fonts/SourceSansPro-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/ui/src/assets/fonts/SourceSansPro-Bold.ttf -------------------------------------------------------------------------------- /ui/src/assets/fonts/SourceSansPro-SemiBold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/ui/src/assets/fonts/SourceSansPro-SemiBold.ttf -------------------------------------------------------------------------------- /ui/src/assets/images/blue-folder.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /ui/src/assets/images/bookstack.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /ui/src/assets/images/calendar.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /ui/src/assets/images/copy-this.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/ui/src/assets/images/copy-this.png -------------------------------------------------------------------------------- /ui/src/assets/images/discord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/ui/src/assets/images/discord.png -------------------------------------------------------------------------------- /ui/src/assets/images/docx.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /ui/src/assets/images/enter.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /ui/src/assets/images/gitlab.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ui/src/assets/images/google-doc.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /ui/src/assets/images/google-drive.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /ui/src/assets/images/left-pane-instructions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/ui/src/assets/images/left-pane-instructions.png -------------------------------------------------------------------------------- /ui/src/assets/images/logo.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ui/src/assets/images/pdf.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ui/src/assets/images/pptx.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /ui/src/assets/images/profile-picture-default.svg: -------------------------------------------------------------------------------- 1 | 4 | 6 | -------------------------------------------------------------------------------- /ui/src/assets/images/pur-dir.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /ui/src/assets/images/usa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/ui/src/assets/images/usa.png -------------------------------------------------------------------------------- /ui/src/assets/images/user.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GerevAI/gerev/1018e122aae288ca9c77d72b0efbe9ff7a6df750/ui/src/assets/images/user.webp -------------------------------------------------------------------------------- /ui/src/assets/images/warning.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /ui/src/autocomplete.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Current implementation is based on last 100 queries saved in local storage. 3 | */ 4 | 5 | function getSearchHistory() { 6 | const history = localStorage.getItem("searchHistory"); 7 | if (history) { 8 | return JSON.parse(history); 9 | } 10 | return []; 11 | } 12 | 13 | function saveSearchHistory(history: string[]) { 14 | localStorage.setItem("searchHistory", JSON.stringify(history)); 15 | } 16 | 17 | export async function addToSearchHistory(search: string) { 18 | const history = getSearchHistory(); 19 | if (history.length > 0 && history[0] === search) { 20 | return; 21 | } 22 | 23 | const index = history.indexOf(search); 24 | if (index > -1) { 25 | history.splice(index, 1); 26 | } 27 | 28 | history.unshift(search); 29 | if (history.length > 100) { 30 | history.pop(); 31 | } 32 | 33 | saveSearchHistory(history); 34 | } 35 | 36 | export function getSearchHistorySuggestions(search: string) { 37 | const history = getSearchHistory(); 38 | return history 39 | .filter((item) => item.startsWith(search)) 40 | .map((item) => item.slice(search.length)) 41 | .slice(0, 6); 42 | } -------------------------------------------------------------------------------- /ui/src/components/search-bar.tsx: -------------------------------------------------------------------------------- 1 | import * as React from "react"; 2 | 3 | import ClipLoader from "react-spinners/ClipLoader"; 4 | import { BsSearch, BsXLg } from "react-icons/bs"; 5 | import { getSearchHistorySuggestions } from "../autocomplete"; 6 | 7 | export interface SearchBarState { 8 | suggestions: string[] 9 | hideSuggestions: boolean 10 | activeSuggestion: number 11 | } 12 | 13 | export interface SearchBarProps { 14 | query: string 15 | widthPercentage: number 16 | isLoading: boolean 17 | isDisabled: boolean 18 | showReset: boolean 19 | showSuggestions: boolean 20 | onSearch: () => void 21 | onQueryChange: (query: string) => void 22 | onClear: () => void 23 | } 24 | 25 | export default class SearchBar extends React.Component { 26 | 27 | constructor(props) { 28 | super(props); 29 | this.state = { 30 | activeSuggestion: 0, 31 | hideSuggestions: false, 32 | suggestions: [] 33 | } 34 | } 35 | 36 | getBorderGradient() { 37 | if (this.props.isDisabled) { 38 | return 'from-[#222222] via-[#333333] to-[#222222]'; 39 | } 40 | return 'from-[#8E59D1] via-[#85a6ec] to-[#b385ec]'; 41 | } 42 | 43 | render() { 44 | return ( 45 |
46 |
47 | 57 | 58 | 60 | { 61 | this.props.showReset && 62 | 63 | } 64 |
65 | { 66 | !this.state.hideSuggestions && this.props.showSuggestions && this.props.query.length > 2 && 67 |
68 | {this.state.suggestions.map((suggestion, index) => { 69 | return ( 70 |
this.onSuggestionClick(suggestion)} key={index} 71 | className={'text-[#C9C9C9] font-poppins font-medium text-lg mt-2 hover:bg-[#37383a] p-1 py-2' 72 | + (this.state.activeSuggestion === index ? ' bg-[#37383a] border-l-[#8E59D1] rounded-l-sm border-l-[3px]' : 73 | '')}> 74 | {this.props.query} 75 | {suggestion} 76 | 77 |
78 | 79 | ) 80 | })} 81 | {this.state.suggestions.length > 1 && 82 |
Use arrows ↑ ↓ to navigate
} 83 |
84 | } 85 |
86 | ); 87 | } 88 | 89 | search = () => { 90 | this.setState({ hideSuggestions: true, activeSuggestion: 0 }); 91 | this.props.onSearch(); 92 | } 93 | 94 | onKeyDown = async (event: React.KeyboardEvent) => { 95 | if (event.key === 'Enter') { 96 | if (this.state.activeSuggestion > 0) { 97 | await this.props.onQueryChange(this.props.query + this.state.suggestions[this.state.activeSuggestion]); 98 | } 99 | this.search() 100 | } else if (event.key === 'Escape') { 101 | this.setState({ hideSuggestions: true, activeSuggestion: 0 }); 102 | } else if (event.key === 'ArrowDown') { 103 | event.preventDefault(); 104 | if (this.state.activeSuggestion < this.state.suggestions.length - 1) { 105 | this.setState({ activeSuggestion: this.state.activeSuggestion + 1 }); 106 | } 107 | } else if (event.key === 'ArrowUp') { 108 | event.preventDefault(); 109 | if (this.state.activeSuggestion > 0) { 110 | this.setState({ activeSuggestion: this.state.activeSuggestion - 1 }); 111 | } 112 | } 113 | } 114 | 115 | onSuggestionClick = (suggestion: string) => { 116 | this.props.onQueryChange(this.props.query + suggestion); 117 | } 118 | 119 | handleChange = (event: React.ChangeEvent) => { 120 | let query = event.target.value 121 | this.props.onQueryChange(query); 122 | this.setState({ hideSuggestions: false }); 123 | this.setState({ suggestions: ['', ...getSearchHistorySuggestions(query)] }) 124 | } 125 | 126 | } 127 | -------------------------------------------------------------------------------- /ui/src/custom.d.ts: -------------------------------------------------------------------------------- 1 | declare module "*.svg" { 2 | const content: React.FunctionComponent>; 3 | export default content; 4 | } 5 | 6 | -------------------------------------------------------------------------------- /ui/src/data-source.ts: -------------------------------------------------------------------------------- 1 | export enum HTMLInputType { 2 | TEXT = "text", 3 | TEXTAREA = "textarea", 4 | PASSWORD = "password" 5 | } 6 | 7 | export interface ConfigField { 8 | name: string 9 | input_type: HTMLInputType 10 | label: string 11 | placeholder: string 12 | value?: string 13 | } 14 | 15 | 16 | export interface DataSourceType { 17 | name: string 18 | display_name: string 19 | config_fields: ConfigField[] 20 | image_base64: string 21 | has_prerequisites: boolean 22 | } 23 | 24 | export interface ConnectedDataSource { 25 | id: number 26 | name: string 27 | } 28 | 29 | export interface IndexLocation { 30 | value: string 31 | label: string 32 | } -------------------------------------------------------------------------------- /ui/src/index.tsx: -------------------------------------------------------------------------------- 1 | import * as React from "react"; 2 | import * as ReactDOM from "react-dom"; 3 | import './assets/css/index.css'; 4 | import App from './App'; 5 | import { PostHogProvider} from 'posthog-js/react' 6 | import posthog from 'posthog-js'; 7 | import 'react-tooltip/dist/react-tooltip.css' 8 | 9 | posthog.init( 10 | "phc_unIQdP9MFUa5bQNIKy5ktoRCPWMPWgqTbRvZr4391Pm", 11 | { 12 | api_host: "https://eu.posthog.com", 13 | disable_session_recording: true, 14 | autocapture: false, 15 | enable_recording_console_log: false, 16 | } 17 | ); 18 | 19 | ReactDOM.render( 20 | 21 | 22 | , 23 | document.getElementById("root") 24 | ); 25 | -------------------------------------------------------------------------------- /ui/src/reportWebVitals.js: -------------------------------------------------------------------------------- 1 | const reportWebVitals = onPerfEntry => { 2 | if (onPerfEntry && onPerfEntry instanceof Function) { 3 | import('web-vitals').then(({ getCLS, getFID, getFCP, getLCP, getTTFB }) => { 4 | getCLS(onPerfEntry); 5 | getFID(onPerfEntry); 6 | getFCP(onPerfEntry); 7 | getLCP(onPerfEntry); 8 | getTTFB(onPerfEntry); 9 | }); 10 | } 11 | }; 12 | 13 | export default reportWebVitals; 14 | -------------------------------------------------------------------------------- /ui/src/setupTests.js: -------------------------------------------------------------------------------- 1 | // jest-dom adds custom jest matchers for asserting on DOM nodes. 2 | // allows you to do things like: 3 | // expect(element).toHaveTextContent(/react/i) 4 | // learn more: https://github.com/testing-library/jest-dom 5 | import '@testing-library/jest-dom'; 6 | -------------------------------------------------------------------------------- /ui/tailwind.config.js: -------------------------------------------------------------------------------- 1 | /** @type {import('tailwindcss').Config} */ 2 | module.exports = { 3 | content: [ 4 | "./src/**/*.{js,jsx,ts,tsx}", 5 | ], 6 | theme: { 7 | extend: { 8 | fontFamily: { 9 | 'poppins': ['Poppins'], 10 | 'source-sans-pro': ['Source Sans Pro'], 11 | inter: ['Inter', 'sans-serif'], 12 | } 13 | }, 14 | }, 15 | plugins: [], 16 | } 17 | -------------------------------------------------------------------------------- /ui/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "outDir": "./dist/", 4 | "noImplicitAny": true, 5 | "module": "es6", 6 | "target": "es5", 7 | "jsx": "react", 8 | "allowJs": true, 9 | "moduleResolution": "node", 10 | }, 11 | "include": ["src/custom.d.ts"] 12 | 13 | } -------------------------------------------------------------------------------- /ui/webpack.config.js: -------------------------------------------------------------------------------- 1 | const path = require('path'); 2 | module.exports = { 3 | entry: './src/index.tsx', 4 | module: { 5 | rules: [ 6 | { 7 | test: /\.tsx?$/, 8 | use: 'ts-loader', 9 | exclude: /node_modules/, 10 | }, 11 | ], 12 | }, 13 | resolve: { 14 | extensions: ['.tsx', '.ts', '.js'], 15 | }, 16 | output: { 17 | filename: 'bundle.js', 18 | path: path.resolve(__dirname, 'dist'), 19 | }, 20 | }; --------------------------------------------------------------------------------