├── .gitignore ├── .pre-commit-config.yaml ├── QA Tool Guideline.md ├── README.md ├── alembic.ini ├── alembic ├── README ├── env.py └── script.py.mako ├── docker-compose.postgresql.dev.yml ├── docker-compose.postgresql.prod.yml ├── docker-compose.sh ├── dump.sh ├── pyproject.toml ├── requirements.txt ├── resources ├── sentence-selection.png └── tts-workflow.png ├── scripts ├── alignment.py ├── paper │ └── check_postedits.ipynb ├── rematch_closest.py ├── run_whisper.py ├── set_delivery_files_to_db.py └── whisper_model.py ├── secrets.env.example ├── src ├── __init__.py ├── enums.py ├── logger.py ├── paths.py ├── service │ ├── __init__.py │ ├── annotations.py │ ├── annotators.py │ ├── api.py │ ├── bases.py │ ├── datasets.py │ ├── models.py │ ├── samples.py │ └── tasks.py ├── utils │ ├── __init__.py │ ├── alignment_utils.py │ ├── audio.py │ ├── db_utils.py │ ├── decorators.py │ ├── experiment_tracking.py │ ├── trim_asr.py │ ├── utils.py │ ├── wer_wo_punct.py │ └── whisper_model.py └── web_app │ ├── admin │ ├── pages │ │ ├── 01_📤_Upload.py │ │ └── 02_📊_Insights.py │ └── 🏠_Intro_admin.py │ └── annotator │ ├── 02_🔍_QA-depreciated.py │ ├── pages │ ├── 01_🎙️_Record │ └── 02_🔍_QA.py │ └── 🏠_Intro_annotator.py ├── test.ipynb └── vars.env /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | **/__pycache__ 3 | **/.DS_Store 4 | **/*.ckpt 5 | **/*.pkl 6 | **/*.csv 7 | **/*.log 8 | **/*.out 9 | **/*.png 10 | **/*.pt 11 | **/*.PNG 12 | **/*.jpg 13 | **/*.jpeg 14 | **/*.json 15 | **/*.png 16 | **/*.tif 17 | **/*.tiff 18 | **/*.pth 19 | **/*.zip 20 | **/exec_spec* 21 | backend/xview2/ 22 | # ignore weights 23 | **/weights/ 24 | **/data/ 25 | 26 | notebooks/batches** 27 | alembic/versions/ 28 | configs/login_config.yaml 29 | # ignore only root data 30 | /data 31 | **/manifest_vad_input.json 32 | **/results/ 33 | **/lightning_logs/ 34 | **/pretrained_models/ 35 | **/vad_frame/ 36 | **/WIP/ 37 | **/wandb/ 38 | /.vscode/ 39 | /checkpoints/ 40 | /result/ 41 | /data/**/*.csv 42 | /tests/**/*.json 43 | /secrets.env 44 | ZeroDuration.txt 45 | **/*.wav 46 | 47 | 48 | notebooks/output/ 49 | 50 | # Byte-compiled / optimized / DLL files 51 | __pycache__/ 52 | *.py[cod] 53 | *$py.class 54 | 55 | # C extensions 56 | *.so 57 | 58 | # Distribution / packaging 59 | .Python 60 | build/ 61 | develop-eggs/ 62 | dist/ 63 | downloads/ 64 | eggs/ 65 | .eggs/ 66 | lib/ 67 | lib64/ 68 | parts/ 69 | sdist/ 70 | var/ 71 | wheels/ 72 | share/python-wheels/ 73 | *.egg-info/ 74 | .installed.cfg 75 | *.egg 76 | MANIFEST 77 | 78 | # PyInstaller 79 | *.manifest 80 | *.spec 81 | 82 | # Installer logs 83 | pip-log.txt 84 | pip-delete-this-directory.txt 85 | 86 | # Unit test / coverage reports 87 | htmlcov/ 88 | .tox/ 89 | .nox/ 90 | .coverage 91 | .coverage.* 92 | .cache 93 | nosetests.xml 94 | coverage.xml 95 | *.cover 96 | *.py,cover 97 | .hypothesis/ 98 | .pytest_cache/ 99 | cover/ 100 | 101 | # Translations 102 | *.mo 103 | *.pot 104 | 105 | # Django stuff: 106 | *.log 107 | local_settings.py 108 | db.sqlite3 109 | db.sqlite3-journal 110 | 111 | # Flask stuff: 112 | instance/ 113 | .webassets-cache 114 | 115 | # Scrapy stuff: 116 | .scrapy 117 | 118 | # Sphinx documentation 119 | docs/_build/ 120 | 121 | # PyBuilder 122 | .pybuilder/ 123 | target/ 124 | 125 | # Jupyter Notebook 126 | .ipynb_checkpoints 127 | 128 | # IPython 129 | profile_default/ 130 | ipython_config.py 131 | 132 | # pdm 133 | .pdm.toml 134 | 135 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 136 | __pypackages__/ 137 | 138 | # Celery stuff 139 | celerybeat-schedule 140 | celerybeat.pid 141 | 142 | # SageMath parsed files 143 | *.sage.py 144 | 145 | # Environments 146 | .env 147 | .venv 148 | env/ 149 | venv/ 150 | ENV/ 151 | env.bak/ 152 | venv.bak/ 153 | 154 | # Spyder project settings 155 | .spyderproject 156 | .spyproject 157 | 158 | # Rope project settings 159 | .ropeproject 160 | 161 | # mkdocs documentation 162 | /site 163 | 164 | # mypy 165 | .mypy_cache/ 166 | .dmypy.json 167 | dmypy.json 168 | 169 | # Pyre type checker 170 | .pyre/ 171 | 172 | # pytype static type analyzer 173 | .pytype/ 174 | 175 | # Cython debug symbols 176 | cython_debug/ 177 | 178 | # PyCharm 179 | .idea/ 180 | notebooks/slide-notebooks/error_parameters.json 181 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # # exclude if you don't want to use pre-commit 2 | # exclude: 3 | repos: 4 | # simple, useful checks 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.0.1 7 | hooks: 8 | - id: check-added-large-files 9 | args: [--maxkb=1500] 10 | - id: check-json 11 | - id: check-yaml 12 | args: ['--unsafe'] # this will allow us to use tags like !ref and !new 13 | - id: check-merge-conflict 14 | - id: check-executables-have-shebangs 15 | - id: check-shebang-scripts-are-executable 16 | - id: debug-statements 17 | - id: detect-private-key 18 | - id: end-of-file-fixer 19 | - id: fix-byte-order-marker 20 | - id: pretty-format-json 21 | args: [--autofix, --no-sort-keys] 22 | - id: trailing-whitespace 23 | # sort imports 24 | - repo: https://github.com/pycqa/isort 25 | rev: 5.10.1 26 | hooks: 27 | - id: isort 28 | name: isort 29 | args: 30 | # passing settings here did not work, so we pass them in a pyproject.toml file 31 | - --settings-path=pyproject.toml 32 | # format code 33 | - repo: https://github.com/psf/black 34 | rev: 22.3.0 35 | hooks: 36 | - id: black 37 | args: 38 | - --line-length=160 39 | # run type checker 40 | - repo: https://github.com/pre-commit/mirrors-mypy 41 | rev: v0.910 42 | hooks: 43 | - id: mypy 44 | exclude: backend/database/ # exclude this folder because it is generated 45 | exclude: src/data/public/ 46 | additional_dependencies: 47 | - pydantic 48 | - types-all 49 | args: 50 | - --ignore-missing-imports 51 | - --no-strict-optional 52 | - --no-warn-no-return 53 | # run linter 54 | - repo: https://github.com/PyCQA/flake8 55 | rev: 3.9.2 56 | hooks: 57 | - id: flake8 58 | additional_dependencies: 59 | - pep8-naming 60 | - flake8-bugbear 61 | - flake8-comprehensions 62 | - flake8-cognitive-complexity 63 | - flake8-expression-complexity 64 | - pep8-naming 65 | args: 66 | - --select=F401,N801,N802,N804,N805,N807 67 | - --max-cognitive-complexity=2 68 | # check for old-style python patterns (e.g. class Foo(object):) and correct them 69 | - repo: https://github.com/asottile/pyupgrade 70 | rev: v3.1.0 71 | hooks: 72 | - id: pyupgrade 73 | args: [--py38-plus] 74 | # clear output of notebooks 75 | - repo: local 76 | hooks: 77 | - id: jupyter-nb-clear-output 78 | name: jupyter-nb-clear-output 79 | files: \.ipynb$ 80 | stages: [commit] 81 | language: system 82 | entry: jupyter nbconvert --ClearOutputPreprocessor.enabled=True --inplace 83 | -------------------------------------------------------------------------------- /QA Tool Guideline.md: -------------------------------------------------------------------------------- 1 | ## aiXplain QA Tool Guideline 2 | 3 | **Site:** : 4 | **Dataset name:** English 5 | **Username:** `` 6 | **Password:** `` 7 | 8 | Instructions: 9 | 10 | 1. **Accessing the Tool:** 11 | - Open the provided URL in your web browser. 12 | - Navigate to the "Intro Annotator" page. 13 | - Login using the provided username and password. 14 | 15 | 2. **Dataset Selection:** 16 | - After logging in, go to the QA page. 17 | - On the sidebar, select the dataset that you have been assigned to. 18 | 19 | 3. **Annotation Process:** 20 | - The samples for annotation will be presented to you in a specific order. 21 | - Start with the sample having the highest Word Error Rate (WER). 22 | - For each sample, follow the steps below: 23 | 24 | a. **Listening to the Audio:** 25 | - Click on the play button to listen to the audio recording. 26 | 27 | b. **Selecting the Transcription:** 28 | - Choose the transcription that best matches the audio. 29 | - Ensure that the transcription accurately represents the content of the audio. 30 | - If necessary, you can make post-editing changes to the selected transcription. 31 | 32 | c. **Discarding Poor Quality Audio:** 33 | - If the audio is not of good quality, it should be discarded. 34 | - Select one or more reasons for discarding from the following options: 35 | - Repetition: The audio contains repeated sentences. 36 | - Incorrect Accent: The accent in the audio is not appropriate. 37 | - Incorrect Pronunciation: The audio has incorrect pronunciations or includes "hmm" sounds. 38 | - Poor Recording Quality: The audio has background noises, lips sounds, or is of low quality. 39 | - Incorrect Pauses: The pauses in the audio are not appropriately placed. 40 | - Incorrect Speed: The audio is either too fast or too slow. 41 | - Inconsistent Voice: The voice level in the audio is inconsistent and does not feel natural. 42 | 43 | d. **Providing Feedback (Optional):** 44 | - If you discard an audio sample, you can provide additional feedback regarding the reason for discarding. 45 | 46 | e. **Submitting the Annotation:** 47 | - Once you have selected the best transcription or discarded the audio, click on the "Submit" button to save your annotation. 48 | 49 | **Note:** Please ensure that you follow these guidelines while using the QA Tool for checking recordings used in TTS model training. 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # TTS QA - Quality Assessment Text to Speech Data Annotation Tool 3 | 4 | 5 | 6 | 7 | 8 | 9 | ## **Overview** 10 | 11 | 12 | ![TTS QA - Workflow](resources/tts-workflow.png) 13 | 14 | 15 | 16 | For a quick overview of the project, please watch the following video: 17 | [![TTS QA - Quality Assessment Text to Speech Data Annotation Tool](https://img.youtube.com/vi/TH71pxxMFfE/0.jpg)](https://www.youtube.com/watch?v=TH71pxxMFfE) 18 | 19 | 20 | ## **Step 1: Prerequisites** 21 | 22 | - Python 3.6 or higher installed on your system along with packages in requirements.txt. You may install them using pip install -r requirements.txt while the environment is active. 23 | - docker-compose 24 | - aiXplain SDK installed and configured as described in the **[aiXplain SDK documentation](https://github.com/aixplain/aiXplain#getting-started)** (More details below) 25 | - AWS S3 credentials (access key and secret access key) to upload and download files from S3 26 | 27 | #### Details: 28 | Prior to running the code, you will need to set up the following services to set up the repo: 29 | - **[aiXplain](https://platform.aixplain.com/)**: For transcribing audio, ASR models are accessed from the aiXplain platform. This repo makes use of the aixplain platform and its models as an essential element. aiXplain provides easy to use no-code AI/ ML solutions to integrate into applications such as this. They can be easily integrated into applications with a single API call. 30 | 31 | To use the aiXplain tools, you firstly need to create an account on the aiXplain platform. Then, you can choose from the plethora of models to use directly or create pipelines that use those models in a cascade. Trying or deploying those models requires credits, which may be easily purchased from the platform. 32 | 33 | After setting up, you need to generate a private TEAM_API_KEY from the integrations settings. Please store that safely as it will be used by the aiXplain SDK to securely access your account and models/ pipelines. 34 | 35 | Following are some short youtube videos that explain the aiXplain platform and how to use it: 36 | * **[aiXplain Teaser](https://www.youtube.com/watch?v=lDIe0kA-DJ8)**: Overview 37 | * **[aiXplain Tools](https://www.youtube.com/watch?v=A7MuD8W_Qkw)**: Tools overview such as models, piplines, benchmark and finetune. 38 | * **[aiXplain Discover](https://www.youtube.com/watch?v=H6_gmsCE4vM)**: Find and try over 38,000 models hosted on the platform 39 | * **[aiXplain Credits](https://www.youtube.com/watch?v=X5EYqXDKb3I)**: How to purchase and use credits 40 | 41 | - AWS S3 bucket: This is used to temporarily store data for being processed by the pipeline. 42 | 43 | After both are set up, you should enter the relevant information and credentials in the environment files: 44 | 1. Configure the environment file: **`vars.env`**. 45 | 2. Open the **`vars.env`** file and add the following environment variables: 46 | - **`S3_BUCKET_NAME`**: The name of the S3 bucket where the video and subtitles will be stored. **THE BUCKET MUST BE PUBLIC.** 47 | - **`S3_DATASET_DIR`**: The folder path inside the S3 bucket where the video will be stored. 48 | - **`AWS_ACCESS_KEY_ID`**: Your AWS access key ID. 49 | - **`AWS_SECRET_ACCESS_KEY`**: Your AWS secret access key. 50 | - **`AWS_DEFAULT_REGION`**: The AWS region where the S3 bucket is located. 51 | 52 | 3. Configure the environment file: **`secrets.env`**. 53 | 4. Open the **`secrets.env.example`** file and save it as **`secrets.env`**. Then, add the following environment variables: 54 | - **`TEAM_API_KEY`**: aiXplain Platform API key. (Generated from aiXplain platform from Team Settings > Integrations) 55 | - **`HUGGINGFACE_TOKEN`**: Huggingface token for using the models. (You may get it from the huggingface website after signing up) 56 | 57 | 58 | ## Start Database 59 | To start the postgreSQL database and the redis database, run the following command. 60 | ``` 61 | bash docker-compose.sh start 62 | ``` 63 | This will create docker containers for each. 64 | 65 | ## Start Celery 66 | Use the following command to start celery. This is used to schedule tasks asynchronously. 67 | ``` 68 | celery -A src.service.tasks worker --loglevel=info --pool=threads 69 | ``` 70 | ## Start Backend 71 | The following will start the backend service that handes the data processing. 72 | ``` 73 | uvicorn src.service.api:app --port 8089 --reload 74 | ``` 75 | 76 | ## Start WebApp Frontend 77 | Please note that the previous services need to be running properly for the web app to work. 78 | 79 | ### 1. Annotator 80 | You may use the following command to run the annotator app while in the project root directory. 81 | ``` 82 | python -m streamlit run --server.port 8501 ./src/web_app/annotator/🏠_Intro_annotator.py 83 | ``` 84 | 85 | ### 2. Admin 86 | You may use the following command to run the admin app while in the project root directory. 87 | 88 | ``` 89 | python -m streamlit run --server.port 8502 --server.maxUploadSize 8192 ./tts-qa/src/web_app/admin/🏠_Intro_admin.py 90 | ``` 91 | You may choose open ports of your choice. 92 | 93 | You can upload a csv file containing the text and a zip file containing recordings. Example file may be downloaded from the frontend to see the format. Moreover, you may also extract the start and end ids of the recordings from the file names by providing a regex filter to extract those. After uploading the corresponding files, the processing will start. Once the initial processing end (visible through celery), you will need to start the trimming script using. 94 | ``` 95 | python ./tts-qa/src/utils/trim_asr.py 96 | ``` 97 | 98 | 99 | ### How to dump and restore the database 100 | 101 | The following command generates a backup dump of the database with the timestamp as its name, which you may save to s3. 102 | ```bash 103 | docker exec -t postgres_container_dev pg_dump -U postgres dev_tts_db > dump_`date +%Y-%m-%d"_"%H_%M_%S`.sql 104 | 105 | cat dump_2023-08-08_10_16_24.sql | docker exec -i postgres_container_dev psql -U postgres dev_tts_db 106 | ``` 107 | The above is for postgres_container_dev, however, you can replace dev with prod for the production container. 108 | 109 | ## Get Duration script 110 | Here are some queries to run to get some insights about the data. 111 | 112 | ```bash 113 | CREATE FUNCTION ROUND(float,int) RETURNS NUMERIC AS $f$ 114 | SELECT ROUND( CAST($1 AS numeric), $2 ) 115 | $f$ language SQL IMMUTABLE; 116 | ``` 117 | 118 | ```bash 119 | SELECT dataset.name as dataset_name, ROUND(SUM(sample.trimmed_audio_duration) / 60, 2) AS minutes, ROUND(SUM(sample.trimmed_audio_duration) / 3600, 2) AS hours 120 | FROM sample 121 | JOIN dataset ON sample.dataset_id = dataset.id 122 | WHERE dataset.name NOT LIKE '%English%' AND dataset.name NOT LIKE '%German%' 123 | GROUP BY dataset.name 124 | ORDER BY dataset.name; 125 | ``` 126 | 127 | Sum of the duration of all the samples in a dataset 128 | 129 | ```bash 130 | SELECT SUM(sample.trimmed_audio_duration) / 60 / 60 as duration_after_trimming 131 | FROM sample 132 | JOIN dataset ON sample.dataset_id = dataset.id 133 | WHERE dataset.name LIKE '%' || 'English' || '%'; 134 | ``` 135 | 136 | Sum of the duration of all the samples for each dataset.language 137 | 138 | ```bash 139 | SELECT dataset.language as dataset_name, ROUND(SUM(sample.trimmed_audio_duration) / 60, 2) AS minutes, ROUND(SUM(sample.trimmed_audio_duration) / 3600, 2) AS hours 140 | FROM sample 141 | JOIN dataset ON sample.dataset_id = dataset.id 142 | WHERE dataset.name NOT LIKE '%English (A%' 143 | GROUP BY dataset.language 144 | ORDER BY dataset.language; 145 | ``` 146 | -------------------------------------------------------------------------------- /alembic.ini: -------------------------------------------------------------------------------- 1 | # A generic, single database configuration. 2 | 3 | [alembic] 4 | # path to migration scripts 5 | script_location = alembic 6 | 7 | # template used to generate migration files 8 | # file_template = %%(rev)s_%%(slug)s 9 | 10 | # max length of characters to apply to the 11 | # "slug" field 12 | #truncate_slug_length = 40 13 | 14 | # set to 'true' to run the environment during 15 | # the 'revision' command, regardless of autogenerate 16 | # revision_environment = false 17 | 18 | # set to 'true' to allow .pyc and .pyo files without 19 | # a source .py file to be detected as revisions in the 20 | # versions/ directory 21 | # sourceless = false 22 | 23 | # version location specification; this defaults 24 | # to alembic/versions. When using multiple version 25 | # directories, initial revisions must be specified with --version-path 26 | # version_locations = %(here)s/bar %(here)s/bat alembic/versions 27 | 28 | # the output encoding used when revision files 29 | # are written from script.py.mako 30 | # output_encoding = utf-8 31 | 32 | sqlalchemy.url = driver://user:pass@localhost/dbname 33 | 34 | 35 | # Logging configuration 36 | [loggers] 37 | keys = root,sqlalchemy,alembic 38 | 39 | [handlers] 40 | keys = console 41 | 42 | [formatters] 43 | keys = generic 44 | 45 | [logger_root] 46 | level = WARN 47 | handlers = console 48 | qualname = 49 | 50 | [logger_sqlalchemy] 51 | level = WARN 52 | handlers = 53 | qualname = sqlalchemy.engine 54 | 55 | [logger_alembic] 56 | level = INFO 57 | handlers = 58 | qualname = alembic 59 | 60 | [handler_console] 61 | class = StreamHandler 62 | args = (sys.stderr,) 63 | level = NOTSET 64 | formatter = generic 65 | 66 | [formatter_generic] 67 | format = %(levelname)-5.5s [%(name)s] %(message)s 68 | datefmt = %H:%M:%S 69 | -------------------------------------------------------------------------------- /alembic/README: -------------------------------------------------------------------------------- 1 | Generic single-database configuration. 2 | 3 | # How to use it 4 | 1. Create a new database and a new user for it. 5 | 2. Copy this file to alembic.ini and edit it. 6 | 3. Run "alembic upgrade head" to create the tables. 7 | 4. Run "alembic revision --autogenerate -m "message"" to generate a new revision. 8 | 5. Run "alembic upgrade head" to apply the new revision. 9 | 6. Modify the models in models.py. 10 | 7. Go to step 4. 11 | -------------------------------------------------------------------------------- /alembic/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from logging.config import fileConfig 3 | 4 | from alembic import context 5 | from dotenv import load_dotenv 6 | from sqlalchemy import engine_from_config, pool 7 | 8 | import sys 9 | 10 | ROOD_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 11 | sys.path.append(ROOD_DIR) 12 | 13 | from src.service import models 14 | from src.paths import paths 15 | 16 | BASE_DIR = str(paths.PROJECT_ROOT_DIR.resolve()) 17 | 18 | # load the .env file 19 | load_dotenv(os.path.join(BASE_DIR, "vars.env")) 20 | 21 | 22 | # this is the Alembic Config object, which provides 23 | # access to the values within the .ini file in use. 24 | config = context.config 25 | 26 | # ---------------- added code here -------------------------# 27 | # this will overwrite the ini-file sqlalchemy.url path 28 | # with the path given in the config of the main code 29 | 30 | POSTGRES_URL = os.getenv("POSTGRES_URL") 31 | config.set_main_option("sqlalchemy.url", POSTGRES_URL) 32 | 33 | 34 | # Interpret the config file for Python logging. 35 | # This line sets up loggers basically. 36 | fileConfig(config.config_file_name) 37 | 38 | # add your model's MetaData object here 39 | # for 'autogenerate' support 40 | # from myapp import mymodel 41 | # target_metadata = mymodel.Base.metadata 42 | target_metadata = models.Base.metadata 43 | 44 | # other values from the config, defined by the needs of env.py, 45 | # can be acquired: 46 | # my_important_option = config.get_main_option("my_important_option") 47 | # ... etc. 48 | 49 | # def include_object(object, name, type_, reflected, compare_to): 50 | # if type_ == "table": 51 | # if name in target_metadata.tables: 52 | # return True 53 | # else: 54 | # return False 55 | # else: 56 | # return False 57 | 58 | 59 | def include_object(object, name, type_, reflected, compare_to): 60 | if type_ == "table" and reflected and compare_to is None: 61 | return False 62 | else: 63 | return True 64 | 65 | 66 | def run_migrations_offline(): 67 | """Run migrations in 'offline' mode. 68 | 69 | This configures the context with just a URL 70 | and not an Engine, though an Engine is acceptable 71 | here as well. By skipping the Engine creation 72 | we don't even need a DBAPI to be available. 73 | 74 | Calls to context.execute() here emit the given string to the 75 | script output. 76 | 77 | """ 78 | url = config.get_main_option("sqlalchemy.url") 79 | context.configure(url=url, target_metadata=target_metadata, literal_binds=True, include_object=include_object) 80 | 81 | with context.begin_transaction(): 82 | context.run_migrations() 83 | 84 | 85 | def run_migrations_online(): 86 | """Run migrations in 'online' mode. 87 | 88 | In this scenario we need to create an Engine 89 | and associate a connection with the context. 90 | 91 | """ 92 | connectable = engine_from_config(config.get_section(config.config_ini_section), prefix="sqlalchemy.", poolclass=pool.NullPool) 93 | 94 | with connectable.connect() as connection: 95 | context.configure(connection=connection, target_metadata=target_metadata, include_object=include_object) 96 | 97 | with context.begin_transaction(): 98 | context.run_migrations() 99 | 100 | 101 | if context.is_offline_mode(): 102 | run_migrations_offline() 103 | else: 104 | run_migrations_online() 105 | -------------------------------------------------------------------------------- /alembic/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | 9 | # revision identifiers, used by Alembic. 10 | revision = ${repr(up_revision)} 11 | down_revision = ${repr(down_revision)} 12 | branch_labels = ${repr(branch_labels)} 13 | depends_on = ${repr(depends_on)} 14 | 15 | from alembic import op 16 | import sqlalchemy as sa 17 | ${imports if imports else ""} 18 | 19 | def upgrade(): 20 | ${upgrades if upgrades else "pass"} 21 | 22 | 23 | def downgrade(): 24 | ${downgrades if downgrades else "pass"} 25 | -------------------------------------------------------------------------------- /docker-compose.postgresql.dev.yml: -------------------------------------------------------------------------------- 1 | version: '3.5' 2 | 3 | services: 4 | db: 5 | container_name: postgres_container_dev 6 | image: postgres 7 | environment: 8 | POSTGRES_USER: ${POSTGRES_USER:-postgres} 9 | POSTGRES_PASSWORD: ${POSTGRES_PWD:-postgres} 10 | POSTGRES_DB: ${POSTGRES_DB:-dev_tts_db} 11 | PGDATA: /data/postgres 12 | ports: 13 | - "5432:5432" 14 | 15 | redis: 16 | container_name: redis_container_dev 17 | image: redis 18 | ports: 19 | - "6379:6379" 20 | networks: 21 | default: 22 | name: tts_network 23 | -------------------------------------------------------------------------------- /docker-compose.postgresql.prod.yml: -------------------------------------------------------------------------------- 1 | version: '3.5' 2 | 3 | services: 4 | db: 5 | container_name: postgres_container_prod 6 | image: postgres 7 | environment: 8 | POSTGRES_USER: ${POSTGRES_USER:-postgres} 9 | POSTGRES_PASSWORD: ${POSTGRES_PWD:-postgres} 10 | POSTGRES_DB: ${POSTGRES_DB:-prod_tts_db} 11 | PGDATA: /data/postgres 12 | ports: 13 | - "5432:5432" 14 | 15 | redis: 16 | container_name: redis_container_prod 17 | image: redis 18 | ports: 19 | - "6379:6379" 20 | 21 | networks: 22 | default: 23 | name: tts_network 24 | -------------------------------------------------------------------------------- /docker-compose.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Load environment variables from .env file 4 | set -o allexport 5 | source vars.env 6 | set +o allexport 7 | 8 | # read a param to start or stop 9 | 10 | START_STOP=$1 11 | 12 | 13 | # print POSTGRES_DB 14 | echo "POSTGRES_DB: ${POSTGRES_DB}" 15 | # Start PostgreSQL container for the specified environment 16 | if [ "$START_STOP" = "start" ]; then 17 | echo "Starting PostgreSQL container for ${ENV} environment" 18 | docker-compose -f docker-compose.postgresql.${ENV}.yml up -d 19 | elif [ "$START_STOP" = "stop" ]; then 20 | echo "Stopping PostgreSQL container for ${ENV} environment" 21 | docker-compose -f docker-compose.postgresql.${ENV}.yml down 22 | else 23 | echo "Invalid parameter. Use 'start' or 'stop'" 24 | fi 25 | -------------------------------------------------------------------------------- /dump.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd /data/tts-qa/dumps 3 | docker exec -t postgres_container_dev pg_dump -U postgres dev_tts_db > /data/tts-qa/dumps/dump_`date +%Y-%m-%d"_"%H_%M_%S`.sql 4 | aws s3 sync /data/tts-qa/dumps/ s3:///tts-qa-dumps/ 5 | echo "done" 6 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | profile = "black" 3 | line_length = 160 4 | lines_after_imports = 2 5 | force_alphabetical_sort_within_sections = true 6 | known_third_party = ["wandb"] 7 | 8 | [tool.pytest.ini_options] 9 | filterwarnings = [ 10 | "ignore::FutureWarning", 11 | "ignore::UserWarning", 12 | ] 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aixplain==0.1.1 2 | boto3==1.26.127 3 | editdistance==0.6.2 4 | fastapi==0.95.1 5 | FastAPI_SQLAlchemy==0.2.1 6 | jiwer==3.0.1 7 | librosa==0.9.0 8 | openai_whisper==20230314 9 | pandas==1.5.3 10 | pyannote.audio==2.1.1 11 | pyannote.core==4.5 12 | pyannote.database==4.1.3 13 | pyannote.metrics==3.2.1 14 | pyannote.pipeline==2.3 15 | pydantic_sqlalchemy==0.0.9 16 | pydub==0.25.1 17 | python-dotenv==1.0.0 18 | SoundFile==0.10.3.post1 19 | SQLAlchemy==1.4.47 20 | starlette==0.26.1 21 | streamlit==1.19.0 22 | streamlit_authenticator==0.2.1 23 | tqdm==4.64.1 24 | wandb==0.15.2 25 | -------------------------------------------------------------------------------- /resources/sentence-selection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/tts-qa/288a920d37cfd75f5e1aec6d34efe7b624582d0f/resources/sentence-selection.png -------------------------------------------------------------------------------- /resources/tts-workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/tts-qa/288a920d37cfd75f5e1aec6d34efe7b624582d0f/resources/tts-workflow.png -------------------------------------------------------------------------------- /scripts/alignment.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import re 5 | import sys 6 | from glob import glob 7 | 8 | 9 | sys.path.append("../") 10 | 11 | import editdistance 12 | import numpy as np 13 | import pandas as pd 14 | from pyannote.audio import Model 15 | from pyannote.audio.pipelines import VoiceActivityDetection 16 | from pydub import AudioSegment 17 | from tqdm import tqdm 18 | 19 | from src.utils.whisper_model import WhisperTimestampedASR 20 | from src.paths import paths 21 | from dotenv import find_dotenv, load_dotenv 22 | 23 | load_dotenv(find_dotenv(paths.PROJECT_ROOT_DIR / "secrets.env"), override=True) 24 | 25 | 26 | def edit_distance(s1, s2): 27 | return editdistance.eval(s1, s2) 28 | 29 | 30 | def format_int(i): 31 | return str(i).zfill(8) 32 | 33 | 34 | # trim the audio using start end end time in secs 35 | def trim_audio(path, start, end, out_path): 36 | sound = AudioSegment.from_file(path, format="wav") 37 | trimmed_sound = sound[start * 1000 : end * 1000] 38 | trimmed_sound.export(out_path, format="wav") 39 | return out_path 40 | 41 | 42 | modelPyannote = Model.from_pretrained( 43 | "pyannote/segmentation", use_auth_token=os.getenv("HUGGINGFACE_TOKEN") 44 | ) 45 | 46 | pipeline = VoiceActivityDetection(segmentation=modelPyannote) 47 | HYPER_PARAMETERS = { 48 | # onset/offset activation thresholds 49 | "onset": 0.5, 50 | "offset": 0.5, 51 | # remove speech regions shorter than that many seconds. 52 | "min_duration_on": 0.0, 53 | # fill non-speech regions shorter than that many seconds. 54 | "min_duration_off": 0.05, 55 | } 56 | pipeline.instantiate(HYPER_PARAMETERS) 57 | 58 | 59 | batches = ["batches_French8_2/"] 60 | padding = 0.25 61 | 62 | 63 | print("Loading Whisper model...") 64 | 65 | lang_map = { 66 | "en": "english", 67 | "fr": "french", 68 | "es": "spanish", 69 | "de": "german", 70 | "it": "italian", 71 | } 72 | 73 | for batch in batches: 74 | if "English" in batch: 75 | language = "en" 76 | elif "French" in batch: 77 | language = "fr" 78 | elif "German" in batch: 79 | language = "de" 80 | whisper_model = WhisperTimestampedASR( 81 | model_size="large-v2", language=lang_map[language] 82 | ) 83 | whisper_model.load() 84 | print(f"Processing batch {batch}") 85 | filenames = glob(batch + "*.wav") 86 | for filename in filenames: 87 | print(f"Processing {filename}") 88 | if os.path.exists(filename + ".vad.bin"): 89 | print(f"Detected VAD for {filename} - loading from file") 90 | vad = pickle.load(open(filename + ".vad.bin", "rb")) 91 | else: 92 | print(f"Running VAD for {filename}") 93 | vad = pipeline(filename) 94 | i = 0 95 | print(f"Saving VAD for {filename}") 96 | with open(filename + ".vad.bin", "wb") as f: 97 | pickle.dump(vad, f) 98 | print(f"Saved VAD for {filename}") 99 | 100 | data = AudioSegment.from_file(filename) 101 | # read start_loc and end_loc from wav file name using regex 102 | if language == "fr": 103 | start_loc = int(re.search(r"From (\d+) -", filename).group(1)) 104 | end_loc = int(re.search(r"- (\d+)", filename).group(1)) 105 | elif language == "en": 106 | # reg sdhould work on EN00000003-EN00000012 107 | start_loc = int(re.search(r"EN(\d+)-", filename).group(1)) 108 | end_loc = int(re.search(r"-EN(\d+)", filename).group(1)) 109 | elif language == "de": 110 | # reg sdhould work on start_1-end_500 111 | start_loc = int(re.search(r"DE(\d+)-", filename).group(1)) 112 | end_loc = int(re.search(r"-DE(\d+)", filename).group(1)) 113 | 114 | print(f"start_loc: {start_loc}, end_loc: {end_loc}") 115 | 116 | sentences = {} 117 | inverseSentences = {} 118 | df_sentences = pd.read_csv(f"{language} - {language}.csv") 119 | id_int = df_sentences["unique_identifier"].apply(lambda x: int(x[2:])) 120 | df_sentences["id_int"] = id_int 121 | df_sentences.set_index("id_int", inplace=True) 122 | # include only ids in between start_loc and end_loc 123 | df_sentences = df_sentences.loc[start_loc:end_loc] 124 | 125 | print(f"There are {len(df_sentences)} sentences in this range") 126 | for index, row in df_sentences.iterrows(): 127 | sentenceNum = int(index) 128 | sentence = row["text"] 129 | sentences[sentenceNum] = sentence 130 | if sentence not in inverseSentences: 131 | inverseSentences[sentence] = sentenceNum 132 | else: 133 | tmp = sentence 134 | while tmp in inverseSentences: 135 | tmp += " _" 136 | inverseSentences[tmp] = sentenceNum 137 | 138 | sentenceNumber = -1 139 | segments = {} 140 | if os.path.exists(filename + ".segments.json"): 141 | print(f"Detected segments for {filename} - loading from file") 142 | segments = json.load(open(filename + ".segments.json")) 143 | else: 144 | print(f"Running ASR for {filename}") 145 | timeline = vad.get_timeline().support() 146 | for segment in tqdm(timeline): 147 | start, end = list(segment) 148 | start = max(0, start - padding) 149 | end = min(end + padding, len(data) / 1000) 150 | seg = {} 151 | seg["SegmentStart"] = start 152 | seg["SegmentEnd"] = end 153 | outputAudio = AudioSegment.empty() 154 | outputAudio += data[ 155 | seg["SegmentStart"] * 1000 : seg["SegmentEnd"] * 1000 156 | ] 157 | outputAudio.export(batch + "TMP/tmp.wav", format="wav") 158 | # run ASR 159 | try: 160 | result = whisper_model.predict( 161 | {"instances": [{"url": batch + "TMP/tmp.wav"}]} 162 | ) 163 | asr = result["predictions"][0] 164 | seg["asr"] = asr 165 | except: 166 | seg["asr"] = "" 167 | pass 168 | segments[start] = seg 169 | # save segments 170 | print(f"Saving segments for {filename}") 171 | with open(filename + ".segments.json", "w") as fout: 172 | json.dump(segments, fout, indent=4) 173 | print(f"Saved segments for {filename}") 174 | 175 | print(f"Matching segments to sentences for {filename}") 176 | segments_list = [v for k, v in segments.items()] 177 | sentences_list = [v for k, v in sentences.items()] 178 | distances_matrix = np.ones((len(segments_list), len(sentences))) * 1000 179 | 180 | for ik in range(len(segments_list)): 181 | for jk, sentence in enumerate(sentences_list): 182 | try: 183 | distances_matrix[ik, jk] = edit_distance( 184 | segments_list[ik]["asr"], sentence 185 | ) / min(len(segments_list[ik]["asr"]), len(sentence)) 186 | except: 187 | distances_matrix[ik, jk] = np.inf 188 | 189 | # get the best match for each segment 190 | best_matches = np.argmin(distances_matrix, axis=1) 191 | 192 | # # make a dataframe 193 | columns = [ 194 | "status", 195 | "filename", 196 | "sentenceNumber", 197 | "sentence", 198 | "asr", 199 | "start", 200 | "end", 201 | "ed_dist", 202 | "len_dif", 203 | ] 204 | df = pd.DataFrame(columns=columns) 205 | best_matched_sentences = [sentences_list[k] for k in best_matches] 206 | 207 | # print the results 208 | for ik in range(len(segments_list)): 209 | asr = segments_list[ik]["asr"] 210 | sentence = best_matched_sentences[ik] 211 | ed_dist = distances_matrix[ik, best_matches[ik]] 212 | try: 213 | len_dif = abs(len(asr) - len(sentence)) / min(len(asr), len(sentence)) 214 | except: 215 | len_dif = np.inf 216 | start = segments_list[ik]["SegmentStart"] 217 | end = segments_list[ik]["SegmentEnd"] 218 | sentenceNumber = inverseSentences[sentence] 219 | if ed_dist < 0.25 and len_dif < 0.15: 220 | status = "assigned" 221 | else: 222 | status = "not_assigned" 223 | 224 | row = { 225 | "status": status, 226 | "filename": filename, 227 | "sentenceNumber": sentenceNumber, 228 | "sentence": sentence, 229 | "asr": asr, 230 | "start": start, 231 | "end": end, 232 | "ed_dist": ed_dist, 233 | "len_dif": len_dif, 234 | } 235 | df = df.append(row, ignore_index=True) 236 | # if there is inf drop it 237 | df = df.replace([np.inf, -np.inf], np.nan) 238 | df.dropna(inplace=True) 239 | 240 | print(f"Assigned {len(df[df.status=='assigned'])} segments") 241 | print(f"Not assigned {len(df[df.status=='not_assigned'])} segments") 242 | 243 | # if there are multiple rows with same sentenceNumber take the last one and drop the rest 244 | df = df.sort_values(by=["sentenceNumber"]) 245 | df = df.drop_duplicates(subset=["sentenceNumber"], keep="last") 246 | 247 | print(f"Status counts for {filename}:") 248 | print(df.status.value_counts()) 249 | df.to_csv(filename + ".csv", index=False) 250 | 251 | # create a folder for wav files 252 | wav_folder = os.path.join(batch, os.path.basename(filename).replace(".wav", "")) 253 | if os.path.exists(wav_folder): 254 | print(f"Folder {wav_folder} already exists, skipping") 255 | continue 256 | os.makedirs(wav_folder, exist_ok=True) 257 | 258 | # create "assigned" and "not_assigned" folders 259 | os.makedirs(os.path.join(wav_folder, "assigned"), exist_ok=True) 260 | os.makedirs(os.path.join(wav_folder, "not_assigned"), exist_ok=True) 261 | 262 | # for each row in the dataframe if the status is assigned, create a wav file with the start and end times of the segment 263 | # if not assigned, create a wav file with the start and end times of the segment 264 | print(f"Trimming audio for {filename}, it will be saved in {wav_folder}") 265 | for index, row in tqdm(df.iterrows(), total=len(df)): 266 | start = row["start"] 267 | end = row["end"] 268 | asr = row["asr"] 269 | sentence = row["sentence"] 270 | status = row["status"] 271 | if status == "assigned": 272 | wav_path = os.path.join( 273 | wav_folder, 274 | "assigned", 275 | f"{language.upper()}" + format_int(row["sentenceNumber"]) + ".wav", 276 | ) 277 | else: 278 | wav_path = os.path.join( 279 | wav_folder, 280 | "not_assigned", 281 | f"{language.upper()}" + format_int(row["sentenceNumber"]) + ".wav", 282 | ) 283 | 284 | outpath, start, end = trim_audio(filename, start, end, wav_path) 285 | -------------------------------------------------------------------------------- /scripts/paper/check_postedits.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# connect to postgresql db usingenvironment variable read from vars.env \n", 10 | "\n", 11 | "import warnings\n", 12 | "warnings.filterwarnings(\"ignore\")\n", 13 | "\n", 14 | "import os\n", 15 | "import psycopg2\n", 16 | "import pandas as pd\n", 17 | "import numpy as np\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "import seaborn as sns\n", 20 | "\n", 21 | "# read environment variables from vars.env\n", 22 | "from dotenv import load_dotenv\n", 23 | "load_dotenv(\"../../vars.env\")\n", 24 | "\n", 25 | "# connect to postgresql db on localhost, post 5432, using user and password from vars.env\n", 26 | "\n", 27 | "import psycopg2\n", 28 | "import os\n", 29 | "\n", 30 | "# Define the database credentials\n", 31 | "db_host = os.getenv(\"POSTGRES_HOST\")\n", 32 | "db_name = os.getenv(\"POSTGRES_DB\")\n", 33 | "db_user = os.getenv(\"POSTGRES_USER\")\n", 34 | "db_password = os.getenv(\"POSTGRES_PWD\")\n", 35 | "\n", 36 | "# Create the connection\n", 37 | "connection = psycopg2.connect(\n", 38 | " host=db_host,\n", 39 | " database=db_name,\n", 40 | " user=db_user,\n", 41 | " password=db_password\n", 42 | ")\n", 43 | "\n" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "dataset_str = \"English\"\n", 53 | "query = f\"\"\"\n", 54 | "SELECT sample.* \n", 55 | "FROM sample\n", 56 | "JOIN dataset ON sample.dataset_id = dataset.id\n", 57 | "WHERE dataset.name='{dataset_str}' \n", 58 | " AND sample.trimmed_audio_duration > 0 \n", 59 | " AND sample.trimmed_audio_duration < 5 \n", 60 | " AND sample.longest_pause < 0.5\n", 61 | " AND sample.is_selected_for_delivery = TRUE\n", 62 | "\"\"\"\n", 63 | "\n", 64 | "\n", 65 | "\n", 66 | "# Assume you've loaded the sample and annotation tables as DataFrames\n", 67 | "sample_df = pd.read_sql(query, connection)\n", 68 | "annotation_df = pd.read_sql(\"SELECT * FROM annotation\", connection)\n", 69 | "\n", 70 | "# Merge the DataFrames based on the sample.id and annotation.sample_id\n", 71 | "merged_df = pd.merge(sample_df, annotation_df, left_on='id', right_on='sample_id', how='left')\n" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "merged_df = merged_df[~merged_df.status.isnull()]\n", 81 | "\n", 82 | "# drop duplicates in id_x and select first row\n", 83 | "merged_df = merged_df.drop_duplicates(subset=['id_x'], keep='first')\n", 84 | "\n" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "# merged_df = merged_df[[\"filename\", \"original_text\", \"final_text\"]]" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "merged_df[\"is_edited\"] = False\n", 103 | "# merged_df[merged_df.original_text != merged_df.final_text][\"is_edited\"] = True\n", 104 | "merged_df.loc[merged_df.original_text != merged_df.final_text, \"is_edited\"] = True\n", 105 | "# if status = Discarded then is_edited = False\n", 106 | "merged_df.loc[merged_df.status == \"Discarded\", \"is_edited\"] = False\n" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "merged_df.shape" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "merged_df[merged_df.is_edited == True][[\"filename\", \"original_text\", \"final_text\"]].head(10)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "merged_df[\"is_edited\"].value_counts(normalize=True)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "merged_df.status.value_counts(normalize=True)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "# calculate the duration of the not discarded samples\n", 152 | "merged_df[merged_df.status != \"Discarded\"].trimmed_audio_duration.sum() / 60 / 60" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [] 161 | } 162 | ], 163 | "metadata": { 164 | "kernelspec": { 165 | "display_name": "corpus-insight", 166 | "language": "python", 167 | "name": "python3" 168 | }, 169 | "language_info": { 170 | "codemirror_mode": { 171 | "name": "ipython", 172 | "version": 3 173 | }, 174 | "file_extension": ".py", 175 | "mimetype": "text/x-python", 176 | "name": "python", 177 | "nbconvert_exporter": "python", 178 | "pygments_lexer": "ipython3", 179 | "version": "3.8.8" 180 | }, 181 | "orig_nbformat": 4 182 | }, 183 | "nbformat": 4, 184 | "nbformat_minor": 2 185 | } 186 | -------------------------------------------------------------------------------- /scripts/rematch_closest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import editdistance 4 | import numpy as np 5 | import pandas as pd 6 | import psycopg2 7 | 8 | # read environment variables from vars.env 9 | from dotenv import load_dotenv 10 | from tqdm import tqdm 11 | 12 | 13 | load_dotenv("../vars.env") 14 | 15 | # connect to postgresql db on localhost, post 5432, using user and password from vars.env 16 | 17 | import os 18 | 19 | import psycopg2 20 | 21 | 22 | # Define the database credentials 23 | db_host = os.getenv("POSTGRES_HOST") 24 | db_name = os.getenv("POSTGRES_DB") 25 | db_user = os.getenv("POSTGRES_USER") 26 | db_password = os.getenv("POSTGRES_PWD") 27 | 28 | 29 | def edit_distance(s1, s2): 30 | return editdistance.eval(s1, s2) 31 | 32 | 33 | for dataset in ["German(Dorothee)"]: 34 | print(f"Processing {dataset}") 35 | 36 | sql_script = f""" 37 | SELECT dataset.name, sample.id, sample.filename, sample.local_trimmed_path, sample.original_text, sample.asr_text, sample.wer, sample.trimmed_audio_duration as duration 38 | FROM sample 39 | JOIN dataset ON sample.dataset_id = dataset.id 40 | WHERE dataset.name LIKE '%' || '{dataset}' || '%'; 41 | """ 42 | 43 | # Connect to the database 44 | conn = psycopg2.connect(host=db_host, database=db_name, user=db_user, password=db_password) 45 | 46 | # Execute the SQL script into pandas dataframe with column names 47 | df = pd.read_sql_query(sql_script, conn) 48 | # group by name and then create a dict of the grouped dataframes 49 | 50 | df_dict = {k: v for k, v in df.groupby("name")} 51 | df_matched_list = [] 52 | for df_name, df_sentences in df_dict.items(): 53 | print(f"Processing {df_name}") 54 | df_sentences = df_sentences.reset_index(drop=True) 55 | 56 | sentences = {} 57 | inverseSentences = {} 58 | 59 | segments = {} 60 | 61 | print(f"There are {len(df_sentences)} sentences in this range") 62 | for index, row in df_sentences.iterrows(): 63 | sentenceNum = int(index) 64 | sentence = row["original_text"] 65 | sentences[sentenceNum] = sentence 66 | 67 | segments[sentenceNum] = row 68 | if sentence not in inverseSentences: 69 | inverseSentences[sentence] = sentenceNum 70 | else: 71 | tmp = sentence 72 | while tmp in inverseSentences: 73 | tmp += " _" 74 | inverseSentences[tmp] = sentenceNum 75 | 76 | sentenceNumber = -1 77 | 78 | segments_list = [v for k, v in segments.items()] 79 | sentences_list = [v for k, v in sentences.items()] 80 | distances_matrix = np.ones((len(segments_list), len(sentences))) * 1000 81 | 82 | for ik in tqdm(range(len(segments_list))): 83 | # for jk in range(ik-500, min(len(segments_list), ik+500)): 84 | for jk in range(len(sentences_list)): 85 | try: 86 | distances_matrix[ik, jk] = edit_distance(segments_list[ik]["asr_text"], sentences_list[jk]) / min( 87 | len(segments_list[ik]["asr_text"]), len(sentences_list[jk]) 88 | ) 89 | except: 90 | distances_matrix[ik, jk] = np.inf 91 | 92 | # get the best match for each segment 93 | best_matches = np.argmin(distances_matrix, axis=1) 94 | best_matched_sentences = [sentences_list[k] for k in best_matches] 95 | 96 | # # make a dataframe 97 | rows = [] 98 | best_matched_sentences = [sentences_list[k] for k in best_matches] 99 | 100 | # print the results 101 | for ik in tqdm(range(len(segments_list))): 102 | asr = segments_list[ik]["asr_text"] 103 | sentence = best_matched_sentences[ik] 104 | ed_dist = distances_matrix[ik, best_matches[ik]] 105 | try: 106 | len_dif = abs(len(asr) - len(sentence)) / min(len(asr), len(sentence)) 107 | except: 108 | len_dif = np.inf 109 | sentenceNumber = inverseSentences[sentence] 110 | if ed_dist < 0.25 and len_dif < 0.15: 111 | status = "assigned" 112 | else: 113 | status = "not_assigned" 114 | 115 | row = { 116 | "status": status, 117 | "originalNumber": ik, 118 | "original_id": segments_list[ik]["id"], 119 | "assigned_id": segments[sentenceNumber]["id"], 120 | "original_sentence": sentences_list[ik], 121 | "assigned_sentence": sentence, 122 | "ed_dist": ed_dist, 123 | "len_dif": len_dif, 124 | } 125 | 126 | row.update(segments_list[ik]) 127 | rows.append(row) 128 | # if there is inf drop it 129 | df_matched_ = pd.DataFrame(rows) 130 | df_matched_ = df_matched_[df_matched_["ed_dist"] != np.inf] 131 | 132 | diff = df_matched_[df_matched_.original_id != df_matched_.assigned_id] 133 | diff = diff[diff.status == "assigned"] 134 | diff = diff.sort_values("ed_dist").drop_duplicates("assigned_id", keep="first") 135 | 136 | if len(diff) > 0: 137 | # diff.to_csv(f"diff_{df_name}.csv", index=False) 138 | print(f"Found {len(diff)} differences") 139 | df_matched_list.append(diff) 140 | 141 | df_matched = pd.concat(df_matched_list) 142 | df_matched = df_matched.sort_values("ed_dist").drop_duplicates("assigned_id", keep="first") 143 | 144 | df_matched.to_csv(f"matched-{dataset}.csv", index=False) 145 | 146 | print(f"Matched {len(df_matched)} sentences") 147 | -------------------------------------------------------------------------------- /scripts/run_whisper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import json 4 | import os 5 | from glob import glob 6 | from pathlib import Path 7 | 8 | import pandas as pd 9 | from tqdm import tqdm 10 | 11 | from src.logger import root_logger 12 | from src.paths import paths 13 | from src.utils.whisper_model import WhisperTimestampedASR 14 | 15 | 16 | logger = root_logger.getChild(__name__) 17 | 18 | 19 | if __name__ == "__main__": 20 | files = glob(str(paths.PROCESSED_DATASETS_DIR.resolve()) + "/*.csv") 21 | for file in files: 22 | df_trans = pd.read_csv(file, index_col=None) 23 | language = df_trans["language"][0] 24 | 25 | for model_size in ["tiny", "base", "small", "medium"]: 26 | if language == "English": 27 | model_size = model_size + ".en" 28 | 29 | logger.info(f"Loading model {model_size} for {language}") 30 | model = WhisperTimestampedASR(model_size=model_size, language=language) 31 | model.load() 32 | save_dir = paths.FEATURES_DATASETS_DIR / f"whisper_{language}_{model_size}" 33 | if not save_dir.exists(): 34 | save_dir.mkdir() 35 | 36 | # TQDM DF 37 | for i, row in tqdm(df_trans.iterrows(), total=df_trans.shape[0]): 38 | filepath = row["filename"] 39 | filepath = Path(filepath) 40 | 41 | filename = filepath.name 42 | save_path = str((save_dir / filename).resolve()) + ".json" 43 | if os.path.exists(save_path): 44 | continue 45 | try: 46 | request = {"instances": [{"url": row["filename"]}]} 47 | response = model.predict(request) 48 | except: 49 | logger.error(f"Error processing {filepath}") 50 | continue 51 | 52 | with open(save_path, "w") as f: 53 | json.dump(response, f, ensure_ascii=False) 54 | -------------------------------------------------------------------------------- /scripts/set_delivery_files_to_db.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.filterwarnings("ignore") 5 | 6 | import os 7 | 8 | import pandas as pd 9 | import psycopg2 10 | 11 | # read environment variables from vars.env 12 | from dotenv import load_dotenv 13 | 14 | 15 | load_dotenv("../vars.env") 16 | 17 | # connect to postgresql db on localhost, post 5432, using user and password from vars.env 18 | 19 | import os 20 | 21 | import psycopg2 22 | 23 | 24 | # Define the database credentials 25 | db_host = os.getenv("POSTGRES_HOST") 26 | db_name = os.getenv("POSTGRES_DB") 27 | db_user = os.getenv("POSTGRES_USER") 28 | db_password = os.getenv("POSTGRES_PWD") 29 | 30 | total_hours = 30 31 | include_extras = False 32 | 33 | from tqdm import tqdm 34 | 35 | 36 | for dataset_str in [ 37 | "German(Dorothee)" 38 | ]: # ["English(Melynda)", "French(Dorsaf)"]: # "Italian(Martina)", "Spanish(Violeta)"]: # "English(Melynda)", "French(Dorsaf)", 39 | if "English" in dataset_str: 40 | dataset = "English" 41 | elif "Spanish" in dataset_str: 42 | dataset = "Spanish" 43 | elif "German" in dataset_str: 44 | dataset = "German" 45 | elif "French" in dataset_str: 46 | dataset = "French" 47 | elif "Italian" in dataset_str: 48 | dataset = "Italian" 49 | print(f"Processing {dataset}...") 50 | df_wav = pd.read_csv(f"/data/tts-qa/share_{total_hours}h/{dataset}/{dataset}.csv") 51 | if include_extras: 52 | df_extras = pd.read_csv(f"/data/tts-qa/share_{total_hours}h/{dataset}/{dataset}-extras.csv") 53 | df = pd.concat([df_wav, df_extras], axis=0) 54 | else: 55 | df = df_wav 56 | # set all samples is_selected_for_delivery to true by matching filename in the postgres database 57 | filenames = df.filename.to_list() 58 | 59 | # connect to postgres 60 | conn = psycopg2.connect(host=db_host, database=db_name, user=db_user, password=db_password) 61 | 62 | # create cursor 63 | cur = conn.cursor() 64 | 65 | sql_script = f""" 66 | UPDATE sample 67 | SET is_selected_for_delivery = FALSE 68 | WHERE dataset_id IN (SELECT id FROM dataset WHERE name = '{dataset_str}'); 69 | """ 70 | 71 | cur.execute(sql_script) 72 | conn.commit() 73 | 74 | # update all samples to is_seslected_for_delivery = true 75 | for filename in tqdm(filenames): 76 | cur.execute( 77 | f""" 78 | UPDATE sample 79 | SET is_selected_for_delivery = true 80 | WHERE dataset_id IN ( 81 | SELECT id 82 | FROM dataset 83 | WHERE name= '{dataset_str}' 84 | ) 85 | AND filename = '{filename}'; 86 | """ 87 | ) 88 | 89 | conn.commit() 90 | -------------------------------------------------------------------------------- /scripts/whisper_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import traceback 4 | import urllib.request 5 | from typing import Dict 6 | from uuid import uuid4 7 | 8 | import whisper 9 | 10 | 11 | MODEL_DIR = os.getenv("MODEL_DIR", "/mnt/models") 12 | MODEL_NAME = os.getenv("MODEL_NAME", "whisper_model") 13 | 14 | 15 | class WhisperASR: 16 | def __init__(self, model_size="tiny", language="English"): 17 | self.model = None 18 | self.ready = False 19 | options = dict(language=language) 20 | self.transcribe_options = dict(task="transcribe", **options) 21 | self.model_size = model_size 22 | 23 | def load(self): 24 | model_path = os.path.join(MODEL_DIR, MODEL_NAME) 25 | self.model = whisper.load_model(self.model_size, download_root=model_path) 26 | self.ready = True 27 | 28 | def get_file_from_url(self, url, tempdir): 29 | _, extension = os.path.splitext(url) 30 | extension = extension[1:] 31 | extension = extension.split("?")[0] 32 | input_path = os.path.join(tempdir, f"{str(uuid4())}.{extension}") 33 | urllib.request.urlretrieve(url, input_path) 34 | return input_path 35 | 36 | def predict(self, request: Dict) -> Dict: 37 | try: 38 | transcriptions = [] 39 | inputs = request["instances"] 40 | with tempfile.TemporaryDirectory(prefix="whisper-asr-") as tempdir: 41 | for request in inputs: 42 | # check if url is s3 link or local 43 | if request["url"].startswith("s3://"): 44 | audio_file = self.get_file_from_url(request["url"], tempdir) 45 | else: 46 | audio_file = request["url"] 47 | 48 | transcriptions.append(self.model.transcribe(audio_file, **self.transcribe_options)["text"]) 49 | 50 | return {"predictions": transcriptions} 51 | except ValueError as e: 52 | print(traceback.format_exc()) 53 | raise ValueError(f"Failed to process request: {e}") 54 | -------------------------------------------------------------------------------- /secrets.env.example: -------------------------------------------------------------------------------- 1 | # Environment variables go here, can be read by `python-dotenv` package: 2 | # 3 | # `src/script.py` 4 | # ---------------------------------------------------------------- 5 | # import dotenv 6 | # 7 | # project_dir = os.path.join(os.path.dirname(__file__), os.pardir) 8 | # dotenv_path = os.path.join(project_dir, '.env') 9 | # dotenv.load_dotenv(dotenv_path) 10 | # ---------------------------------------------------------------- 11 | 12 | TEAM_API_KEY= 13 | HUGGINGFACE_TOKEN= -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/tts-qa/288a920d37cfd75f5e1aec6d34efe7b624582d0f/src/__init__.py -------------------------------------------------------------------------------- /src/enums.py: -------------------------------------------------------------------------------- 1 | # We use this file to define all the enums used in the project 2 | 3 | # from enum import Enum 4 | # Example: 5 | # class Color(Enum): 6 | # RED = 1 7 | # GREEN = 2 8 | # BLUE = 3 9 | 10 | from enum import Enum 11 | 12 | 13 | class RunType(Enum): 14 | DATASET_ANALYSIS = "dataset_analysis" 15 | EVALUATION = "evaluation" 16 | TRAINING = "training" 17 | 18 | 19 | class DatasetType(Enum): 20 | TRAIN = "train" 21 | TEST = "test" 22 | VALIDATION = "validation" 23 | -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import warnings 4 | 5 | import click 6 | 7 | 8 | handler = logging.StreamHandler() 9 | formatter = logging.Formatter("%(asctime)s %(levelname)-8s %(filename)s:%(lineno)d %(message)s", "%m-%d %H:%M:%S") 10 | handler.setFormatter(formatter) 11 | logging.addLevelName(logging.DEBUG, click.style(str(logging.getLevelName(logging.DEBUG)), fg="cyan")) 12 | logging.addLevelName(logging.INFO, click.style(str(logging.getLevelName(logging.INFO)), fg="green")) 13 | logging.addLevelName(logging.WARNING, click.style(str(logging.getLevelName(logging.WARNING)), fg="yellow")) 14 | logging.addLevelName(logging.ERROR, click.style(str(logging.getLevelName(logging.ERROR)), fg="red")) 15 | logging.addLevelName(logging.CRITICAL, click.style(str(logging.getLevelName(logging.CRITICAL)), fg="bright_red")) 16 | logging.basicConfig(handlers=[handler]) 17 | logging.getLogger("matplotlib").setLevel(logging.WARNING) 18 | logging.getLogger("urllib3.connectionpool").setLevel(logging.WARNING) 19 | logging.getLogger("urllib3.util.retry").setLevel(logging.WARNING) 20 | root_logger = logging.getLogger() 21 | root_logger.setLevel(os.environ.get("LOG_LEVEL", "ERROR")) 22 | 23 | warnings.simplefilter(action="ignore", category=FutureWarning) 24 | warnings.simplefilter(action="ignore", category=UserWarning) 25 | -------------------------------------------------------------------------------- /src/paths.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from pathlib import Path 3 | 4 | 5 | now_str = datetime.strftime(datetime.now(), "%Y-%m-%d_%H-%M-%S") 6 | 7 | 8 | class Paths: 9 | PROJECT_ROOT_DIR = Path(__file__).parent.parent 10 | DATASETS_DIR: Path = Path("/data") 11 | LOCAL_BUCKET_DIR: Path = DATASETS_DIR / "tts-qa" 12 | OUTPUTS_DIR: Path = PROJECT_ROOT_DIR / "outputs" 13 | REPORTS_DIR: Path = PROJECT_ROOT_DIR / "reports" 14 | SRC_DIR: Path = PROJECT_ROOT_DIR / "src" 15 | TESTS_DIR: Path = PROJECT_ROOT_DIR / "tests" 16 | 17 | RAW_DATASETS_DIR: Path = DATASETS_DIR / "raw" 18 | PROCESSED_DATASETS_DIR: Path = DATASETS_DIR / "processed" 19 | EXTERNAL_DATASETS_DIR: Path = DATASETS_DIR / "external" 20 | FEATURES_DATASETS_DIR: Path = DATASETS_DIR / "features" 21 | 22 | DATASET_SCRIPTS_DIR: Path = SRC_DIR / "scripts" 23 | 24 | MODELS_DIR: Path = OUTPUTS_DIR / "models" 25 | PREDICTIONS_DIR: Path = OUTPUTS_DIR / "predictions" 26 | 27 | FIGURES_DIR: Path = REPORTS_DIR / "figures" 28 | 29 | PIPELINE_PATH: Path = MODELS_DIR / "pipeline.pkl" 30 | BEST_MODEL_PATH: Path = MODELS_DIR / "model_best.pkl" 31 | 32 | CONFIGS_DIR: Path = PROJECT_ROOT_DIR / "configs" 33 | LOGIN_CONFIG_PATH: Path = CONFIGS_DIR / "login_config.yaml" 34 | 35 | OUTPUT_DIR_PATTERN: str = f"{OUTPUTS_DIR}/" + "{run_type}/{model_name}/{run_tag}/" + now_str 36 | 37 | 38 | paths = Paths() 39 | -------------------------------------------------------------------------------- /src/service/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/tts-qa/288a920d37cfd75f5e1aec6d34efe7b624582d0f/src/service/__init__.py -------------------------------------------------------------------------------- /src/service/annotations.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from fastapi import APIRouter 4 | 5 | from src.logger import root_logger 6 | from src.service.bases import AnnotationModel # noqa: F401 7 | from src.utils import db_utils 8 | 9 | 10 | router = APIRouter(prefix="/annotations", tags=["annotations"]) 11 | 12 | app_logger = root_logger.getChild("api/annotations") 13 | 14 | # list annotations 15 | @router.get("/") 16 | def list_annotators() -> List[AnnotationModel]: 17 | annotations = db_utils.list_annotations() 18 | # map the annotations to the AnnotationModel 19 | return [AnnotationModel(**annotation.to_dict()) for annotation in annotations] 20 | -------------------------------------------------------------------------------- /src/service/annotators.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | from fastapi import APIRouter 4 | 5 | from src.logger import root_logger 6 | from src.service.bases import AnnotatorModel, DatasetModel, InfoModel, SampleModel # noqa: F401 7 | from src.utils import db_utils 8 | 9 | 10 | router = APIRouter(prefix="/annotators", tags=["annotators"]) 11 | 12 | app_logger = root_logger.getChild("api/annotators") 13 | 14 | 15 | # list annotators 16 | @router.get("/") 17 | def list_annotators() -> List[AnnotatorModel]: 18 | annotators = db_utils.list_annotators() 19 | # map the annotators to the AnnotatorModel 20 | return [AnnotatorModel(**annotator.to_dict()) for annotator in annotators] 21 | 22 | 23 | # create an annotator 24 | @router.post("/{username}") 25 | def create_annotator( 26 | username: str, name: str, email: str, password: str, ispreauthorized: bool = True, isadmin: bool = False 27 | ) -> Union[AnnotatorModel, InfoModel]: 28 | try: 29 | annotator = db_utils.create_annotator(username=username, name=name, email=email, password=password, ispreauthorized=ispreauthorized, isadmin=isadmin) 30 | return AnnotatorModel(**annotator.to_dict()) 31 | except Exception as e: 32 | return InfoModel(**{"message": "Failed", "error": str(e)}) 33 | 34 | 35 | # get annotator 36 | @router.get("/{id}") 37 | def get_annotator_by_id(id: int) -> Union[AnnotatorModel, InfoModel]: 38 | try: 39 | annotator = db_utils.get_annotator_by_id(id) 40 | return AnnotatorModel(**annotator.to_dict()) 41 | except Exception as e: 42 | return InfoModel(**{"message": "Failed", "error": str(e)}) 43 | 44 | 45 | # get annotator 46 | @router.get("/username/{username}") 47 | def get_annotator_by_username(username: str) -> Union[AnnotatorModel, InfoModel]: 48 | try: 49 | annotator = db_utils.get_annotator_by_username(username) 50 | return AnnotatorModel(**annotator.to_dict()) 51 | except Exception as e: 52 | return InfoModel(**{"message": "Failed", "error": str(e)}) 53 | 54 | 55 | # delete an annotator 56 | @router.delete("/{id}") 57 | def delete_annotator(id: int) -> Union[AnnotatorModel, InfoModel]: 58 | try: 59 | db_utils.delete_annotator(id) 60 | return InfoModel(**{"message": "Success"}) 61 | except Exception as e: 62 | return InfoModel(**{"message": "Failed", "error": str(e)}) 63 | 64 | 65 | # assign dataset that the annotator permitted to annotate 66 | @router.post("/{id}/datasets/{dataset_id}") 67 | def assign_annotator_to_dataset(id: int, dataset_id: int) -> InfoModel: 68 | try: 69 | db_utils.assign_annotator_to_dataset(id, dataset_id) 70 | return InfoModel(**{"message": "Success"}) 71 | except Exception as e: 72 | return InfoModel(**{"message": "Failed", "error": str(e)}) 73 | 74 | 75 | # get datasets that the annotator permitted to annotate 76 | @router.get("/{id}/datasets") 77 | def get_datasets_of_annotator(id: int) -> Union[List[DatasetModel], InfoModel]: 78 | try: 79 | datasets = db_utils.get_datasets_of_annotator(id) 80 | return [DatasetModel(**dataset.to_dict()) for dataset in datasets] 81 | except Exception as e: 82 | return InfoModel(**{"message": "Failed", "error": str(e)}) 83 | 84 | 85 | # get latest sample of the annotator 86 | @router.get("/{id}/samples/{dataset_id}/latest") 87 | def get_latest_sample_of_annotator(id: int, dataset_id: int) -> Union[SampleModel, InfoModel]: 88 | try: 89 | sample = db_utils.get_latest_sample_of_annotator(id, dataset_id) 90 | if sample: 91 | return SampleModel(**sample.to_dict()) 92 | return InfoModel(**{"message": "No sample found for this annotator in this dataset"}) 93 | except Exception as e: 94 | return InfoModel(**{"message": "Failed", "error": str(e)}) 95 | 96 | 97 | # # update an annotator 98 | # @router.put("/{id}") 99 | # def update_annotator(id: int, username: str, email: str) -> Union[AnnotatorModel, InfoModel]: 100 | # try: 101 | # annotator = {"username": username, "email": email } 102 | 103 | # annotator = db_utils.update_annotator(id, **annotator) # type: ignore 104 | # return AnnotatorModel(**annotator.to_dict()) # type: ignore # noqa: F821 105 | # except Exception as e: 106 | # return InfoModel(**{"message": "Failed", "error": str(e)}) 107 | -------------------------------------------------------------------------------- /src/service/api.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dotenv import load_dotenv 4 | from fastapi import FastAPI 5 | from fastapi_sqlalchemy import DBSessionMiddleware 6 | from starlette.middleware.cors import CORSMiddleware 7 | 8 | from src.logger import root_logger 9 | from src.paths import paths 10 | from src.service.annotators import router as annotators_router 11 | from src.service.datasets import router as datasets_router 12 | from src.service.samples import router as samples_router 13 | 14 | 15 | app_logger = root_logger.getChild("api") 16 | 17 | BASE_DIR = str(paths.PROJECT_ROOT_DIR.resolve()) 18 | # load the .env file 19 | load_dotenv(os.path.join(BASE_DIR, "vars.env")) 20 | 21 | 22 | app = FastAPI(title="TTS QA", openapi_url="/api/v1/openapi.json") 23 | 24 | app.add_middleware( 25 | CORSMiddleware, 26 | allow_origins=["*"], 27 | allow_credentials=True, 28 | allow_methods=["*"], 29 | allow_headers=["*"], 30 | ) 31 | 32 | app.add_middleware(DBSessionMiddleware, db_url=os.getenv("POSTGRES_URL")) 33 | 34 | app.logger = app_logger 35 | 36 | app.include_router(datasets_router) 37 | app.include_router(samples_router) 38 | app.include_router(annotators_router) 39 | 40 | 41 | @app.get("/") 42 | def read_root(): 43 | return {"message": "Welcome to the TTS QA API"} 44 | -------------------------------------------------------------------------------- /src/service/bases.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from pydantic import BaseModel, Field 4 | from pydantic_sqlalchemy import sqlalchemy_to_pydantic 5 | 6 | from src.service.models import Annotation, Annotator, Dataset, Sample 7 | 8 | 9 | BaseAnnotatorModel = sqlalchemy_to_pydantic(Annotator) 10 | BaseAnnotationModel = sqlalchemy_to_pydantic(Annotation) 11 | BaseSampleModel = sqlalchemy_to_pydantic(Sample) 12 | BaseDatasetModel = sqlalchemy_to_pydantic(Dataset) 13 | 14 | 15 | class AnnotatorModel(BaseAnnotatorModel): # type: ignore 16 | """The annotation model.""" 17 | 18 | pass 19 | 20 | 21 | class AnnotationModel(BaseAnnotationModel): # type: ignore 22 | """The annotation model.""" 23 | 24 | pass 25 | 26 | 27 | class SampleModel(BaseSampleModel): # type: ignore 28 | """The sample model.""" 29 | 30 | pass 31 | 32 | 33 | class DatasetModel(BaseDatasetModel): # type: ignore 34 | """The dataset model.""" 35 | 36 | pass 37 | 38 | 39 | class InfoModel(BaseModel): 40 | """The error model.""" 41 | 42 | message: str = Field(..., description="The error message") 43 | # error field might be empty or contain the error message 44 | error: Optional[str] = Field(default=None, description="The error message") 45 | 46 | class Config: 47 | schema_extra = { 48 | "example": { 49 | "message": "Failed", 50 | "error": "The dataset does not exist", 51 | } 52 | } 53 | 54 | 55 | class InputAnnotationModel(BaseModel): 56 | """The input annotation model.""" 57 | 58 | annotator_id: int = Field(..., description="The annotator id") 59 | final_text: str = Field(..., description="The final text") 60 | final_sentence_type: str = Field(..., description="The final sentence type") 61 | isRepeated: bool = Field(..., description="The sample is repeated") 62 | # isAccentRight: bool = Field(..., description="The accent is right") 63 | # isPronunciationRight: bool = Field(..., description="The pronunciation is right") 64 | # isClean: bool = Field(..., description="The sample is clean") 65 | # isPausesRight: bool = Field(..., description="The pauses are right") 66 | # isSpeedRight: bool = Field(..., description="The speed is right") 67 | # isConsisent: bool = Field(..., description="The sample is consistent") 68 | incorrectProsody: bool = Field(..., description="The sample has incorrect prosody") 69 | inconsistentTextAudio: bool = Field(..., description="The sample has inconsistent text and audio") 70 | incorrectTrancuation: bool = Field(..., description="The sample has incorrect trancuation") 71 | soundArtifacts: bool = Field(..., description="The sample has sound artifacts") 72 | feedback: str = Field(default=None, description="The feedback") 73 | status: str = Field(default="NotReviewed", description="The status") 74 | -------------------------------------------------------------------------------- /src/service/datasets.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import traceback 3 | from typing import List, Union 4 | 5 | from fastapi import APIRouter 6 | 7 | from src.logger import root_logger 8 | from src.service.bases import AnnotationModel, AnnotatorModel, DatasetModel, InfoModel, SampleModel # noqa: F401 9 | from src.utils import db_utils 10 | 11 | 12 | router = APIRouter(prefix="/datasets", tags=["datasets"]) 13 | 14 | app_logger = root_logger.getChild("api/datasets") 15 | 16 | 17 | # list datasets 18 | @router.get("/") 19 | def list_datasets() -> List[DatasetModel]: 20 | datasets = db_utils.list_datasets() 21 | # map the datasets to the DatasetModel 22 | return [DatasetModel(**dataset.to_dict()) for dataset in datasets] 23 | 24 | 25 | # create a dataset 26 | @router.post("/{name}") 27 | def create_dataset(name: str, language: str, description: str = None) -> Union[DatasetModel, InfoModel]: 28 | try: 29 | dataset = db_utils.create_dataset(name=name, language=language, description=description) 30 | return DatasetModel(**dataset.to_dict()) 31 | except Exception as e: 32 | return InfoModel(**{"message": "Failed", "error": str(e)}) 33 | 34 | 35 | # get dataset 36 | @router.get("/{id}") 37 | def get_dataset_by_id(id: int) -> Union[DatasetModel, InfoModel]: 38 | try: 39 | dataset = db_utils.get_dataset_by_id(id) 40 | return DatasetModel(**dataset.to_dict()) 41 | except Exception as e: 42 | return InfoModel(**{"message": "Failed", "error": str(e)}) 43 | 44 | 45 | # delete a dataset 46 | @router.delete("/{id}") 47 | def delete_dataset(id: int) -> Union[DatasetModel, InfoModel]: 48 | try: 49 | db_utils.delete_dataset(id) 50 | return InfoModel(**{"message": "Success"}) 51 | except Exception as e: 52 | return InfoModel(**{"message": "Failed", "error": str(e)}) 53 | 54 | 55 | # update a dataset 56 | @router.put("/{id}") 57 | def update_dataset(id: int, name: str = None, language: str = None, description: str = None) -> Union[DatasetModel, InfoModel]: 58 | try: 59 | dataset = {"name": name, "language": language, "description": description} 60 | dataset = db_utils.update_dataset(id, **dataset) # type: ignore 61 | return DatasetModel(**dataset.to_dict()) # type: ignore 62 | except Exception as e: # type: ignore 63 | return InfoModel(**{"message": "Failed", "error": str(e)}) 64 | 65 | 66 | # list all samples 67 | @router.get("/{id}/samples") 68 | def list_samples(id: int, top_k=50) -> List[SampleModel]: 69 | samples = db_utils.list_samples(id, top_k) 70 | # map the samples to the SampleModel 71 | return [SampleModel(**sample.to_dict()) for sample in samples] 72 | 73 | 74 | # insert a sample 75 | @router.post("/{id}/samples") 76 | def insert_sample( 77 | id: int, text: str, audio_path: str, sentence_length: int = None, sentence_type: str = "statement", deliverable: str = None 78 | ) -> Union[SampleModel, InfoModel]: 79 | if sentence_length is None: 80 | sentence_length = len(text.split()) 81 | 82 | try: 83 | sample = db_utils.insert_sample(id, text, audio_path, sentence_type, sentence_length, deliverable=deliverable) 84 | return SampleModel(**sample.to_dict()) 85 | except Exception as e: 86 | return InfoModel(**{"message": "Failed", "error": str(e)}) 87 | 88 | 89 | # get annotators allowed to annotate this dataset 90 | @router.get("/{id}/annotators") 91 | def get_annotators_of_dataset(id: int) -> Union[List[AnnotatorModel], InfoModel]: 92 | try: 93 | annotators = db_utils.get_annotators_of_dataset(id) 94 | return [AnnotatorModel(**annotator.to_dict()) for annotator in annotators] 95 | except Exception as e: 96 | return InfoModel(**{"message": "Failed", "error": str(e)}) 97 | 98 | 99 | # query next sample 100 | @router.get("/{id}/next_sample") 101 | def query_next_sample(id: int) -> dict: 102 | try: 103 | sample, stats = db_utils.query_next_sample(id) 104 | if sample is None: 105 | return {"sample": None, "stats": stats} 106 | return {"sample": SampleModel(**sample.to_dict()), "stats": stats} # type: ignore 107 | except Exception as e: 108 | return {"message": "Failed", "error": str(e)} 109 | 110 | 111 | # get the annotations of dataset samples 112 | @router.get("/{id}/annotations") 113 | def get_annotations_of_dataset(id: int) -> Union[List[dict], InfoModel]: 114 | try: 115 | annotations = db_utils.get_annotations_of_dataset(id) 116 | return [annotation for annotation in annotations] 117 | except Exception as e: 118 | return InfoModel(**{"message": "Failed", "error": str(e)}) 119 | 120 | 121 | def handle_exceptions(task: asyncio.Task): 122 | if task.exception(): 123 | print(f"An error occurred in the task: {task.exception()}") 124 | 125 | 126 | from src.service.tasks import segmented_onboarding_job, unsegmented_onboarding_job, unsegmented_onboarding_job_sync 127 | 128 | 129 | @router.get("/{id}/upload_segmented_async") 130 | def upload_segmented(id, csv_path: str, deliverable: str = None): 131 | job = segmented_onboarding_job.delay(dataset_id=id, csv_path=csv_path, deliverable=deliverable) 132 | return {"job_id": job.id} 133 | 134 | 135 | @router.get("/check_job_status/{job_id}") 136 | def check_job_status(job_id: str): 137 | job = segmented_onboarding_job.AsyncResult(job_id) 138 | if job.state == "SUCCESS": 139 | progress = 100 140 | elif job.state == "PENDING": 141 | progress = 0 142 | else: 143 | progress = job.info.get("progress", 0) 144 | if job.info is None: 145 | return {"status": job.status, "progress": progress, "onboarded_samples": 0, "failed_samples": []} 146 | return { 147 | "status": job.status, 148 | "progress": progress, 149 | "onboarded_samples": job.info.get("onboarded_samples", 0), 150 | "failed_samples": job.info.get("failed_samples", []), 151 | } 152 | 153 | 154 | @router.get("/{id}/upload_unsegmented_async") 155 | def upload_unsegmented_async(id, wavs_path: str, csv_path: str, start_id_regex: str, end_id_regex: str, deliverable: str = None): 156 | # get dataset language 157 | dataset = db_utils.get_dataset_by_id(id) 158 | job = unsegmented_onboarding_job.delay( 159 | dataset_id=id, 160 | language=dataset.language, 161 | wavs_path=wavs_path, 162 | csv_path=csv_path, 163 | start_id_regex=start_id_regex, 164 | end_id_regex=end_id_regex, 165 | deliverable=deliverable, 166 | ) 167 | return {"job_id": job.id} 168 | 169 | 170 | @router.get("/{id}/upload_unsegmented_sync") 171 | def upload_unsegmented_sync(id, wavs_path: str, csv_path: str, start_id_regex: str, end_id_regex: str, deliverable: str = None): 172 | # get dataset language 173 | 174 | dataset = db_utils.get_dataset_by_id(id) 175 | try: 176 | unsegmented_onboarding_job_sync( 177 | dataset_id=id, 178 | language=dataset.language, 179 | wavs_path=wavs_path, 180 | csv_path=csv_path, 181 | start_id_regex=start_id_regex, 182 | end_id_regex=end_id_regex, 183 | deliverable=deliverable, 184 | ) 185 | return {"message": "Success"} 186 | except Exception as e: 187 | app_logger.error(f"{traceback.format_exc()}") 188 | return {"message": "Failed", "error": str(e)} 189 | -------------------------------------------------------------------------------- /src/service/models.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | from sqlalchemy import Boolean, Column, DateTime, Enum, Float, ForeignKey, func, Integer, MetaData, String, Table, UniqueConstraint 4 | from sqlalchemy.ext.declarative import declarative_base 5 | from sqlalchemy.orm import backref, relationship 6 | 7 | 8 | Base = declarative_base(metadata=MetaData()) 9 | 10 | 11 | class Status(enum.Enum): 12 | Reviewed = "Reviewed" 13 | Discarded = "Discarded" 14 | NotReviewed = "NotReviewed" 15 | 16 | 17 | annotator_dataset = Table( 18 | "annotator_dataset", 19 | Base.metadata, 20 | Column("annotator_id", Integer, ForeignKey("annotator.id"), nullable=True), 21 | Column("dataset_id", Integer, ForeignKey("dataset.id"), nullable=True), 22 | ) 23 | 24 | 25 | # Define a Annotator model in qhich we will store the annotators's username and email address 26 | class Annotator(Base): # type: ignore 27 | __tablename__ = "annotator" 28 | id = Column(Integer, primary_key=True, index=True) 29 | username = Column(String(50), unique=True, nullable=False) 30 | name = Column(String(50), unique=False, nullable=False) 31 | email = Column(String(120), unique=True, nullable=False) 32 | hashed_password = Column(String(120), unique=False, nullable=False) 33 | ispreauthorized = Column(Boolean, default=True) 34 | isadmin = Column(Boolean, default=False) 35 | # defines 36 | datasets = relationship("Dataset", secondary=annotator_dataset, backref=backref("assigned_annotators", passive_deletes=True)) 37 | 38 | # add unique constraint to username 39 | __table_args__ = ( 40 | UniqueConstraint("username", name="_username_uc"), 41 | UniqueConstraint("email", name="_email_uc"), 42 | ) 43 | 44 | def __repr__(self): 45 | return f"{self.to_dict()}" 46 | 47 | def to_dict(self): 48 | return { 49 | "id": self.id, 50 | "username": self.username, 51 | "name": self.name, 52 | "email": self.email, 53 | "hashed_password": "********", 54 | "ispreauthorized": self.ispreauthorized, 55 | "isadmin": self.isadmin, 56 | } 57 | 58 | 59 | # Define a Sample model in which we will store following nformation for an tts recording sample: 60 | # id, unique filename, s3RawPath, original text, asr text, the duration of the recording, sentence_tyoe, 61 | class Sample(Base): # type: ignore 62 | __tablename__ = "sample" 63 | id = Column(Integer, primary_key=True, index=True) 64 | dataset_id = Column(Integer, ForeignKey("dataset.id")) 65 | deliverable = Column(String(50), unique=False, nullable=True) 66 | created_at = Column(DateTime, default=func.now()) 67 | filename = Column(String(50), unique=False, nullable=False) 68 | local_path = Column(String(120), unique=False, nullable=False) 69 | local_trimmed_path = Column(String(120), unique=False, nullable=True) 70 | s3RawPath = Column(String(120), unique=True, nullable=False) 71 | s3TrimmedPath = Column(String(120), unique=True, nullable=True) 72 | original_text = Column(String(250), unique=False, nullable=False) 73 | asr_text = Column(String(250), unique=False, nullable=True) 74 | duration = Column(Float, unique=False, nullable=False) 75 | trimmed_audio_duration = Column(Float, unique=False, nullable=True) 76 | sentence_type = Column(String(50), unique=False, nullable=False) 77 | sentence_length = Column(Integer, unique=False, nullable=False) 78 | sampling_rate = Column(Integer, unique=False, nullable=False) 79 | sample_format = Column(String(10), unique=False, nullable=False) 80 | isPCM = Column(Boolean, unique=False, nullable=False) 81 | n_channel = Column(Integer, unique=False, nullable=False) 82 | format = Column(String(10), unique=False, nullable=False) 83 | peak_volume_db = Column(Float, unique=False, nullable=False) 84 | size = Column(Integer, unique=False, nullable=False) 85 | isValid = Column(Boolean, unique=False, nullable=False) 86 | trim_start = Column(Float, unique=False, nullable=True) 87 | trim_end = Column(Float, unique=False, nullable=True) 88 | longest_pause = Column(Float, unique=False, nullable=True) 89 | wer = Column(Float, unique=False, nullable=True) 90 | uncased_unpunctuated_wer = Column(Float, unique=False, nullable=True) 91 | islocked = Column(Boolean, default=False, nullable=False) # this is for locking sample that is being annotated 92 | locked_at = Column(DateTime, default=None, nullable=True) 93 | is_selected_for_delivery = Column(Boolean, default=False, nullable=True) 94 | 95 | annotation = relationship("Annotation", cascade="all, delete-orphan", backref="sample") 96 | __table_args__ = ( 97 | UniqueConstraint("s3TrimmedPath", name="_s3TrimmedPath_uc"), 98 | UniqueConstraint("s3RawPath", name="_s3RawPath_uc"), 99 | # UniqueConstraint("filename", name="_filename_uc"), 100 | # dataset id and sample filename should be unique 101 | UniqueConstraint("dataset_id", "filename", name="_dataset_id_filename_uc"), 102 | ) # Example for such cases combination of filename and s3RawPath should be unique 103 | 104 | def __repr__(self): 105 | return f"{self.to_dict()}" 106 | 107 | def to_dict(self): 108 | return { 109 | "id": self.id, 110 | "dataset_id": self.dataset_id, 111 | "deliverable": self.deliverable, 112 | "created_at": self.created_at, 113 | "filename": self.filename, 114 | "local_path": self.local_path, 115 | "local_trimmed_path": self.local_trimmed_path, 116 | "s3RawPath": self.s3RawPath, 117 | "s3TrimmedPath": self.s3TrimmedPath, 118 | "original_text": self.original_text, 119 | "duration": self.duration, 120 | "trimmed_audio_duration": self.trimmed_audio_duration, 121 | "sentence_type": self.sentence_type, 122 | "sentence_length": self.sentence_length, 123 | "sampling_rate": self.sampling_rate, 124 | "sample_format": self.sample_format, 125 | "isPCM": self.isPCM, 126 | "n_channel": self.n_channel, 127 | "format": self.format, 128 | "peak_volume_db": self.peak_volume_db, 129 | "size": self.size, 130 | "isValid": self.isValid, 131 | "asr_text": self.asr_text, 132 | "trim_start": self.trim_start, 133 | "trim_end": self.trim_end, 134 | "longest_pause": self.longest_pause, 135 | "wer": self.wer, 136 | } 137 | 138 | 139 | # Define a Annotation Model in which we will store the following information for an annotation: 140 | # id, annotator_id, sample_id, the date and time when the annotation was created and annotation fields 141 | # status Enumeration y defauld it is NULL, Approved, Rejected 142 | # isAccentRight bool default is NULL, True, False 143 | # isPronunciationRight bool default is NULL, True, False 144 | # isTypeRight bool default is NULL, True, False 145 | # isClean bool default is NULL, True, False 146 | # isPausesRight bool default is NULL, True, False 147 | # isSpeedRight bool default is NULL, True, False 148 | # isConsisent bool default is NULL, True, False 149 | # feedback text 150 | class Annotation(Base): # type: ignore 151 | __tablename__ = "annotation" 152 | id = Column(Integer, primary_key=True) 153 | annotator_id = Column(Integer, ForeignKey("annotator.id"), nullable=True) 154 | sample_id = Column(Integer, ForeignKey("sample.id"), nullable=False) 155 | created_at = Column(DateTime, default=func.now()) 156 | updated_at = Column(DateTime, default=func.now(), onupdate=func.now()) 157 | status = Column(Enum(Status), default=Status.NotReviewed) 158 | final_text = Column(String(250), unique=False, nullable=True) 159 | final_sentence_type = Column(String(50), unique=False, nullable=True) 160 | isRepeated = Column(Boolean, default=None, nullable=True) 161 | isAccentRight = Column(Boolean, default=None, nullable=True) 162 | isPronunciationRight = Column(Boolean, default=None, nullable=True) 163 | isClean = Column(Boolean, default=None, nullable=True) 164 | isPausesRight = Column(Boolean, default=None, nullable=True) 165 | isSpeedRight = Column(Boolean, default=None, nullable=True) 166 | isConsisent = Column(Boolean, default=None, nullable=True) 167 | feedback = Column(String(250), unique=False, nullable=True) 168 | 169 | # Additional fields 170 | incorrectProsody = Column(Boolean, default=None, nullable=True) 171 | inconsistentTextAudio = Column(Boolean, default=None, nullable=True) 172 | incorrectTrancuation = Column(Boolean, default=None, nullable=True) 173 | soundArtifacts = Column(Boolean, default=None, nullable=True) 174 | 175 | annotator = relationship("Annotator", backref=backref("annotations", passive_deletes=True)) 176 | 177 | __table_args__ = (UniqueConstraint("annotator_id", "sample_id", name="_annotator_sample_uc"),) 178 | 179 | def __repr__(self): 180 | return f"{self.to_dict()}" 181 | 182 | def to_dict(self): 183 | return { 184 | "id": self.id, 185 | "annotator_id": self.annotator_id, 186 | "sample_id": self.sample_id, 187 | "created_at": self.created_at, 188 | "status": self.status, 189 | "final_text": self.final_text, 190 | "final_sentence_type": self.final_sentence_type, 191 | "isRepeated": self.isRepeated, 192 | "isAccentRight": self.isAccentRight, 193 | "isPronunciationRight": self.isPronunciationRight, 194 | "isClean": self.isClean, 195 | "isPausesRight": self.isPausesRight, 196 | "isSpeedRight": self.isSpeedRight, 197 | "isConsisent": self.isConsisent, 198 | "incorrectProsody": self.incorrectProsody, 199 | "inconsistentTextAudio": self.inconsistentTextAudio, 200 | "incorrectTrancuation": self.incorrectTrancuation, 201 | "soundArtifacts": self.soundArtifacts, 202 | "feedback": self.feedback, 203 | } 204 | 205 | 206 | # Define a Dataset model in which we will store the following information for a dataset: 207 | # id, name, description, the date and time when the dataset was created, list of sampes in the dataset 208 | class Dataset(Base): # type: ignore 209 | __tablename__ = "dataset" 210 | id = Column(Integer, primary_key=True, index=True) 211 | name = Column(String(50), unique=True, nullable=False) 212 | recorder = Column(String(50), unique=False, nullable=True) 213 | language = Column(String(5), unique=False, nullable=False) 214 | description = Column(String(250), unique=False, nullable=True) 215 | created_at = Column(DateTime, default=func.now()) 216 | 217 | samples = relationship("Sample", cascade="all, delete", backref="dataset") 218 | annotators = relationship("Annotator", secondary=annotator_dataset, backref=backref("assigned_datasets", passive_deletes=True)) 219 | 220 | __table_args__ = (UniqueConstraint("name", name="_name_uc"),) 221 | 222 | def __repr__(self): 223 | return f"{self.to_dict()}" 224 | 225 | def to_dict(self): 226 | return {"id": self.id, "name": self.name, "description": self.description, "language": self.language, "created_at": self.created_at} 227 | -------------------------------------------------------------------------------- /src/service/samples.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from fastapi import APIRouter 4 | 5 | 6 | router = APIRouter(prefix="/samples", tags=["samples"]) 7 | 8 | from src.logger import root_logger 9 | from src.service.bases import InfoModel, InputAnnotationModel, SampleModel # noqa: F401 10 | from src.utils import db_utils 11 | 12 | 13 | app_logger = root_logger.getChild("api/samples") 14 | 15 | 16 | # delete a sample 17 | @router.delete("{id}") 18 | def delete_sample(id: int) -> InfoModel: 19 | try: 20 | db_utils.delete_sample(id) 21 | return InfoModel(**{"message": "Success"}) 22 | except Exception as e: 23 | return InfoModel(**{"message": "Failed", "error": str(e)}) 24 | 25 | 26 | # list samples 27 | @router.get("/") 28 | def list_samples() -> List[SampleModel]: 29 | samples = db_utils.list_samples() # type: ignore 30 | # map the samples to the SampleModel 31 | return [SampleModel(**sample.to_dict()) for sample in samples] 32 | 33 | 34 | # get a sample 35 | @router.get("/{id}") 36 | def get_sample_by_id(id: int) -> SampleModel: 37 | sample = db_utils.get_sample_by_id(id) 38 | return SampleModel(**sample.to_dict()) 39 | 40 | 41 | # annotate a sample 42 | @router.put("/{id}") 43 | def annotate_sample(id: int, annotation: InputAnnotationModel) -> InfoModel: 44 | try: 45 | db_utils.annotate_sample(sample_id=id, **dict(annotation)) 46 | return InfoModel(**{"message": "Success"}) 47 | except Exception as e: 48 | return InfoModel(**{"message": "Failed", "error": str(e)}) 49 | 50 | 51 | # lock sample for annotation 52 | @router.put("/{id}/lock") 53 | def lock_sample(id: int) -> InfoModel: 54 | try: 55 | db_utils.lock_sample(id) 56 | return InfoModel(**{"message": "Success"}) 57 | except Exception as e: 58 | return InfoModel(**{"message": "Failed", "error": str(e)}) 59 | 60 | 61 | # unlock sample for annotation 62 | @router.put("/{id}/unlock") 63 | def unlock_sample(id: int) -> InfoModel: 64 | try: 65 | db_utils.unlock_sample(id) 66 | return InfoModel(**{"message": "Success"}) 67 | except Exception as e: 68 | return InfoModel(**{"message": "Failed", "error": str(e)}) 69 | -------------------------------------------------------------------------------- /src/service/tasks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | from celery import Celery, Task 5 | from dotenv import load_dotenv 6 | 7 | from src.logger import root_logger 8 | from src.paths import paths 9 | from src.utils.alignment_utils import align_wavs_vad, align_wavs_whisper # noqa F401 10 | from src.utils.db_utils import upload_wav_samples 11 | 12 | 13 | BASE_DIR = str(paths.PROJECT_ROOT_DIR.resolve()) 14 | # load the .env file 15 | load_dotenv(os.path.join(BASE_DIR, "vars.env")) 16 | 17 | app_logger = root_logger.getChild("celery") 18 | s3_bucket_name = os.environ.get("S3_BUCKET_NAME") 19 | s3_dataset_dir = os.environ.get("S3_DATASET_DIR") 20 | 21 | 22 | # get engine from url 23 | POSTGRES_URL = os.getenv("POSTGRES_URL") 24 | 25 | 26 | from sqlalchemy import create_engine 27 | from sqlalchemy.orm import sessionmaker 28 | 29 | 30 | engine = create_engine(POSTGRES_URL) 31 | SessionObject = sessionmaker(bind=engine) 32 | session = SessionObject() 33 | 34 | 35 | # Create the Celery app 36 | app = Celery("TTS-QA") 37 | 38 | # Configure the broker and result backend 39 | app.conf.broker_url = "redis://localhost:6379/0" 40 | app.conf.result_backend = "redis://localhost:6379/0" 41 | 42 | 43 | @app.task(bind=True) 44 | def segmented_onboarding_job(self: Task, dataset_id: int, csv_path: str, deliverable: str = None): 45 | # Simulate a long-running job 46 | app_logger.info("Starting segmented onboarding job") 47 | upload_wav_samples(self, session, dataset_id, csv_path, deliverable=deliverable) 48 | 49 | 50 | @app.task(bind=True) 51 | def unsegmented_onboarding_job( 52 | self: Task, dataset_id: int, language: str, wavs_path: str, csv_path: str, start_id_regex: str, end_id_regex: str, deliverable: str = None 53 | ): 54 | 55 | app_logger.info("Starting unsegmented onboarding job") 56 | 57 | app_logger.info(f"dataset_id: {dataset_id}") 58 | app_logger.info(f"wavs_path: {wavs_path}") 59 | app_logger.info(f"csv_path: {csv_path}") 60 | app_logger.info(f"deliverable: {deliverable}") 61 | if language == "en": 62 | # do alignment first and then upload 63 | # aligned_wavs_dir, aligned_csv_path = align_wavs_whisper(self, wavs_path, csv_path, language, start_id_regex, end_id_regex, assigned_only=True) 64 | aligned_wavs_dir, aligned_csv_path = align_wavs_vad(self, wavs_path, csv_path, language, start_id_regex, end_id_regex, assigned_only=True) 65 | else: 66 | # do alignment first and then upload 67 | aligned_wavs_dir, aligned_csv_path = align_wavs_vad(self, wavs_path, csv_path, language, start_id_regex, end_id_regex, assigned_only=True) 68 | 69 | app_logger.debug(f"aligned_wavs_dir: {aligned_wavs_dir}") 70 | 71 | # Simulate a long-running job 72 | upload_wav_samples(self, session, dataset_id, aligned_csv_path, deliverable=deliverable) 73 | 74 | # TODO: make sure that you keep the aligned csv 75 | shutil.rmtree(wavs_path, ignore_errors=True) 76 | shutil.rmtree(csv_path, ignore_errors=True) 77 | 78 | 79 | def unsegmented_onboarding_job_sync( 80 | dataset_id: int, language: str, wavs_path: str, csv_path: str, start_id_regex: str, end_id_regex: str, deliverable: str = None 81 | ): 82 | 83 | app_logger.info("Starting unsegmented onboarding job") 84 | 85 | app_logger.info(f"dataset_id: {dataset_id}") 86 | app_logger.info(f"wavs_path: {wavs_path}") 87 | app_logger.info(f"csv_path: {csv_path}") 88 | app_logger.info(f"deliverable: {deliverable}") 89 | 90 | if language == "en": 91 | # do alignment first and then upload 92 | # aligned_wavs_dir, aligned_csv_path = align_wavs_whisper(None, wavs_path, csv_path, language, start_id_regex, end_id_regex, assigned_only=True) 93 | aligned_wavs_dir, aligned_csv_path = align_wavs_vad(None, wavs_path, csv_path, language, start_id_regex, end_id_regex, assigned_only=True) 94 | else: 95 | # do alignment first and then upload 96 | aligned_wavs_dir, aligned_csv_path = align_wavs_vad(None, wavs_path, csv_path, language, start_id_regex, end_id_regex, assigned_only=True) 97 | 98 | # TODO: make sure that you keep the aligned csv 99 | shutil.rmtree(wavs_path, ignore_errors=True) 100 | shutil.rmtree(csv_path, ignore_errors=True) 101 | 102 | app_logger.debug(f"aligned_wavs_dirh: {aligned_wavs_dir}") 103 | 104 | # Simulate a long-running job 105 | upload_wav_samples(None, session, dataset_id, aligned_csv_path, deliverable=deliverable) 106 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aixplain/tts-qa/288a920d37cfd75f5e1aec6d34efe7b624582d0f/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from src.paths import paths 3 | from dotenv import find_dotenv, load_dotenv 4 | 5 | load_dotenv(find_dotenv(paths.PROJECT_ROOT_DIR / "vars.env"), override=True) 6 | load_dotenv(find_dotenv(paths.PROJECT_ROOT_DIR / "secrets.env"), override=True) 7 | 8 | import time 9 | 10 | import librosa 11 | import pandas as pd 12 | import soundfile as sf 13 | from aixplain.factories.model_factory import ModelFactory 14 | from pyannote.audio import Model 15 | from pyannote.audio.pipelines import VoiceActivityDetection 16 | from pydub import AudioSegment 17 | from pydub.utils import mediainfo 18 | 19 | 20 | modelPyannote = Model.from_pretrained( 21 | "pyannote/segmentation", use_auth_token=os.getenv("HUGGINGFACE_TOKEN") 22 | ) 23 | vad_pipeline = VoiceActivityDetection(segmentation=modelPyannote) 24 | HYPER_PARAMETERS = { 25 | # onset/offset activation thresholds 26 | "onset": 0.5, 27 | "offset": 0.5, 28 | # remove speech regions shorter than that many seconds. 29 | "min_duration_on": 0.0, 30 | # fill non-speech regions shorter than that many seconds. 31 | "min_duration_off": 0.05, 32 | } 33 | vad_pipeline.instantiate(HYPER_PARAMETERS) 34 | 35 | 36 | api_keys_azure = { 37 | "en": { 38 | "id": "62fab6ecb39cca09ca5bc378", 39 | "model": ModelFactory.create_asset_from_id("62fab6ecb39cca09ca5bc378"), 40 | }, 41 | "es": { 42 | "id": "62fab6ecb39cca09ca5bc375", 43 | "model": ModelFactory.create_asset_from_id("62fab6ecb39cca09ca5bc375"), 44 | }, 45 | "fr": { 46 | "id": "62fab6ecb39cca09ca5bc389", 47 | "model": ModelFactory.create_asset_from_id("62fab6ecb39cca09ca5bc389"), 48 | }, 49 | "it": { 50 | "id": "62fab6ecb39cca09ca5bc353", 51 | "model": ModelFactory.create_asset_from_id("62fab6ecb39cca09ca5bc353"), 52 | }, 53 | "de": { 54 | "id": "62fab6ecb39cca09ca5bc334", 55 | "model": ModelFactory.create_asset_from_id("62fab6ecb39cca09ca5bc334"), 56 | }, 57 | } 58 | 59 | 60 | api_keys_aws = { 61 | "en": { 62 | "id": "60ddef908d38c51c5885dd1e", 63 | "model": ModelFactory.create_asset_from_id("60ddef908d38c51c5885dd1e"), 64 | }, 65 | "es": { 66 | "id": "60ddefd68d38c51c588608c6", 67 | "model": ModelFactory.create_asset_from_id("60ddefd68d38c51c588608c6"), 68 | }, 69 | "fr": { 70 | "id": "60ddefde8d38c51c58860d8d", 71 | "model": ModelFactory.create_asset_from_id("60ddefde8d38c51c58860d8d"), 72 | }, 73 | "it": { 74 | "id": "60ddefa38d38c51c5885e979", 75 | "model": ModelFactory.create_asset_from_id("60ddefa38d38c51c5885e979"), 76 | }, 77 | "de": { 78 | "id": "60ddefc48d38c51c5885fd69", 79 | "model": ModelFactory.create_asset_from_id("60ddefc48d38c51c5885fd69"), 80 | }, 81 | } 82 | 83 | 84 | def asr_and_trim_azure(s3path, language="en"): 85 | model = api_keys_azure[language]["model"] 86 | response = {} 87 | count = 0 88 | while count < 3 and response == {}: 89 | try: 90 | model_response = model.run(data=s3path, name=f"ASR model ({language})") 91 | if model_response["status"] == "SUCCESS": 92 | details = model_response["details"] 93 | 94 | df_details = pd.DataFrame(details) 95 | df_details.head() 96 | 97 | df_details["pauses"] = ( 98 | df_details["start_time"].shift(-1) - df_details["end_time"] 99 | ) 100 | df_details["pauses"] = df_details["pauses"].fillna(0) 101 | transcription = " ".join(df_details["text"]) 102 | 103 | start_time = df_details.loc[0, "start_time"] 104 | end_time = df_details.loc[len(df_details) - 1, "end_time"] 105 | audio_duration = end_time - start_time 106 | 107 | response["asr_text"] = transcription 108 | response["trim_start"] = start_time 109 | response["trim_end"] = end_time 110 | response["trimmed_audio_duration"] = audio_duration 111 | response["longest_pause"] = df_details["pauses"].max() 112 | return response 113 | except Exception as e: 114 | print(e) 115 | count += 1 116 | time.sleep(1) 117 | continue 118 | return { 119 | "asr_text": "", 120 | "trim_start": 0, 121 | "trim_end": 0, 122 | "trimmed_audio_duration": 0, 123 | "longest_pause": 0, 124 | } 125 | 126 | 127 | def trim_only(path): 128 | vad = vad_pipeline(path) 129 | timeline = vad.get_timeline().support() 130 | longest_pause = 0 131 | previous_end = 0 132 | # for timelines that has diff lover than 0.2 sec remove segment 133 | timeline_temp = [ 134 | segment for segment in timeline if list(segment)[1] - list(segment)[0] > 0.25 135 | ] 136 | if len(timeline_temp) > 0: 137 | timeline = timeline_temp 138 | for i, segment in enumerate(timeline): 139 | start, end = list(segment) 140 | if i == 0: 141 | start_time = start 142 | if i == len(timeline) - 1: 143 | end_time = end 144 | if i > 0: 145 | pause = start - previous_end 146 | if pause > longest_pause: 147 | longest_pause = pause 148 | previous_end = end 149 | try: 150 | audio_duration = end_time - start_time 151 | except: 152 | print(f"Error in audio duration calculation, do not triming file {path}") 153 | # get audio dur in secs 154 | audio_duration = librosa.get_duration(filename=path) 155 | end_time = audio_duration 156 | start_time = 0 157 | 158 | return { 159 | "trim_start": start_time, 160 | "trim_end": end_time, 161 | "trimmed_audio_duration": audio_duration, 162 | "longest_pause": longest_pause, 163 | } 164 | 165 | 166 | def asr_aws(s3path, language="en"): 167 | model = api_keys_aws[language]["model"] 168 | response = {} 169 | count = 0 170 | while count < 3 and response == {}: 171 | try: 172 | model_response = model.run(data=s3path, name=f"ASR model ({language})") 173 | if model_response["status"] == "SUCCESS": 174 | details = model_response["details"]["segments"] 175 | 176 | df_details = pd.DataFrame(details) 177 | df_details.dropna(inplace=True) 178 | 179 | transcription = " ".join(df_details["text"]) 180 | return transcription 181 | except Exception as e: 182 | print(e) 183 | count += 1 184 | time.sleep(1) 185 | continue 186 | return "" 187 | 188 | 189 | def asr_and_trim_aws(s3path, language="en"): 190 | model = api_keys_aws[language]["model"] 191 | response = {} 192 | count = 0 193 | while count < 3 and response == {}: 194 | try: 195 | model_response = model.run(data=s3path, name=f"ASR model ({language})") 196 | if model_response["status"] == "SUCCESS": 197 | details = model_response["details"]["segments"] 198 | 199 | df_details = pd.DataFrame(details) 200 | df_details.dropna(inplace=True) 201 | 202 | df_details["pauses"] = ( 203 | df_details["start_time"].shift(-1) - df_details["end_time"] 204 | ) 205 | df_details["pauses"] = df_details["pauses"].fillna(0) 206 | transcription = " ".join(df_details["text"]) 207 | 208 | start_time = df_details["start_time"].values[0] 209 | end_time = df_details["end_time"].values[-1] 210 | audio_duration = end_time - start_time 211 | 212 | response["asr_text"] = transcription 213 | response["trim_start"] = start_time 214 | response["trim_end"] = end_time 215 | response["trimmed_audio_duration"] = audio_duration 216 | response["longest_pause"] = df_details["pauses"].max() 217 | return response 218 | except Exception as e: 219 | print(e) 220 | count += 1 221 | time.sleep(1) 222 | continue 223 | return { 224 | "asr_text": "", 225 | "trim_start": 0, 226 | "trim_end": 0, 227 | "trimmed_audio_duration": 0, 228 | "longest_pause": 0, 229 | } 230 | 231 | 232 | def evaluate_audio(path): 233 | response = {} 234 | info = mediainfo(path) 235 | sound = AudioSegment.from_file(path) 236 | response["Filepath"] = path 237 | response["file_type"] = os.path.splitext(path)[1] 238 | response["file_name"] = os.path.basename(path) 239 | response["sampling_rate"] = int(info["sample_rate"]) 240 | response["sample_format"] = info["sample_fmt"] 241 | response["format"] = info["format_name"] 242 | response["n_channel"] = int(info["channels"]) 243 | response["bit_rate"] = int(info["bit_rate"]) 244 | response["codec"] = info["codec_name"] 245 | response["peak_volume_db"] = sound.max_dBFS 246 | response["duration"] = float(info["duration"]) 247 | response["size"] = os.path.getsize(path) 248 | response["is_wav"] = True if info["format_name"] == "wav" else False 249 | response["is_mono"] = True if info["channels"] == "1" else False 250 | response["isPCM"] = True if response["codec"] == "pcm_s16le" else False 251 | response["is_16bit"] = True if response["sample_format"] == "s16" else False 252 | response["is_88khz"] = True if response["sampling_rate"] == 88000 else False 253 | 254 | is_valid = False 255 | if ( 256 | response["is_wav"] 257 | and response["is_mono"] 258 | and response["is_16bit"] 259 | and response["isPCM"] 260 | and response["is_88khz"] 261 | and response["peak_volume_db"] >= -6 262 | and response["peak_volume_db"] <= -3 263 | and response["duration"] > 0 264 | ): 265 | is_valid = True 266 | response["isValid"] = is_valid 267 | return response 268 | 269 | 270 | # convert the sampling rate to 88kHz 271 | def convert_to_88k(path, out_path): 272 | y, sr = librosa.load(path, sr=None) 273 | y_88k = librosa.resample(y, orig_sr=sr, target_sr=88000) 274 | sf.write(out_path, y_88k, 88000) 275 | return out_path 276 | 277 | 278 | # normalize the audio peak_volume_db to be between -6 and -3 db 279 | def normalize_audio(path, out_path): 280 | sound = AudioSegment.from_file(path, format="wav") 281 | if sound.max_dBFS > -3: 282 | normalized_sound = sound - (sound.max_dBFS + 3.5) 283 | elif sound.max_dBFS < -6: 284 | normalized_sound = sound + (-5.5 - sound.max_dBFS) 285 | else: 286 | normalized_sound = sound 287 | normalized_sound.export(out_path, format="wav") 288 | return out_path 289 | 290 | 291 | # trim the audio using start end end time in secs 292 | def trim_audio(path, start, end, out_path): 293 | sound = AudioSegment.from_file(path, format="wav") 294 | # make sure that the start and end are in between the audio duration 295 | start_time = max(0, start) 296 | end_time = min(end, len(sound) / 1000) 297 | trimmed_sound = sound[start_time * 1000 : end_time * 1000] 298 | trimmed_sound.export(out_path, format="wav") 299 | return out_path, start_time, end_time 300 | 301 | 302 | # convert the audio to mono 303 | def convert_to_mono(path, out_path): 304 | sound = AudioSegment.from_file(path, format="wav") 305 | mono_sound = sound.set_channels(1) 306 | mono_sound.export(out_path, format="wav") 307 | return out_path 308 | 309 | 310 | def convert_to_s16le(path, out_path): 311 | sound = AudioSegment.from_file(path, format="wav") 312 | s16le_sound = sound.set_sample_width(2) 313 | s16le_sound.export(out_path, format="wav") 314 | return out_path 315 | -------------------------------------------------------------------------------- /src/utils/decorators.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import wraps 3 | 4 | from src.logger import root_logger 5 | 6 | 7 | log = root_logger.getChild("timing") 8 | 9 | 10 | def log_timing(func): 11 | # we use a decorator to log the time it takes to execute a function 12 | # this saves boilerplate code in each function 13 | @wraps(func) 14 | def wrapper(*args, **kwargs): 15 | start = time.time() 16 | result = func(*args, **kwargs) 17 | duration = time.time() - start 18 | log.info(f"{func.__name__} took {duration} seconds ({duration / 60:.2f} minutes)") 19 | return result 20 | 21 | return wrapper 22 | -------------------------------------------------------------------------------- /src/utils/experiment_tracking.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from datetime import datetime 4 | from pathlib import Path 5 | from typing import Optional, Tuple 6 | 7 | import wandb 8 | from dotenv import load_dotenv 9 | 10 | from src.enums import RunType 11 | from src.logger import root_logger 12 | 13 | 14 | log = root_logger.getChild("utils") 15 | # we read the WANDB_API_KEY from the secrets.env file 16 | load_dotenv("secrets.env") 17 | # and the WANDB_PROJECT and WANDB_ENTITY from the vars.env file 18 | load_dotenv("vars.env") 19 | 20 | 21 | root_path = Path(__file__).parent.parent 22 | 23 | 24 | with open(root_path / "models" / "configs" / "sweep_config.json") as f: 25 | sweep_config = json.load(f) 26 | 27 | 28 | def init_wandb_run( 29 | model_name: str, 30 | run_type: RunType, 31 | run_tag: str, 32 | run_dir: Path = None, 33 | collaborator: str = os.getlogin(), 34 | existing_wandb_run_id: str = None, 35 | config: dict = None, 36 | ) -> str: 37 | # with an existing wandb run id, we resume the run 38 | if existing_wandb_run_id: 39 | resume_wandb_run(existing_wandb_run_id) 40 | return existing_wandb_run_id 41 | 42 | # get variables from the environment 43 | wandb_project, wandb_entity = get_wandb_variables() 44 | # skip logging to wandb if the variables are not set 45 | if not wandb_project or not wandb_entity: 46 | log.warning("Skipping W&B logging.") 47 | return None 48 | 49 | # create the run and group names 50 | run_type = RunType(run_type) 51 | group_name, run_name = create_wandb_run_and_group_names( 52 | run_type=run_type, 53 | model_name=model_name, 54 | run_tag=run_tag, 55 | collaborator=collaborator, 56 | ) 57 | 58 | # initialize the run 59 | log.info("Initializing new W&B run") 60 | wandb_run = wandb.init( 61 | entity=wandb_entity, 62 | project=wandb_project, 63 | group=group_name, 64 | job_type=run_type.value, 65 | name=run_name, 66 | resume="allow", 67 | config=config, 68 | ) 69 | 70 | # if the run dir is not None, we save the run id to the run dir for easier resuming later 71 | if run_dir: 72 | run_dir = Path(run_dir) 73 | run_dir.mkdir(parents=True, exist_ok=True) 74 | with open(run_dir / "wandb_run_id", "w") as f: 75 | f.write(wandb_run.id) 76 | 77 | log.info(f"W&B run initialized: {wandb_run.id=}, {wandb_run.name=}") 78 | return wandb_run.id 79 | 80 | 81 | def init_wandb_sweep() -> str: 82 | # get variables from the environment 83 | wandb_project, wandb_entity = get_wandb_variables() 84 | 85 | sweep_id = wandb.sweep(sweep=sweep_config, project=wandb_project, entity=wandb_entity) 86 | return sweep_id 87 | 88 | 89 | def create_wandb_run_and_group_names( 90 | run_type: RunType, 91 | model_name: str, 92 | run_tag: str, 93 | collaborator: str = os.getlogin(), 94 | ) -> Tuple[str, str]: 95 | # if the tag is a model path, use the file stem 96 | if model_name and Path(model_name).exists(): 97 | model_name = Path(model_name).stem 98 | 99 | log.info(f"Creating W&B run and group names for {run_type=}, {model_name=}, {collaborator=}, {run_tag=}") 100 | run_type_str = RunType(run_type).value 101 | group_name = "-".join([collaborator, run_tag]) 102 | run_name = "-".join([run_type_str, model_name, datetime.now().strftime("%Y%m%d-%H%M%S")]) 103 | return group_name, run_name 104 | 105 | 106 | def resume_wandb_run(wandb_run_id: str) -> None: 107 | log.info(f"Resuming W&B run {wandb_run_id}") 108 | wandb_project = os.environ.get("WANDB_PROJECT", None) 109 | wandb_entity = os.environ.get("WANDB_ENTITY", None) 110 | wandb.init(id=wandb_run_id, project=wandb_project, entity=wandb_entity) 111 | 112 | 113 | def get_wandb_variables() -> Tuple[Optional[str], Optional[str]]: 114 | if os.environ.get("WANDB_DISABLED", False): 115 | log.warning("W&B logging is disabled as WANDB_DISABLED is set (run `unset WANDB_DISABLED` to re-enable)") 116 | return None, None 117 | 118 | if not os.environ.get("WANDB_API_KEY"): 119 | log.warning("WANDB_API_KEY not found in ENV. Make sure you have a secrets.env files with a value for it if you want to log to W&B.") 120 | 121 | wandb_project = os.environ.get("WANDB_PROJECT", None) 122 | if not wandb_project: 123 | log.warning("WANDB_PROJECT not found in ENV. Make sure you have a vars.env files with a value for it if you want to log to W&B.") 124 | 125 | wandb_entity = os.environ.get("WANDB_ENTITY", None) 126 | if not wandb_entity: 127 | log.warning("WANDB_ENTITY not found in ENV. Make sure you have a vars.env files with a value for it if you want to log to W&B.") 128 | 129 | return wandb_project, wandb_entity 130 | -------------------------------------------------------------------------------- /src/utils/trim_asr.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import traceback 3 | 4 | 5 | logging.basicConfig(level=logging.DEBUG) 6 | import os 7 | 8 | import librosa 9 | from tqdm import tqdm 10 | 11 | from src.paths import paths 12 | from dotenv import find_dotenv, load_dotenv 13 | load_dotenv(find_dotenv(paths.PROJECT_ROOT_DIR / "vars.env"), override=True) 14 | 15 | import sys 16 | from concurrent.futures import as_completed, ThreadPoolExecutor # noqa: F401 17 | from pathlib import Path 18 | 19 | import boto3 20 | import botocore 21 | from dotenv import load_dotenv 22 | from sqlalchemy import create_engine 23 | from sqlalchemy.orm import sessionmaker 24 | 25 | 26 | sys.path.append(str(Path(__file__).resolve().parents[2])) 27 | 28 | 29 | from src.logger import root_logger 30 | from src.paths import paths 31 | from src.service.models import Annotation, Annotator, Base, Dataset, Sample # noqa: F401 32 | from src.utils import utils 33 | from src.utils.audio import asr_and_trim_aws, asr_and_trim_azure, asr_aws, trim_audio, trim_only 34 | 35 | 36 | app_logger = root_logger.getChild("trimmer") 37 | 38 | 39 | BASE_DIR = paths.PROJECT_ROOT_DIR 40 | 41 | 42 | if load_dotenv(os.path.join(BASE_DIR, "vars.env")): 43 | app_logger.info("Loaded env vars from vars.env") 44 | else: 45 | app_logger.error("Failed to load env vars from vars.env") 46 | exit(1) 47 | 48 | 49 | # get engine from url 50 | POSTGRES_URL = os.getenv("POSTGRES_URL") 51 | 52 | engine = create_engine(POSTGRES_URL) 53 | Session = sessionmaker(bind=engine) 54 | session = Session() 55 | 56 | 57 | client_config = botocore.config.Config(max_pool_connections=50) 58 | s3 = boto3.client( 59 | "s3", aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"), config=client_config 60 | ) 61 | bucket_name = os.environ.get("S3_BUCKET_NAME") 62 | dataset_dir = os.environ.get("S3_DATASET_DIR") 63 | 64 | offset = 0.2 65 | 66 | # from src.utils.whisper_model import WhisperTimestampedASR 67 | 68 | 69 | # whisper_model = WhisperTimestampedASR(model_size="medium", language="english", device="cuda") 70 | 71 | lang_map = { 72 | "en": "english", 73 | "fr": "french", 74 | "es": "spanish", 75 | "de": "german", 76 | "it": "italian", 77 | } 78 | 79 | import string 80 | 81 | 82 | def remove_punctuation(input_string): 83 | # Make a translator object to remove all punctuation 84 | translator = str.maketrans("", "", string.punctuation) 85 | return input_string.translate(translator) 86 | 87 | 88 | def wer_wo_punctuation(reference, hypothesis): 89 | reference = remove_punctuation(reference) 90 | hypothesis = remove_punctuation(hypothesis) 91 | 92 | uncased_unpunctuated_wer = float(utils.calculate_wer(reference, hypothesis)) 93 | return uncased_unpunctuated_wer 94 | 95 | 96 | def asr_and_trim_(session_, sample, language, use="azure"): 97 | if use == "azure": 98 | response = asr_and_trim_azure(sample.s3RawPath, language) 99 | elif use == "aws": 100 | response = asr_and_trim_aws(sample.s3RawPath, language) 101 | 102 | start = float(response["trim_start"]) - offset 103 | end = float(response["trim_end"]) + offset 104 | out_path, start, end = trim_audio(sample.local_path, start, end, sample.local_path.replace("raw", "trimmed")) 105 | 106 | sample.trim_start = round(float(start), 2) 107 | sample.trim_end = round(float(end), 2) 108 | sample.trimmed_audio_duration = round(float(end - start), 2) 109 | sample.longest_pause = round(float(response["longest_pause"]), 2) 110 | sample.asr_text = str(response["asr_text"]) 111 | sample.wer = round(float(utils.calculate_wer(sample.original_text.lower(), sample.asr_text.lower())), 2) 112 | 113 | # update sample 114 | object_key = out_path.split(f"{str(paths.LOCAL_BUCKET_DIR)}/")[1] 115 | s3TrimmedPath = f"s3://{bucket_name}/{object_key}" 116 | 117 | sample.local_trimmed_path = out_path 118 | sample.s3TrimmedPath = str(s3TrimmedPath) 119 | session_.add(sample) 120 | s3.upload_file(out_path, bucket_name, object_key) 121 | session_.commit() 122 | 123 | 124 | def trim_only_(session_, sample, language): 125 | response = trim_only(sample.local_path) 126 | 127 | start = float(response["trim_start"]) - offset 128 | end = float(response["trim_end"]) + offset 129 | out_path, start, end = trim_audio(sample.local_path, start, end, sample.local_path.replace("raw", "trimmed")) 130 | 131 | sample.trim_start = round(float(start), 2) 132 | sample.trim_end = round(float(end), 2) 133 | sample.trimmed_audio_duration = round(float(end - start), 2) 134 | sample.longest_pause = round(float(response["longest_pause"]), 2) 135 | sample.wer = round(float(utils.calculate_wer(sample.original_text.lower(), sample.asr_text.lower())), 2) 136 | # update sample 137 | object_key = out_path.split(f"{str(paths.LOCAL_BUCKET_DIR)}/")[1] 138 | s3TrimmedPath = f"s3://{bucket_name}/{object_key}" 139 | 140 | sample.local_trimmed_path = out_path 141 | sample.s3TrimmedPath = str(s3TrimmedPath) 142 | session_.add(sample) 143 | s3.upload_file(out_path, bucket_name, object_key) 144 | session_.commit() 145 | 146 | 147 | def trim_and_asr_(sample, language): 148 | response = trim_only(sample.local_path) 149 | 150 | start = float(response["trim_start"]) - offset 151 | end = float(response["trim_end"]) + offset 152 | out_path, start, end = trim_audio(sample.local_path, start, end, sample.local_path.replace("raw", "trimmed")) 153 | 154 | # update sample 155 | object_key = out_path.split(f"{str(paths.LOCAL_BUCKET_DIR)}/")[1] 156 | s3TrimmedPath = f"s3://{bucket_name}/{object_key}" 157 | 158 | sample.local_trimmed_path = out_path 159 | sample.s3TrimmedPath = str(s3TrimmedPath) 160 | 161 | s3.upload_file(out_path, bucket_name, object_key) 162 | 163 | asr = asr_aws(str(s3TrimmedPath), language) 164 | sample.asr_text = str(asr) 165 | sample.trim_start = round(float(start), 2) 166 | sample.trim_end = round(float(end), 2) 167 | sample.trimmed_audio_duration = round(float(end - start), 2) 168 | sample.longest_pause = round(float(response["longest_pause"]), 2) 169 | sample.wer = round(float(utils.calculate_wer(sample.original_text.lower(), str(asr).lower())), 2) 170 | sample.uncased_unpunctuated_wer = round(float(wer_wo_punctuation(sample.original_text.lower(), str(asr).lower())), 2) 171 | return sample 172 | 173 | 174 | def asr_only_(sample, language): 175 | asr = asr_aws(sample.s3TrimmedPath, language) 176 | sample.asr_text = str(asr) 177 | # get duration 178 | sample.trimmed_audio_duration = librosa.get_duration(filename=sample.local_trimmed_path) 179 | sample.wer = round(float(utils.calculate_wer(sample.original_text.lower(), str(asr).lower())), 2) 180 | sample.uncased_unpunctuated_wer = round(float(wer_wo_punctuation(sample.original_text.lower(), str(asr).lower())), 2) 181 | return sample 182 | 183 | 184 | def process_datasets(): 185 | datasets = session.query(Dataset).all() 186 | for dataset in datasets: 187 | if "English (Alyssa)" in dataset.name: 188 | continue 189 | print(f"Processing dataset: {dataset.name}") 190 | app_logger.info(f"Processing dataset: {dataset.name}") 191 | 192 | language = dataset.language 193 | 194 | samples = ( 195 | session.query(Sample) 196 | .filter(Sample.dataset_id == dataset.id) 197 | .filter( 198 | (Sample.local_trimmed_path == None) 199 | | (Sample.local_path == None) 200 | | (Sample.s3TrimmedPath == None) 201 | | (Sample.s3RawPath == None) 202 | | (Sample.asr_text == None) 203 | | (Sample.trim_start == None) 204 | | (Sample.trim_end == None) 205 | | (Sample.trimmed_audio_duration == None) 206 | | (Sample.trimmed_audio_duration == 0) 207 | | (Sample.longest_pause == None) 208 | | (Sample.wer == None) 209 | ) 210 | .all() 211 | ) 212 | # if len(samples) > 0: 213 | # whisper_model.unload() 214 | # whisper_model.load(language=lang_map[language]) 215 | while samples: 216 | with ThreadPoolExecutor(max_workers=32) as executor: 217 | futures = [executor.submit(trim_and_asr_, sample, language) for sample in samples] 218 | 219 | # Use tqdm to display the progress of processing finished samples 220 | for future in tqdm(as_completed(futures), total=len(futures), desc="Processing samples"): 221 | try: 222 | sample = future.result() 223 | # Detach the sample from the main session 224 | session.expunge(sample) 225 | 226 | # Create a new temporary session 227 | Session = sessionmaker(bind=engine) 228 | tmp_session = Session() 229 | 230 | # Add the sample to the temporary session and commit the changes 231 | tmp_session.add(sample) 232 | tmp_session.commit() 233 | 234 | # Close the temporary session 235 | tmp_session.close() 236 | except Exception as e: 237 | app_logger.error(f"Error processing sample: traceback: {traceback.format_exc()}") 238 | 239 | # get samples with asr_text = null 240 | samples = ( 241 | session.query(Sample) 242 | .filter(Sample.dataset_id == dataset.id) 243 | .filter( 244 | (Sample.local_trimmed_path == None) 245 | | (Sample.local_path == None) 246 | | (Sample.s3TrimmedPath == None) 247 | | (Sample.s3RawPath == None) 248 | | (Sample.asr_text == None) 249 | | (Sample.trim_start == None) 250 | | (Sample.trim_end == None) 251 | | (Sample.trimmed_audio_duration == None) 252 | | (Sample.trimmed_audio_duration == 0) 253 | | (Sample.longest_pause == None) 254 | | (Sample.wer == None) 255 | ) 256 | .all() 257 | ) 258 | 259 | app_logger.info(f"Finished processing dataset: {dataset.name}") 260 | 261 | 262 | if __name__ == "__main__": 263 | app_logger.info("Starting to process all datasets") 264 | process_datasets() 265 | session.close() 266 | app_logger.info("Finished processing all datasets") 267 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import boto3 4 | import jiwer 5 | 6 | 7 | # define a class map 8 | def s3_link_handler(s3_link): 9 | # check if s3 10 | if s3_link.startswith("s3://"): 11 | # get bucket name 12 | bucket_name = s3_link.split("/")[2] 13 | # get object name 14 | object_path = "/".join(s3_link.split("/")[3:]) 15 | return bucket_name, object_path 16 | 17 | 18 | def upload_to_s3(bucket_name, object_path, file_path): 19 | # upload to s3 20 | s3 = boto3.client("s3", aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"]) 21 | 22 | s3.upload_file(file_path, bucket_name, object_path) 23 | return f"s3://{bucket_name}/{object_path}" 24 | 25 | 26 | def calculate_wer(reference, hypothesis): 27 | return jiwer.wer(reference, hypothesis) 28 | -------------------------------------------------------------------------------- /src/utils/wer_wo_punct.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from src.paths import paths 4 | from dotenv import find_dotenv, load_dotenv 5 | load_dotenv(find_dotenv(paths.PROJECT_ROOT_DIR / "vars.env"), override=True) 6 | 7 | import sys 8 | from concurrent.futures import ThreadPoolExecutor # noqa: F401 9 | from pathlib import Path 10 | 11 | import boto3 12 | from dotenv import load_dotenv 13 | from sqlalchemy import create_engine 14 | from sqlalchemy.orm import sessionmaker 15 | 16 | 17 | sys.path.append(str(Path(__file__).resolve().parents[2])) 18 | 19 | 20 | from src.logger import root_logger 21 | from src.paths import paths 22 | from src.service.models import Annotation, Annotator, Base, Dataset, Sample # noqa: F401 23 | from src.utils import utils 24 | from src.utils.audio import asr_and_trim_aws, asr_and_trim_azure, trim_audio, trim_only # noqa: F401 25 | 26 | 27 | app_logger = root_logger.getChild("trimmer") 28 | 29 | BASE_DIR = paths.PROJECT_ROOT_DIR 30 | 31 | 32 | if load_dotenv(os.path.join(BASE_DIR, "vars.env")): 33 | app_logger.info("Loaded env vars from vars.env") 34 | else: 35 | app_logger.error("Failed to load env vars from vars.env") 36 | exit(1) 37 | 38 | 39 | # get engine from url 40 | POSTGRES_URL = os.getenv("POSTGRES_URL") 41 | 42 | engine = create_engine(POSTGRES_URL) 43 | Session = sessionmaker(bind=engine) 44 | session = Session() 45 | 46 | s3 = boto3.client("s3", aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY")) 47 | bucket_name = os.environ.get("S3_BUCKET_NAME") 48 | dataset_dir = os.environ.get("S3_DATASET_DIR") 49 | 50 | offset = 0.2 51 | 52 | lang_map = { 53 | "en": "english", 54 | "fr": "french", 55 | "es": "spanish", 56 | "de": "german", 57 | "it": "italian", 58 | } 59 | 60 | 61 | import string 62 | 63 | 64 | def remove_punctuation(input_string): 65 | # Make a translator object to remove all punctuation 66 | translator = str.maketrans("", "", string.punctuation) 67 | return input_string.translate(translator) 68 | 69 | 70 | def wer_wo_punctuation(session_, sample): 71 | reference = remove_punctuation(sample.original_text.lower()) 72 | hypothesis = remove_punctuation(sample.asr_text.lower()) 73 | 74 | sample.uncased_unpunctuated_wer = round(float(utils.calculate_wer(reference, hypothesis)), 2) 75 | session_.add(sample) 76 | session_.commit() 77 | 78 | 79 | def process_datasets(): 80 | datasets = session.query(Dataset).all() 81 | 82 | for dataset in datasets: 83 | 84 | print(f"Processing dataset: {dataset.name}") 85 | app_logger.info(f"Processing dataset: {dataset.name}") 86 | 87 | samples = session.query(Sample).filter(Sample.dataset_id == dataset.id).filter(Sample.uncased_unpunctuated_wer == None).all() 88 | 89 | while samples: # Continue while there are samples left 90 | with ThreadPoolExecutor(max_workers=8) as executor: 91 | # Remove sample from list and process 92 | while samples: 93 | sample = samples.pop(0) 94 | executor.submit(wer_wo_punctuation, session, sample) 95 | 96 | app_logger.info(f"Finished processing dataset: {dataset.name}") 97 | 98 | 99 | if __name__ == "__main__": 100 | process_datasets() 101 | app_logger.info("Finished processing all datasets") 102 | -------------------------------------------------------------------------------- /src/utils/whisper_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import traceback 4 | import urllib.request 5 | from typing import Dict 6 | from uuid import uuid4 7 | 8 | import whisper 9 | import whisper_timestamped as whisperts 10 | 11 | from src.logger import root_logger 12 | 13 | 14 | app_logger = root_logger.getChild("whisper-asr") 15 | 16 | MODEL_DIR = os.getenv("MODEL_DIR", "/mnt/models") 17 | MODEL_NAME = os.getenv("MODEL_NAME", "whisper_model") 18 | 19 | lang_map = { 20 | "en": "english", 21 | "fr": "french", 22 | "es": "spanish", 23 | "de": "german", 24 | "it": "italian", 25 | } 26 | 27 | inverse_lang_map = {v: k for k, v in lang_map.items()} 28 | 29 | 30 | class WhisperTimestampedASR: 31 | def __init__(self, model_size="tiny", language="english", device="cpu"): 32 | app_logger.info(f"Initializeing Whisper model: {model_size}") 33 | self.model = None 34 | self.ready = False 35 | self.device = device 36 | self.transcribe_options = dict(detect_disfluencies=True, vad=True, verbose=None, language=inverse_lang_map[language]) 37 | self.model_size = model_size 38 | 39 | def load(self, language: str = None): 40 | app_logger.info(f"Loading Whisper model: {self.model_size}") 41 | if self.ready: 42 | app_logger.warning("Whisper model already loaded need to unload first") 43 | self.unload() 44 | if language: 45 | app_logger.info(f"Setting language to {language}") 46 | self.transcribe_options["language"] = language 47 | self.model = whisperts.load_model(self.model_size, device=self.device) 48 | self.ready = True 49 | app_logger.info(f"Whisper model loaded") 50 | 51 | def unload(self): 52 | self.model = None 53 | self.ready = False 54 | app_logger.info(f"Whisper model unloaded") 55 | 56 | def get_file_from_url(self, url, tempdir): 57 | _, extension = os.path.splitext(url) 58 | extension = extension[1:] 59 | extension = extension.split("?")[0] 60 | input_path = os.path.join(tempdir, f"{str(uuid4())}.{extension}") 61 | urllib.request.urlretrieve(url, input_path) 62 | return input_path 63 | 64 | def predict(self, request: Dict) -> Dict: 65 | try: 66 | transcriptions = [] 67 | segments = [] 68 | disfluencies = [] 69 | inputs = request["instances"] 70 | with tempfile.TemporaryDirectory(prefix="whisper-asr-") as tempdir: 71 | for request in inputs: 72 | # check if url is s3 link or local 73 | if request["url"].startswith("s3://"): 74 | audio_file = self.get_file_from_url(request["url"], tempdir) 75 | else: 76 | audio_file = request["url"] 77 | audio = whisperts.load_audio(audio_file) 78 | results = whisperts.transcribe(self.model, audio, **self.transcribe_options) 79 | text = results["text"] 80 | segments = results["segments"] 81 | # remove spaces at the beginning and end of the string 82 | text = text.strip() 83 | 84 | raw_words = [w["text"] for s in segments for w in s["words"]] 85 | is_disfluent = False 86 | if "[*]" in raw_words: 87 | is_disfluent = True 88 | disfluencies.append(is_disfluent) 89 | transcriptions.append(text) 90 | segments.append(segments) 91 | 92 | return {"predictions": transcriptions, "segments": segments, "disfluencies": disfluencies} 93 | except ValueError as e: 94 | print(traceback.format_exc()) 95 | raise ValueError(f"Failed to process request: {e}") 96 | 97 | 98 | class WhisperASR: 99 | def __init__(self, model_size="tiny", language="English"): 100 | app_logger.info(f"Initializeing Whisper model: {model_size}") 101 | self.model = None 102 | self.ready = False 103 | options = dict(language=language) 104 | self.transcribe_options = dict(task="transcribe", **options) 105 | self.model_size = model_size 106 | 107 | def load(self, language: str = None): 108 | app_logger.info(f"Loading Whisper model: {self.model_size}") 109 | if self.ready: 110 | app_logger.warning("Whisper model already loaded need to unload first") 111 | return True 112 | if language: 113 | app_logger.info(f"Setting language to {language}") 114 | self.transcribe_options["language"] = language 115 | model_path = os.path.join(MODEL_DIR, MODEL_NAME) 116 | self.model = whisper.load_model(self.model_size, download_root=model_path) 117 | self.ready = True 118 | app_logger.info(f"Whisper model loaded") 119 | 120 | def unload(self): 121 | self.model = None 122 | self.ready = False 123 | app_logger.info(f"Whisper model unloaded") 124 | 125 | def get_file_from_url(self, url, tempdir): 126 | _, extension = os.path.splitext(url) 127 | extension = extension[1:] 128 | extension = extension.split("?")[0] 129 | input_path = os.path.join(tempdir, f"{str(uuid4())}.{extension}") 130 | urllib.request.urlretrieve(url, input_path) 131 | return input_path 132 | 133 | def predict(self, request: Dict) -> Dict: 134 | try: 135 | transcriptions = [] 136 | inputs = request["instances"] 137 | with tempfile.TemporaryDirectory(prefix="whisper-asr-") as tempdir: 138 | for request in inputs: 139 | # check if url is s3 link or local 140 | if request["url"].startswith("s3://"): 141 | audio_file = self.get_file_from_url(request["url"], tempdir) 142 | else: 143 | audio_file = request["url"] 144 | 145 | transcriptions.append(self.model.transcribe(audio_file, **self.transcribe_options)["text"]) 146 | 147 | return {"predictions": transcriptions} 148 | except ValueError as e: 149 | print(traceback.format_exc()) 150 | raise ValueError(f"Failed to process request: {e}") 151 | -------------------------------------------------------------------------------- /src/web_app/admin/pages/01_📤_Upload.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | import os 4 | import re 5 | import shutil 6 | import sys 7 | import tempfile 8 | import zipfile 9 | from glob import glob 10 | 11 | import pandas as pd 12 | import requests 13 | import streamlit as st 14 | 15 | 16 | current_file_path = os.path.dirname(os.path.abspath(__file__)) 17 | # aapedn 3 parent directories to the path 18 | sys.path.append(os.path.join(current_file_path, "..", "..", "..", "..")) 19 | 20 | from dotenv import load_dotenv 21 | 22 | from src.logger import root_logger 23 | from src.paths import paths 24 | 25 | 26 | BASE_DIR = str(paths.PROJECT_ROOT_DIR.resolve()) 27 | # load the .env file 28 | load_dotenv(os.path.join(BASE_DIR, "vars.env")) 29 | 30 | 31 | def run_regex(wav_files, start_id_regex, end_id_regex): 32 | failed_files = [] 33 | for wav_file in wav_files: 34 | start_id = re.findall(start_id_regex, wav_file) 35 | end_id = re.findall(end_id_regex, wav_file) 36 | if start_id and end_id: 37 | start_id = int(start_id[0]) 38 | end_id = int(end_id[0]) 39 | else: 40 | failed_files.append(wav_file) 41 | return wav_files, failed_files 42 | 43 | 44 | lang_map = { 45 | "English": "en", 46 | "German": "de", 47 | "French": "fr", 48 | "Spanish": "es", 49 | "Italian": "it", 50 | } 51 | 52 | app_logger = root_logger.getChild("web_app::create_dataset") 53 | BACKEND_URL = "http://{}:{}".format(os.environ.get("SERVER_HOST"), os.environ.get("SERVER_PORT")) 54 | 55 | 56 | def app(): 57 | if "authentication_status" not in st.session_state: 58 | # forward to the page where the user can login 59 | st.warning("Please login first") 60 | st.stop() 61 | else: 62 | user = requests.get(f"{BACKEND_URL}/annotators/username/{st.session_state['username']}").json() 63 | if not user["isadmin"]: 64 | st.error("You are not authorized to access this page") 65 | st.stop() 66 | 67 | with st.sidebar: 68 | if st.session_state["authentication_status"]: 69 | st.write(f'Welcome *{st.session_state["name"]}*') 70 | 71 | sample_df = pd.read_csv(os.path.join(BASE_DIR, "src", "web_app", "admin", "data", "sample_csv.csv")) 72 | sample_zip_path = os.path.join(BASE_DIR, "src", "web_app", "admin", "data", "sample_zip.zip") 73 | 74 | st.title("TTS Datasets") 75 | st.write("Create a new TTS dataset or select an existing one") 76 | if "dataset" not in st.session_state: 77 | st.session_state["dataset"] = {} 78 | 79 | if "failed_files" not in st.session_state: 80 | st.session_state["failed_files"] = [] 81 | 82 | if "job_id" not in st.session_state: 83 | st.session_state["job_id"] = None 84 | 85 | def get_datasets(): 86 | return requests.get(BACKEND_URL + "/datasets").json() 87 | 88 | datasets = get_datasets() 89 | 90 | # either select a dataset or create a new one 91 | selected_dataset_name = st.selectbox("Dataset", [dataset["name"] for dataset in datasets] + ["Create New TTS Dataset"]) 92 | 93 | if selected_dataset_name == "Create New TTS Dataset": 94 | dataset_name = st.text_input("Dataset Name") 95 | dataset_description = st.text_input("Dataset Description") 96 | dataset_language = st.selectbox("Language of Dataset", ["English", "German", "French", "Spanish", "Italian"]) 97 | 98 | if st.button("Create Dataset"): 99 | # create new dataset object 100 | params = { 101 | "language": lang_map[dataset_language], 102 | "description": dataset_description, 103 | } 104 | dataset = requests.post(BACKEND_URL + f"/datasets/{dataset_name}", params=params).json() 105 | st.success("Dataset created successfully") 106 | st.session_state["dataset"] = dataset 107 | st.experimental_rerun() 108 | else: 109 | # get selected dataset object 110 | selected_dataset = [dataset for dataset in datasets if dataset["name"] == selected_dataset_name][0] 111 | st.session_state["dataset"] = selected_dataset 112 | 113 | # after selctiong or creating a dataset, upload a csv file 114 | if st.session_state["dataset"] != {}: 115 | # Add a selectox for deleting, updating dataset 116 | dataset_options = ["Upload Recordings", "Update Dataset", "Delete Dataset"] 117 | selected_dataset_option = st.selectbox("Actions", dataset_options) 118 | if selected_dataset_option == "Delete Dataset": 119 | if st.button("Delete Dataset"): 120 | r = requests.delete(BACKEND_URL + "/datasets/{}".format(st.session_state["dataset"]["id"])) 121 | st.session_state["dataset"] = {} 122 | st.success("Dataset deleted successfully") 123 | 124 | elif selected_dataset_option == "Update Dataset": 125 | dataset_name = st.text_input("Dataset Name") 126 | dataset_description = st.text_input("Dataset Description") 127 | dataset_language = st.selectbox("Language of Dataset", ["English", "German", "French", "Spanish", "Italian"]) 128 | 129 | if st.button("Update Dataset"): 130 | r = requests.put( 131 | BACKEND_URL + "/datasets/{}".format(st.session_state["dataset"]["id"]), 132 | json={"dataset_name": dataset_name, "description": dataset_description, "language": lang_map[dataset_language]}, 133 | ) 134 | st.write(r.json()) 135 | st.session_state["dataset"] = r.json() 136 | elif selected_dataset_option == "Upload Recordings": 137 | deliverable = st.text_input("Deliverable Name") 138 | col1, col2 = st.columns(2) 139 | with col1: 140 | # show a sample csv 141 | csv = sample_df.to_csv(index=False) 142 | b64 = base64.b64encode(csv.encode()).decode() 143 | href = f'Download example csv file' 144 | st.markdown(href, unsafe_allow_html=True) 145 | 146 | # upload csv file 147 | uploaded_file = st.file_uploader("Upload CSV File", type=["csv"]) 148 | 149 | with col2: 150 | 151 | # show a sample zip 152 | b64 = base64.b64encode(open(sample_zip_path, "rb").read()).decode() 153 | href = f'Download example zip file' 154 | st.markdown(href, unsafe_allow_html=True) 155 | # upload zip file 156 | uploaded_zip_file = st.file_uploader("Upload WAVs as zip", type=["zip"]) 157 | # checkif the recordings already segmented and zipped or not using a checkbox 158 | segmented = st.checkbox("Check if recordings are already segmented", value=True) 159 | 160 | if not segmented: 161 | # ask for a regex term to define start id and end id 162 | st.markdown("Define a regex term to extract start and end id from the file names") 163 | col2_1, col2_2 = st.columns(2) 164 | start_id_regex = col2_1.text_input("Start ID Regex", value=r"From (\d+) -") 165 | end_id_regex = col2_2.text_input("End ID Regex", value=r"- (\d+)") 166 | 167 | test = st.checkbox("Test Regex") 168 | if test: 169 | # get an example name for wav file and test the regex 170 | example_name = st.text_input("Example File Name", value="From 1 - 10.wav") 171 | if st.button("Test"): 172 | start_id = re.findall(start_id_regex, example_name) 173 | end_id = re.findall(end_id_regex, example_name) 174 | if start_id and end_id: 175 | st.success("Regex is valid") 176 | st.write(f"Start ID: {int(start_id[0])} - End ID: {int(end_id[0])}") 177 | else: 178 | st.error("Regex is invalid") 179 | 180 | if uploaded_file is not None and uploaded_zip_file is not None: 181 | button_col1, button_col2 = st.columns((10, 1)) 182 | if button_col2.button("Upload"): 183 | st.session_state["job_id"] = None 184 | 185 | st.session_state["failed_files"] = [] 186 | # read csv file 187 | csv = pd.read_csv(uploaded_file, delimiter=",", usecols=["unique_identifier", "text", "sentence_length", "sentence_type"]) 188 | if csv[csv.isnull().any(axis=1)].shape[0] > 0: 189 | st.error("CSV file contains NaN values") 190 | else: 191 | csv["file_name"] = csv["unique_identifier"].apply(lambda x: x + ".wav" if not x.endswith(".wav") else x) 192 | csv["file_name"] = csv["file_name"].apply(lambda x: x.upper().replace(".WAV", ".wav")) 193 | # read and unzip zip file 194 | zip_bytes = uploaded_zip_file.read() 195 | zip_file = io.BytesIO(zip_bytes) 196 | # extract zip file to temp directory which will not delete until the program is running 197 | temp_dir = tempfile.mkdtemp() 198 | with zipfile.ZipFile(zip_file, "r") as zip_ref: 199 | zip_ref.extractall(temp_dir) 200 | # get all wav files in temp directory 201 | # pdb.set_trace() 202 | wav_files = glob(os.path.join(temp_dir, "**", "*.wav"), recursive=True) 203 | 204 | # run regex for detection of start and end id for each wav file and the failed ones will be added to a list 205 | if not segmented: 206 | wav_files, st.session_state["failed_files"] = run_regex(wav_files, start_id_regex, end_id_regex) 207 | if len(st.session_state["failed_files"]) > 0: 208 | st.error("Some files could not be processed") 209 | st.write(st.session_state["failed_files"]) 210 | 211 | # create a folder of the wavs in one directory and rename them to their unique identifier 212 | 213 | tempdir_aggregated = tempfile.mkdtemp() 214 | for wav_file in wav_files: 215 | # move the wav file to the temp dir 216 | shutil.move(wav_file, tempdir_aggregated) 217 | # get all wav files in temp directory 218 | wav_files = glob(os.path.join(tempdir_aggregated, "**", "*.wav"), recursive=True) 219 | 220 | # remove old temp dir 221 | shutil.rmtree(temp_dir) 222 | 223 | temp_dir = tempdir_aggregated 224 | 225 | # create a dataframe with the wav files 226 | wav_df = pd.DataFrame(wav_files, columns=["local_path"]) 227 | # add the filename 228 | wav_df["file_name"] = wav_df["local_path"].apply(lambda x: os.path.basename(x)) 229 | 230 | if segmented: 231 | wav_df["file_name"] = wav_df["file_name"].apply(lambda x: x.upper().replace(".WAV", ".wav")) 232 | # merge the csv and the wav dataframe 233 | df = pd.merge(wav_df, csv, on="file_name", how="left") 234 | 235 | # check if all files were found 236 | not_found_files = df[df["text"].isnull()]["file_name"].tolist() 237 | 238 | if len(not_found_files) > 0: 239 | st.write("The following files were not found in the csv file:") 240 | st.write(not_found_files) 241 | st.warning( 242 | "Please make sure that the file names in the csv file match the file names in the zip file. Processing of the files will continue with the files that were found." 243 | ) 244 | df = df.dropna(subset=["text"]) 245 | # save df to a local dir 246 | csv_dir = os.path.join(temp_dir, f"{uploaded_file.name}") 247 | df.to_csv(csv_dir, index=False) 248 | # preprocess all files and save them to the database 249 | st.write("Uploading files to database...") 250 | 251 | params = { 252 | "wavs_path": temp_dir, 253 | "csv_path": csv_dir, 254 | "deliverable": None if deliverable == "" else deliverable, 255 | } 256 | response = requests.get( 257 | BACKEND_URL + "/datasets/{}/upload_segmented_async".format(st.session_state["dataset"]["id"]), params=params 258 | ) 259 | if response.status_code == 200: 260 | st.session_state["job_id"] = response.json()["job_id"] 261 | st.success("Files upload triggered successfully") 262 | else: 263 | st.error("An error occured while uploading the files") 264 | # remove temp_dir 265 | shutil.rmtree(temp_dir, ignore_errors=True) 266 | else: 267 | if not len(st.session_state["failed_files"]) > 0: 268 | csv_dir = os.path.join(temp_dir, f"{uploaded_file.name}") 269 | csv.to_csv(csv_dir, index=False) 270 | params = { 271 | "wavs_path": temp_dir, 272 | "csv_path": csv_dir, 273 | "deliverable": None if deliverable == "" else deliverable, 274 | "start_id_regex": start_id_regex, 275 | "end_id_regex": end_id_regex, 276 | } 277 | response = requests.get( 278 | BACKEND_URL + "/datasets/{}/upload_unsegmented_async".format(st.session_state["dataset"]["id"]), params=params 279 | ) 280 | if response.status_code == 200: 281 | st.session_state["job_id"] = response.json()["job_id"] 282 | st.success("Files upload triggered successfully") 283 | else: 284 | st.error("An error occured while uploading the files") 285 | # remove temp_dir 286 | shutil.rmtree(temp_dir, ignore_errors=True) 287 | else: 288 | st.error("Please fix the errors before uploading the files") 289 | # if st.session_state["job_id"] is not None and st.button("Check Status"): 290 | # progress_bar = st.progress(0) 291 | # job_id = st.session_state["job_id"] 292 | # response = requests.get(BACKEND_URL + f"/datasets/upload_segmented_async_status/{job_id}") 293 | # if response.status_code == 200: 294 | # # {"status": job.state, "progress": progress, "onboarded_samples": job.info.get("onboarded_samples", 0), "failed_samples": job.info.get("failed_samples", [])} 295 | # response_json = response.json() 296 | # progress_bar.progress(response_json["progress"]) 297 | # st.write(f"Samples onboarding for dataset {st.session_state['dataset']['name']} is {response_json['progress']}% complete") 298 | # st.write("Status: {}".format(response_json["status"])) 299 | # st.write("Onboarded Samples Count: {}".format(response_json["onboarded_samples"])) 300 | # st.write("Failed Samples: {}".format(response_json["failed_samples"])) 301 | # else: 302 | # st.error("An error occured while getting the status of the job") 303 | 304 | 305 | app() 306 | -------------------------------------------------------------------------------- /src/web_app/admin/pages/02_📊_Insights.py: -------------------------------------------------------------------------------- 1 | # This scripts contains the dataset investigate page 2 | # it select datalaset select list and after selecting the dataset it shows the dataset details 3 | # and the samples of the dataset 4 | import os 5 | import sys 6 | 7 | import plotly.express as px 8 | import plotly.figure_factory as ff # noqa: F401 9 | import requests 10 | import streamlit as st 11 | 12 | 13 | current_file_path = os.path.dirname(os.path.abspath(__file__)) 14 | # aapedn 3 parent directories to the path 15 | sys.path.append(os.path.join(current_file_path, "..", "..", "..", "..")) 16 | 17 | import pandas as pd 18 | from dotenv import load_dotenv 19 | 20 | from src.logger import root_logger 21 | from src.paths import paths 22 | 23 | 24 | BASE_DIR = str(paths.PROJECT_ROOT_DIR.resolve()) 25 | # load the .env file 26 | load_dotenv(os.path.join(BASE_DIR, "vars.env")) 27 | 28 | app_logger = root_logger.getChild("web_app::create_dataset") 29 | BACKEND_URL = "http://{}:{}".format(os.environ.get("SERVER_HOST"), os.environ.get("SERVER_PORT")) 30 | 31 | 32 | import matplotlib.pyplot as plt 33 | from wordcloud import WordCloud # noqa: F401 34 | 35 | 36 | def get_samples(dataset_id): 37 | return requests.get(BACKEND_URL + f"/datasets/{dataset_id}/samples").json() 38 | 39 | 40 | def get_annotations(dataset_id): 41 | return requests.get(BACKEND_URL + f"/datasets/{dataset_id}/annotations").json() 42 | 43 | 44 | def get_annotators(dataset_id): 45 | return requests.get(BACKEND_URL + f"/datasets/{dataset_id}/annotators").json() 46 | 47 | 48 | def get_feedback(dataset_id): 49 | return requests.get(BACKEND_URL + f"/datasets/{dataset_id}/feedback").json() 50 | 51 | 52 | def get_datasets(): 53 | return requests.get(BACKEND_URL + "/datasets").json() 54 | 55 | 56 | def display_json(data): 57 | # If the data is a dictionary, display it as a table 58 | table_data = None 59 | if isinstance(data, dict): 60 | # edit keys: replace _ with " " and title, and make values small case 61 | data = {k.replace("_", " ").title(): v for k, v in data.items()} 62 | table_data = pd.DataFrame(data.items(), columns=["", ""]) 63 | # If the data is a list, display it as a table 64 | elif isinstance(data, list): 65 | table_data = pd.DataFrame(data) 66 | 67 | if table_data is not None: 68 | num_rows, num_cols = table_data.shape 69 | # Calculate the desired height and width of the table 70 | # You can adjust these values as per your requirement 71 | table_height = (num_rows + 1) * 30 # 30 pixels per row, plus 1 row for header 72 | table_width = num_cols * 250 # 150 pixels per column (adjust as desired) 73 | 74 | # Generate the CSS styling for the table 75 | table_style = f""" 76 | 97 | """ 98 | # Display the table using Streamlit 99 | st.markdown(table_style, unsafe_allow_html=True) 100 | st.markdown(f'
{table_data.to_html(index=False, header=False)}
', unsafe_allow_html=True) 101 | else: 102 | st.write(data) 103 | 104 | 105 | def app(): 106 | 107 | st.title("TTS Datasets") 108 | if "authentication_status" not in st.session_state: 109 | # forward to the page where the user can login 110 | st.warning("Please login first") 111 | st.stop() 112 | else: 113 | user = requests.get(f"{BACKEND_URL}/annotators/username/{st.session_state['username']}").json() 114 | if not user["isadmin"]: 115 | st.error("You are not authorized to access this page") 116 | st.stop() 117 | 118 | with st.sidebar: 119 | if st.session_state["authentication_status"]: 120 | st.write(f'Welcome *{st.session_state["name"]}*') 121 | 122 | datasets = get_datasets() 123 | dataset_names = [dataset["name"] for dataset in datasets] 124 | selected_dataset = st.selectbox("Select Dataset", dataset_names) 125 | 126 | if selected_dataset: 127 | selected_dataset = [dataset for dataset in datasets if dataset["name"] == selected_dataset][0] 128 | samples = pd.DataFrame(get_samples(selected_dataset["id"])) 129 | annotations = pd.DataFrame(get_annotations(selected_dataset["id"])) 130 | annotators = pd.DataFrame(get_annotators(selected_dataset["id"])) 131 | 132 | col1, col2 = st.columns((5, 10)) 133 | with col1: 134 | st.subheader(f"Dataset") 135 | display_json(selected_dataset) 136 | 137 | with col2: 138 | st.subheader("Annotators") 139 | display_json(annotators) 140 | 141 | st.markdown("---") 142 | 143 | # st.subheader("Sample Information") 144 | # display_json(samples) 145 | 146 | # col1, col2 = st.columns((1, 1)) 147 | # with col1: 148 | # st.subheader("Histogram: Sample Duration") 149 | # fig = ff.create_distplot([samples["duration"]], ["duration"]) 150 | # st.plotly_chart(fig) 151 | 152 | if len(annotations) > 0: 153 | # download anotations 154 | st.download_button( 155 | label="Download Annotations", 156 | data=annotations.to_csv(index=False), 157 | file_name=f"{selected_dataset['name']}_annotations.csv", 158 | mime="text/csv", 159 | ) 160 | 161 | st.subheader("Annotation Information") 162 | st.dataframe(annotations) 163 | 164 | st.subheader("Histogram: Annotation Status") 165 | fig = px.histogram(annotations, x="status") 166 | st.plotly_chart(fig) 167 | 168 | st.subheader("Histogram: Annotators") 169 | fig = px.histogram(annotations, x="annotator") 170 | st.plotly_chart(fig) 171 | 172 | st.subheader("Annotation Text Comparison") 173 | comparison_table = annotations[["filename", "original_text", "final_text"]] 174 | st.dataframe(comparison_table) 175 | 176 | st.subheader("Feedback Analysis") 177 | # check if there is any feedback 178 | feedbacks = " ".join(annotations["feedback"].dropna()) 179 | 180 | # remove the space from the feedbacks 181 | feedbacks = feedbacks.replace(" ", "") 182 | if feedbacks != "": 183 | wordcloud = WordCloud(background_color="white").generate(feedbacks) 184 | fig, ax = plt.subplots() 185 | ax.imshow(wordcloud, interpolation="bilinear") 186 | ax.axis("off") 187 | st.pyplot(fig) 188 | else: 189 | st.warning("No feedback found for this dataset") 190 | else: 191 | st.warning("No annotations found for this dataset") 192 | 193 | # st.subheader("Annotation Trends over Time") 194 | # annotations["date"] = pd.to_datetime(annotations["date"]) 195 | # fig, ax = plt.subplots() 196 | # annotations.groupby(annotations["date"].dt.date).size().plot(kind='line', ax=ax) 197 | # st.pyplot(fig) 198 | 199 | 200 | app() 201 | -------------------------------------------------------------------------------- /src/web_app/admin/🏠_Intro_admin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import streamlit as st 5 | import streamlit_authenticator as stauth 6 | 7 | 8 | current_file_path = os.path.dirname(os.path.abspath(__file__)) 9 | # aapedn 3 parent directories to the path 10 | sys.path.append(os.path.join(current_file_path, "..", "..", "..", "..")) 11 | 12 | from dotenv import load_dotenv 13 | 14 | from src.logger import root_logger 15 | from src.paths import paths 16 | 17 | 18 | BASE_DIR = str(paths.PROJECT_ROOT_DIR.resolve()) 19 | 20 | # load the .env file 21 | load_dotenv(os.path.join(BASE_DIR, "vars.env")) 22 | 23 | 24 | # set app name and icon 25 | st.set_page_config(page_title="TTS Data App", page_icon="🎙️", layout="wide") 26 | 27 | app_logger = root_logger.getChild("web_app::home") 28 | BACKEND_URL = "http://{}:{}".format(os.environ.get("SERVER_HOST"), os.environ.get("SERVER_PORT")) 29 | 30 | import requests 31 | import yaml 32 | from yaml.loader import SafeLoader 33 | 34 | 35 | config_file_path = paths.LOGIN_CONFIG_PATH 36 | with open(config_file_path) as file: 37 | config = yaml.load(file, Loader=SafeLoader) 38 | 39 | authenticator = stauth.Authenticate( 40 | config["credentials"], config["cookie"]["name"], config["cookie"]["key"], config["cookie"]["expiry_days"], config["preauthorized"] 41 | ) 42 | 43 | # sidebar 44 | with st.sidebar: 45 | name, authentication_status, username = authenticator.login("Login", "main") 46 | # ig logged in and not admin then logout get user @router.get("/username/{username}") 47 | if st.session_state["authentication_status"]: 48 | authenticator.logout("Logout", "main") 49 | user = requests.get(f"{BACKEND_URL}/annotators/username/{username}").json() 50 | if not user["isadmin"]: 51 | st.error("You are not authorized to access this page") 52 | st.stop() 53 | 54 | if st.session_state["authentication_status"]: 55 | 56 | st.write(f'Welcome *{st.session_state["name"]}*') 57 | choice = st.selectbox("Select an option", ["Create User", "Assign Dataset", "Delete User"]) 58 | if choice == "Create User": 59 | try: 60 | with st.form("Register user"): 61 | email = st.text_input("Email") 62 | username = st.text_input("Username") 63 | name = st.text_input("Name") 64 | isadmin = st.selectbox("Is admin?", ["False", "True"]) 65 | password = st.text_input("Password", type="password") 66 | repeat_password = st.text_input("Repeat password", type="password") 67 | submit_button = st.form_submit_button("Add User") 68 | 69 | if submit_button: 70 | if password == repeat_password: 71 | # create user 72 | params = {"password": password, "name": name, "email": email, "isadmin": isadmin, "ispreauthorized": True} 73 | response = requests.post(f"{BACKEND_URL}/annotators/{username}", params=params) 74 | if response.status_code == 200: 75 | response: str = response.json() # type: ignore 76 | if "message" in response: # type: ignore 77 | st.error(response["error"]) # type: ignore 78 | else: 79 | st.success("User created successfully") 80 | 81 | else: 82 | st.error("Something went wrong") 83 | else: 84 | st.error("Passwords do not match") 85 | except Exception as e: 86 | st.error(e) 87 | if choice == "Assign Dataset": 88 | try: 89 | all_datasets = requests.get(f"{BACKEND_URL}/datasets").json() 90 | annotators = requests.get(f"{BACKEND_URL}/annotators").json() 91 | annotator_selected = st.selectbox("To Annotator", [annotator["username"] for annotator in annotators]) 92 | with st.form("Assign dataset"): 93 | annotator_id = [annotator["id"] for annotator in annotators if annotator["username"] == annotator_selected][0] 94 | assigned_datasets = requests.get(f"{BACKEND_URL}/annotators/{annotator_id}/datasets").json() 95 | not_assigned_datasets = [dataset for dataset in all_datasets if dataset not in assigned_datasets] 96 | datasets_selected = st.multiselect("Dataset", [dataset["name"] for dataset in not_assigned_datasets]) 97 | if len(not_assigned_datasets) == 0: 98 | st.warning("No datasets or all datasets are assigned to this annotator") 99 | submit_button = st.form_submit_button("Assign Dataset") 100 | 101 | if submit_button: 102 | # assign dataset 103 | for dataset_selected in datasets_selected: 104 | dataset_id = [dataset["id"] for dataset in not_assigned_datasets if dataset["name"] == dataset_selected][0] 105 | response = requests.post(f"{BACKEND_URL}/annotators/{annotator_id}/datasets/{dataset_id}") 106 | if response.status_code == 200: 107 | response = response.json() 108 | if response["message"] == "Failed": # type: ignore 109 | st.error(response["error"]) # type: ignore 110 | else: 111 | st.success(f"Dataset {dataset_selected} assigned to {annotator_selected}") 112 | else: 113 | st.error("Something went wrong") 114 | except Exception as e: 115 | st.error(e) 116 | if choice == "Delete User": 117 | try: 118 | annotators = requests.get(f"{BACKEND_URL}/annotators").json() 119 | with st.form("Delete user"): 120 | annotator_selected = st.selectbox("Annotator", [annotator["username"] for annotator in annotators]) 121 | submit_button = st.form_submit_button("Delete User") 122 | if submit_button: 123 | # delete user 124 | annotator_id = [annotator["id"] for annotator in annotators if annotator["username"] == annotator_selected][0] 125 | response = requests.delete(f"{BACKEND_URL}/annotators/{annotator_id}") 126 | if response.status_code == 200: 127 | response = response.json() 128 | if response["message"] == "Success": # type: ignore 129 | st.success(f"User {annotator_selected} deleted") 130 | else: 131 | st.error(response["error"]) # type: ignore 132 | else: 133 | st.error("Something went wrong") 134 | except Exception as e: 135 | st.error(f"Error: {e}") 136 | 137 | elif st.session_state["authentication_status"] is False: 138 | st.error("Username/password is incorrect") 139 | try: 140 | username_forgot_pw, email_forgot_password, random_password = authenticator.forgot_password("Forgot password") 141 | if username_forgot_pw: 142 | st.success("New password sent securely") 143 | with open(config_file_path, "w") as file: 144 | yaml.dump(config, file, default_flow_style=False) 145 | # Random password to be transferred to user securely 146 | else: 147 | st.error("Username not found") 148 | except Exception as e: 149 | st.error(e) 150 | elif st.session_state["authentication_status"] is None: 151 | st.warning("Please enter your username and password") 152 | 153 | 154 | st.title("TTS QA: Admin App") 155 | text = """ 156 | ## Create and Upload Datasets 157 | Welcome to the Create and Upload Datasets page of the Admin Application! Here, you can create new datasets or 158 | upload existing ones to the system for annotators to work on. To create a new dataset, simply click on the "Create Dataset" 159 | button and fill in the required information such as the dataset name, description, and any relevant tags. 160 | You can also upload a CSV file containing your data directly to the system. Once your dataset is created or uploaded, you can assign it to annotators to start the QA process. 161 | 162 | ## Insight 163 | Welcome to the Insight page of the Admin Application! Here, you can get valuable insights and statistics about the QA 164 | processes and datasets in the system. You can view the overall status of your datasets, such as the number of annotations completed, in progress, 165 | or pending. You can also view detailed statistics on individual datasets, such as the number of correct and incorrect annotations, the time taken for 166 | each annotation, and any issues encountered during the QA process. Use these insights to improve the quality of your datasets and optimize your QA process. 167 | """ 168 | 169 | 170 | st.markdown(text) 171 | -------------------------------------------------------------------------------- /src/web_app/annotator/02_🔍_QA-depreciated.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import requests 5 | import streamlit as st 6 | 7 | 8 | current_file_path = os.path.dirname(os.path.abspath(__file__)) 9 | # aapedn 3 parent directories to the path 10 | sys.path.append(os.path.join(current_file_path, "..", "..", "..", "..")) 11 | 12 | from dotenv import load_dotenv 13 | 14 | from src.logger import root_logger 15 | from src.paths import paths 16 | 17 | 18 | BASE_DIR = str(paths.PROJECT_ROOT_DIR.resolve()) 19 | # load the .env file 20 | load_dotenv(os.path.join(BASE_DIR, "vars.env")) 21 | 22 | lang2idx = { 23 | "English": "en", 24 | "German": "de", 25 | "French": "fr", 26 | "Spanish": "es", 27 | "Italian": "it", 28 | } 29 | 30 | idx2lang = {v: k for k, v in lang2idx.items()} 31 | 32 | app_logger = root_logger.getChild("web_app::annotate") 33 | BACKEND_URL = "http://{}:{}".format(os.environ.get("SERVER_HOST"), os.environ.get("SERVER_PORT")) 34 | 35 | 36 | def app(): 37 | 38 | st.title("TTS QA") 39 | if "authentication_status" not in st.session_state: 40 | # forward to the page where the user can login 41 | st.warning("Please login first") 42 | st.stop() 43 | with st.sidebar: 44 | if st.session_state["authentication_status"]: 45 | st.write(f'Welcome *{st.session_state["name"]}*') 46 | 47 | def get_datasets(annotator_id: int): 48 | return requests.get(BACKEND_URL + f"/annotators/{annotator_id}/datasets").json() 49 | 50 | st.markdown( 51 | """ 52 | 68 | """, 69 | unsafe_allow_html=True, 70 | ) 71 | 72 | columns_sizes = (40, 5, 5) 73 | 74 | if "sample" not in st.session_state: 75 | st.session_state["sample"] = None 76 | 77 | if "test_count" not in st.session_state: 78 | st.session_state["test_count"] = 0 79 | 80 | if "user_input" not in st.session_state: 81 | st.session_state["user_input"] = { 82 | "final_text": "", 83 | "final_sentence_type": "statement", 84 | "isRepeated": False, 85 | "isAccentRight": False, 86 | "isPronunciationRight": False, 87 | "isClean": False, 88 | "isPausesRight": False, 89 | "isSpeedRight": False, 90 | "isConsisent": False, 91 | "feedback": "", 92 | "status": "NotReviewed", 93 | } 94 | 95 | if "query_button" not in st.session_state: 96 | st.session_state["query_button"] = True 97 | 98 | if "annotate_button" not in st.session_state: 99 | st.session_state["annotate_button"] = False 100 | 101 | if "isFirstRun" not in st.session_state: 102 | st.session_state["isFirstRun"] = True 103 | 104 | if "annotator_id" not in st.session_state: 105 | annotator = requests.get(BACKEND_URL + f"/annotators/username/{st.session_state['username']}").json() 106 | st.session_state["annotator_id"] = annotator["id"] 107 | 108 | if "dataset_id" not in st.session_state: 109 | st.session_state["dataset_id"] = None 110 | 111 | if "prev_dataset_id" not in st.session_state: 112 | st.session_state["prev_dataset_id"] = None 113 | 114 | if "datasets" not in st.session_state: 115 | datasets = get_datasets(st.session_state["annotator_id"]) 116 | st.session_state["datasets"] = datasets 117 | 118 | if "stats" not in st.session_state: 119 | st.session_state["stats"] = { 120 | "total": 0, 121 | "annotated": 0, 122 | "not_annotated": 0, 123 | } 124 | 125 | def annotate_sample( 126 | id: int, 127 | annotator_id: int, 128 | final_text: str, 129 | final_sentence_type: str, 130 | isRepeated: bool, 131 | isAccentRight: bool, 132 | isPronunciationRight: bool, 133 | isClean: bool, 134 | isPausesRight: bool, 135 | isSpeedRight: bool, 136 | isConsisent: bool, 137 | feedback: str, 138 | status: str = "NotReviewed", 139 | ): 140 | 141 | data = { 142 | "annotator_id": annotator_id, 143 | "final_text": final_text, 144 | "final_sentence_type": final_sentence_type, 145 | "isRepeated": isRepeated, 146 | "isAccentRight": isAccentRight, 147 | "isPronunciationRight": isPronunciationRight, 148 | "isClean": isClean, 149 | "isPausesRight": isPausesRight, 150 | "isSpeedRight": isSpeedRight, 151 | "isConsisent": isConsisent, 152 | "feedback": feedback, 153 | "status": status, 154 | } 155 | response = requests.put(BACKEND_URL + f"/samples/{id}", json=data) 156 | if response.status_code == 200: 157 | app_logger.info("Sample annotated") 158 | else: 159 | app_logger.error("Sample annotation failed") 160 | st.error("Sample annotation failed") 161 | return response 162 | 163 | def query(): 164 | if st.session_state["annotate_button"]: 165 | if st.session_state["user_input"]["status"] in ["Discarded", "Reviewed"]: 166 | st.session_state["user_input"].update({"id": st.session_state["sample"]["id"], "annotator_id": st.session_state["annotator_id"]}) 167 | response = annotate_sample(**st.session_state["user_input"]) 168 | st.session_state["annotate_button"] = False 169 | 170 | try: 171 | # send a request to get next sample 172 | response = requests.get(BACKEND_URL + f"/datasets/{st.session_state['dataset_id']}/next_sample") 173 | if response.status_code == 200: 174 | response = response.json() 175 | if "message" in response: 176 | st.session_state["sample"] = None 177 | st.session_state["stats"] = None 178 | return 179 | sample = response["sample"] 180 | stats = response["stats"] 181 | st.session_state["sample"] = sample 182 | st.session_state["stats"] = stats 183 | st.session_state["user_input"] = { 184 | "final_text": sample["final_text"], 185 | "final_sentence_type": sample["sentence_type"], 186 | "isRepeated": False, 187 | "isAccentRight": False, 188 | "isPronunciationRight": False, 189 | "isClean": False, 190 | "isPausesRight": False, 191 | "isSpeedRight": False, 192 | "isConsisent": False, 193 | "feedback": "", 194 | "status": "NotReviewed", 195 | } 196 | st.session_state["query_button"] = False 197 | app_logger.info("Next sample retrieved") 198 | else: 199 | st.error("No more samples to annotate") 200 | app_logger.error("No more samples to annotate") 201 | except Exception as e: 202 | app_logger.error(e) 203 | 204 | def sample_container(sample): 205 | container = st.container() 206 | 207 | col1, col2, col3, col4 = container.columns(4) 208 | col1.metric("ID", sample["filename"]) 209 | col2.metric("Sentence Type", f"{sample['sentence_type']}") 210 | col3.metric("Length", "10") 211 | col4.metric("WER", sample["wer"]) 212 | 213 | container.markdown("---") 214 | 215 | container.markdown("## Listen Audio") 216 | # audio player 217 | audio_file = open(st.session_state["sample"]["local_trimmed_path"], "rb") 218 | container.audio(audio_file, format="audio/wav") 219 | 220 | container.markdown("---") 221 | # Divide screen into 2 columns 222 | col1, col2 = container.columns(2) 223 | col1.subheader("Original Text") 224 | col2.subheader("ASR") 225 | 226 | # For all systems show all output sentences under each other and system names 227 | original_text = col1.text_area( 228 | "Original Text", 229 | sample["original_text"], 230 | key=f"original_text", 231 | label_visibility="hidden", 232 | ) 233 | 234 | asr_text = col2.text_area( 235 | "ASR", 236 | sample["asr_text"], 237 | key=f"asr_text", 238 | label_visibility="hidden", 239 | ) 240 | 241 | # add vertival radio button for each system 242 | better = st.radio("Selected Transcription", ["Original", "ASR"], key=f"better_select", index=0, horizontal=True) 243 | 244 | ph = st.empty() 245 | if better != "": 246 | st.session_state["user_input"]["final_text"] = ph.text_area( 247 | "Please post edit the text if needed", 248 | st.session_state["sample"][f"{better.lower()}_text"], 249 | key=f"best_sys", 250 | ) 251 | sentence_type_list = ["Statement", "Question", "Exclamation"] 252 | defult_idx = sentence_type_list.index(sample["sentence_type"].title()) 253 | sentence_type = st.radio("Sentence Type", sentence_type_list, key=f"sentence_type", index=defult_idx, horizontal=True) 254 | st.session_state["user_input"]["final_sentence_type"] = sentence_type.lower() 255 | 256 | st.session_state["prev_dataset_id"] = st.session_state["dataset_id"] 257 | 258 | st.session_state["dataset_id"] = st.sidebar.selectbox( 259 | "Dataset ", 260 | [d["name"] for d in st.session_state["datasets"]], 261 | ) 262 | st.session_state["dataset_id"] = [d["id"] for d in st.session_state["datasets"] if d["name"] == st.session_state["dataset_id"]][0] 263 | 264 | if st.session_state["prev_dataset_id"] != st.session_state["dataset_id"]: 265 | st.session_state["query_button"] = True 266 | st.session_state["annotate_button"] = False 267 | st.session_state["isFirstRun"] = True 268 | st.experimental_rerun() 269 | 270 | if st.session_state["dataset_id"] is not None: 271 | 272 | if st.session_state["query_button"] and st.session_state["isFirstRun"]: 273 | st.balloons() 274 | query() 275 | st.session_state["isFirstRun"] = False 276 | elif st.session_state["query_button"]: 277 | query() 278 | st.session_state["query_button"] = False 279 | st.session_state["annotate_button"] = False 280 | # add progresss bar 281 | if st.session_state["stats"] is not None: 282 | progress_bar = st.progress(0, text="Progress") 283 | progress = st.session_state["stats"]["annotated"] / st.session_state["stats"]["total"] 284 | progress_bar.progress( 285 | progress, 286 | text=f"Progress: {st.session_state['stats']['annotated']} Rated, {st.session_state['stats']['total'] - st.session_state['stats']['annotated']} Remaining out of {st.session_state['stats']['total']} recordings", 287 | ) 288 | 289 | if st.session_state["sample"] is not None: 290 | 291 | # Input sentence 292 | sample_container(st.session_state["sample"]) 293 | 294 | # create a divider 295 | st.markdown("---") 296 | # ask if you want to Discard or Save 297 | col1, col2 = st.columns(2) 298 | with col1: 299 | discard = st.checkbox("Discard", value=False) 300 | 301 | if discard: 302 | col1, col2 = st.columns(2) 303 | with col1: 304 | isRepeated = True if st.checkbox("Has Repeation", value=False) else False 305 | isAccentRight = True if st.checkbox("Accent is Wrong", value=False) else False 306 | isPronunciationRight = True if st.checkbox("Pronunciation is Wrong", value=False) else False 307 | isClean = True if st.checkbox("Recording is not Clean", value=False) else False 308 | isPausesRight = True if st.checkbox("Pauses are not right", value=False) else False 309 | isSpeedRight = True if st.checkbox("Speed is not right", value=False) else False 310 | isConsisent = True if st.checkbox("Voice is not consistent", value=False) else False 311 | with col2: 312 | feedback = st.text_area("Feedback", value=st.session_state["user_input"]["feedback"]) 313 | 314 | st.session_state["user_input"]["isRepeated"] = isRepeated 315 | st.session_state["user_input"]["isAccentRight"] = isAccentRight 316 | st.session_state["user_input"]["isPronunciationRight"] = isPronunciationRight 317 | st.session_state["user_input"]["isClean"] = isClean 318 | st.session_state["user_input"]["isPausesRight"] = isPausesRight 319 | st.session_state["user_input"]["isSpeedRight"] = isSpeedRight 320 | st.session_state["user_input"]["isConsisent"] = isConsisent 321 | st.session_state["user_input"]["feedback"] = feedback 322 | else: 323 | st.session_state["user_input"]["isRepeated"] = True 324 | st.session_state["user_input"]["isAccentRight"] = True 325 | st.session_state["user_input"]["isPronunciationRight"] = True 326 | st.session_state["user_input"]["isClean"] = True 327 | st.session_state["user_input"]["isPausesRight"] = True 328 | st.session_state["user_input"]["isSpeedRight"] = True 329 | st.session_state["user_input"]["isConsisent"] = True 330 | st.session_state["user_input"]["feedback"] = "" 331 | col1, col2 = st.columns((10, 1)) 332 | with col2: 333 | submitted = st.button("Submit") 334 | if submitted: 335 | if discard: 336 | status = "Discarded" 337 | else: 338 | status = "Reviewed" 339 | st.session_state["user_input"]["status"] = status 340 | st.success("Submitted!") 341 | st.session_state["query_button"] = True 342 | st.session_state["annotate_button"] = True 343 | st.experimental_rerun() 344 | 345 | else: 346 | st.warning("No more samples to rate") 347 | 348 | else: 349 | st.warning("Select Annotator and Dataset") 350 | 351 | 352 | app() 353 | -------------------------------------------------------------------------------- /src/web_app/annotator/pages/01_🎙️_Record: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | import sys 4 | from glob import glob 5 | 6 | import pandas as pd 7 | import streamlit as st 8 | from audio_recorder_streamlit import audio_recorder 9 | 10 | from src.utils.audio import convert_to_mono, normalize_audio 11 | 12 | 13 | current_file_path = os.path.dirname(os.path.abspath(__file__)) 14 | # aapedn 3 parent directories to the path 15 | sys.path.append(os.path.join(current_file_path, "..", "..", "..", "..")) 16 | 17 | from src.logger import root_logger 18 | from src.paths import paths 19 | 20 | 21 | BASE_DIR = str(paths.PROJECT_ROOT_DIR.resolve()) 22 | app_logger = root_logger.getChild("web_app::record") 23 | 24 | sample_df = pd.read_csv(os.path.join(BASE_DIR, "src", "web_app", "annotator", "data", "sample_record.csv")) 25 | 26 | SAVE_DIR_AUDIO = os.path.join(BASE_DIR, "src", "web_app", "annotator", "data", "audio") 27 | SAVE_DIR_CSV = os.path.join(BASE_DIR, "src", "web_app", "annotator", "data", "csv") 28 | 29 | 30 | def app(): 31 | if "authentication_status" not in st.session_state: 32 | # forward to the page where the user can login 33 | st.warning("Please login first") 34 | st.stop() 35 | 36 | if not os.path.exists(SAVE_DIR_AUDIO): 37 | os.makedirs(SAVE_DIR_AUDIO) 38 | 39 | if not os.path.exists(SAVE_DIR_CSV): 40 | os.makedirs(SAVE_DIR_CSV) 41 | 42 | if "csv" not in st.session_state: 43 | st.session_state.csv = None 44 | 45 | if "row" not in st.session_state: 46 | st.session_state.row = None 47 | 48 | if "index" not in st.session_state: 49 | st.session_state.index = None 50 | 51 | if "csv_name" not in st.session_state: 52 | st.session_state.csv_name = None 53 | 54 | st.title("Record Audio") 55 | with st.sidebar: 56 | if st.session_state["authentication_status"]: 57 | st.write(f'Welcome *{st.session_state["name"]}*') 58 | 59 | # show a sample csv 60 | csv = sample_df.to_csv(index=False) 61 | b64 = base64.b64encode(csv.encode()).decode() 62 | href = f'Download example csv file' 63 | st.markdown(href, unsafe_allow_html=True) 64 | 65 | # select_a_file or upload a file 66 | files = glob(os.path.join(SAVE_DIR_CSV, "*.csv")) 67 | files = [os.path.basename(f) for f in files] 68 | option = st.selectbox("Select a file or upload a file", files + ["Upload a file"]) 69 | if option == "Upload a file": 70 | # upload csv file 71 | uploaded_file = st.file_uploader("Upload CSV File", type=["csv"]) 72 | if uploaded_file is not None: 73 | # Name the file 74 | st.session_state.csv_name = st.text_input("Name the file", value="recordings_1.csv") 75 | csv = pd.read_csv(uploaded_file, delimiter=",", usecols=["unique_identifier", "text", "sentence_length", "sentence_type"]) 76 | if "record_status" not in csv.columns: 77 | csv["record_status"] = "Not Recorded" 78 | if st.button("Upload"): 79 | st.session_state.csv = csv 80 | # check the audio files and update if needed 81 | audio_files_ids = [os.path.splitext(os.path.basename(f))[0] for f in glob(os.path.join(SAVE_DIR_AUDIO, "*.wav"))] 82 | # the ones in the audio_files shold be marked as recorded 83 | st.session_state.csv.loc[st.session_state.csv["unique_identifier"].isin(audio_files_ids), "record_status"] = "Recorded" 84 | st.session_state.csv.to_csv(os.path.join(SAVE_DIR_CSV, st.session_state.csv_name), index=False) 85 | else: 86 | # select a file 87 | st.session_state.csv_name = os.path.basename(option) 88 | st.session_state.csv = pd.read_csv( 89 | os.path.join(SAVE_DIR_CSV, option), delimiter=",", usecols=["unique_identifier", "text", "sentence_length", "sentence_type", "record_status"] 90 | ) 91 | 92 | # now from the csv file, we go in order and record the audio if it is not recorded 93 | if st.session_state.csv is not None: 94 | st.session_state.index = st.session_state.csv[st.session_state.csv["record_status"] == "Not Recorded"].index[0] 95 | st.session_state.row = st.session_state.csv.loc[st.session_state.index] 96 | # get first row that is not recorded 97 | if st.session_state.row["record_status"] == "Not Recorded": 98 | st.write(f"Recording {st.session_state.row['unique_identifier']}") 99 | # wrte text very bold and big 100 | col1, col2 = st.columns(2) 101 | col1.markdown(f"**{st.session_state.row['text']}**", unsafe_allow_html=True) 102 | with col2: 103 | audio_bytes = audio_recorder( 104 | pause_threshold=1.0, neutral_color="#303030", recording_color="#de1212", icon_name="microphone", icon_size="3x", sample_rate=88_000 105 | ) 106 | 107 | if audio_bytes: 108 | st.audio(audio_bytes, format="audio/wav") 109 | col1, col2, col3 = st.columns(3) 110 | 111 | if col1.button("Save"): 112 | st.session_state.csv.at[st.session_state.index, "record_status"] = "Recorded" 113 | st.session_state.csv.to_csv(os.path.join(SAVE_DIR_CSV, st.session_state.csv_name), index=False) 114 | save_path = os.path.join(SAVE_DIR_AUDIO, st.session_state.row["unique_identifier"] + ".wav") 115 | with open(save_path, "wb") as f: 116 | f.write(audio_bytes) 117 | normalize_audio(save_path, save_path) 118 | convert_to_mono(save_path, save_path) 119 | st.success("Saved") 120 | 121 | if col2.button("Next"): 122 | st.session_state.index = st.session_state.csv[st.session_state.csv["record_status"] == "Not Recorded"].index[0] 123 | st.session_state.row = st.session_state.csv.loc[st.session_state.index] 124 | st.experimental_rerun() 125 | 126 | 127 | app() 128 | -------------------------------------------------------------------------------- /src/web_app/annotator/pages/02_🔍_QA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | import pandas as pd 6 | import requests 7 | import streamlit as st 8 | 9 | 10 | current_file_path = os.path.dirname(os.path.abspath(__file__)) 11 | # aapedn 3 parent directories to the path 12 | sys.path.append(os.path.join(current_file_path, "..", "..", "..", "..")) 13 | 14 | from dotenv import load_dotenv 15 | 16 | from src.logger import root_logger 17 | from src.paths import paths 18 | 19 | 20 | BASE_DIR = str(paths.PROJECT_ROOT_DIR.resolve()) 21 | # load the .env file 22 | load_dotenv(os.path.join(BASE_DIR, "vars.env")) 23 | 24 | lang2idx = { 25 | "English": "en", 26 | "German": "de", 27 | "French": "fr", 28 | "Spanish": "es", 29 | "Italian": "it", 30 | } 31 | 32 | idx2lang = {v: k for k, v in lang2idx.items()} 33 | 34 | app_logger = root_logger.getChild("web_app::annotate") 35 | BACKEND_URL = "http://{}:{}".format(os.environ.get("SERVER_HOST"), os.environ.get("SERVER_PORT")) 36 | 37 | 38 | # Function to display json data in a structured way 39 | def display_json(data, col): 40 | # If the data is a dictionary, display it as a table 41 | table_data = None 42 | if isinstance(data, dict): 43 | # edit keys: replace _ with " " and title, and make values small case 44 | # data = {k.replace("_", " ").title(): v for k, v in data.items()} 45 | table_data = pd.DataFrame(data.items(), columns=["", ""]) 46 | # If the data is a list, display it as a table 47 | elif isinstance(data, list): 48 | table_data = pd.DataFrame(data) 49 | 50 | if table_data is not None: 51 | # transpose the table data 52 | table_data = table_data.T 53 | num_rows, num_cols = table_data.shape 54 | # Calculate the desired height and width of the table 55 | # You can adjust these values as per your requirement 56 | table_height = (num_rows + 1) * 50 # 30 pixels per row, plus 1 row for header 57 | table_width = num_cols * 450 # 150 pixels per column (adjust as desired) 58 | 59 | # Generate the CSS styling for the table 60 | table_style = f""" 61 | 82 | """ 83 | # Display the table using Streamlit 84 | col.markdown(table_style, unsafe_allow_html=True) 85 | col.markdown(f'
{table_data.to_html(index=False, header=False)}
', unsafe_allow_html=True) 86 | else: 87 | col.write(data) 88 | 89 | 90 | def app(): 91 | 92 | st.title("TTS QA") 93 | if "authentication_status" not in st.session_state: 94 | # forward to the page where the user can login 95 | st.warning("Please login first") 96 | st.stop() 97 | 98 | with st.sidebar: 99 | if st.session_state["authentication_status"]: 100 | st.write(f'Welcome *{st.session_state["name"]}*') 101 | 102 | def get_datasets(annotator_id: int): 103 | return requests.get(BACKEND_URL + f"/annotators/{annotator_id}/datasets").json() 104 | 105 | st.markdown( 106 | """ 107 | 123 | """, 124 | unsafe_allow_html=True, 125 | ) 126 | 127 | if "sample" not in st.session_state: 128 | st.session_state["sample"] = None 129 | 130 | if "run_id" not in st.session_state: 131 | st.session_state["run_id"] = 0 132 | 133 | if "test_count" not in st.session_state: 134 | st.session_state["test_count"] = 0 135 | 136 | if "user_input" not in st.session_state: 137 | st.session_state["user_input"] = { 138 | "final_text": "", 139 | "final_sentence_type": "statement", 140 | "isRepeated": True, 141 | # "isAccentRight": False, 142 | # "isPronunciationRight": False, 143 | # "isClean": False, 144 | # "isPausesRight": False, 145 | # "isSpeedRight": False, 146 | # "isConsisent": False, 147 | "incorrectProsody": True, 148 | "inconsistentTextAudio": True, 149 | "incorrectTrancuation": True, 150 | "soundArtifacts": True, 151 | "feedback": "", 152 | "status": "NotReviewed", 153 | } 154 | 155 | if "query_button" not in st.session_state: 156 | st.session_state["query_button"] = True 157 | 158 | if "annotate_button" not in st.session_state: 159 | st.session_state["annotate_button"] = False 160 | 161 | if "isFirstRun" not in st.session_state: 162 | st.session_state["isFirstRun"] = True 163 | 164 | if "annotator_id" not in st.session_state: 165 | annotator = requests.get(BACKEND_URL + f"/annotators/username/{st.session_state['username']}").json() 166 | st.session_state["annotator_id"] = annotator["id"] 167 | 168 | if "dataset_id" not in st.session_state: 169 | st.session_state["dataset_id"] = None 170 | 171 | if "prev_dataset_id" not in st.session_state: 172 | st.session_state["prev_dataset_id"] = None 173 | 174 | if "datasets" not in st.session_state: 175 | datasets = get_datasets(st.session_state["annotator_id"]) 176 | st.session_state["datasets"] = datasets 177 | 178 | if "stats" not in st.session_state: 179 | st.session_state["stats"] = None 180 | 181 | def annotate_sample( 182 | id: int, 183 | annotator_id: int, 184 | final_text: str, 185 | final_sentence_type: str, 186 | isRepeated: bool, 187 | # isAccentRight: bool, 188 | # isPronunciationRight: bool, 189 | # isClean: bool, 190 | # isPausesRight: bool, 191 | # isSpeedRight: bool, 192 | # isConsisent: bool, 193 | incorrectProsody: bool, 194 | inconsistentTextAudio: bool, 195 | incorrectTrancuation: bool, 196 | soundArtifacts: bool, 197 | feedback: str, 198 | status: str = "NotReviewed", 199 | ): 200 | 201 | data = { 202 | "annotator_id": annotator_id, 203 | "final_text": final_text, 204 | "final_sentence_type": final_sentence_type, 205 | "isRepeated": isRepeated, 206 | # "isAccentRight": isAccentRight, 207 | # "isPronunciationRight": isPronunciationRight, 208 | # "isClean": isClean, 209 | # "isPausesRight": isPausesRight, 210 | # "isSpeedRight": isSpeedRight, 211 | # "isConsisent": isConsisent, 212 | "incorrectProsody": incorrectProsody, 213 | "inconsistentTextAudio": inconsistentTextAudio, 214 | "incorrectTrancuation": incorrectTrancuation, 215 | "soundArtifacts": soundArtifacts, 216 | "feedback": feedback, 217 | "status": status, 218 | } 219 | response = requests.put(BACKEND_URL + f"/samples/{id}", json=data) 220 | if response.status_code == 200: 221 | app_logger.info(f"Sample {id} annotated") 222 | st.success("Sample annotated") 223 | # unlock the sample 224 | response = requests.put(BACKEND_URL + f"/samples/{id}/unlock") 225 | if response.status_code == 200: 226 | app_logger.info(f"Sample {id} unlocked") 227 | else: 228 | app_logger.error("Sample annotation failed") 229 | st.error("Sample annotation failed") 230 | return response 231 | 232 | def query(): 233 | if st.session_state["annotate_button"]: 234 | if st.session_state["user_input"]["status"] in ["Discarded", "Reviewed"]: 235 | st.session_state["user_input"].update({"id": st.session_state["sample"]["id"], "annotator_id": st.session_state["annotator_id"]}) 236 | response = annotate_sample(**st.session_state["user_input"]) 237 | st.session_state["annotate_button"] = False 238 | 239 | try: 240 | # send a request to get next sample 241 | response = requests.get(BACKEND_URL + f"/datasets/{st.session_state['dataset_id']}/next_sample") 242 | if response.status_code == 200: 243 | response = response.json() 244 | if "message" in response: 245 | 246 | st.session_state["sample"] = None 247 | st.session_state["stats"] = None 248 | if "error" in response: 249 | app_logger.error(f"Failed to get next sample. Error: {response['error']}") 250 | return 251 | 252 | sample = response["sample"] 253 | stats = response["stats"] 254 | st.session_state["sample"] = sample 255 | st.session_state["stats"] = stats 256 | if sample is not None: 257 | st.session_state["user_input"] = { 258 | "final_text": sample["original_text"], 259 | "final_sentence_type": sample["sentence_type"], 260 | "isRepeated": True, 261 | # "isAccentRight": False, 262 | # "isPronunciationRight": False, 263 | # "isClean": False, 264 | # "isPausesRight": False, 265 | # "isSpeedRight": False, 266 | # "isConsisent": False, 267 | "incorrectProsody": True, 268 | "inconsistentTextAudio": True, 269 | "incorrectTrancuation": True, 270 | "soundArtifacts": True, 271 | "feedback": "", 272 | "status": "NotReviewed", 273 | } 274 | 275 | # lock the sample 276 | 277 | response = requests.put(BACKEND_URL + f"/samples/{sample['id']}/lock") 278 | if response.status_code == 200: 279 | app_logger.info(f"Sample {sample['id']} locked") 280 | st.session_state["query_button"] = False 281 | app_logger.info("Next sample retrieved") 282 | else: 283 | st.error(f"Failed to get next sample, status code: {response.status_code}") 284 | app_logger.error(f"Failed to get next sample, status code: {response.status_code}") 285 | except Exception as e: 286 | app_logger.error(f"Failed to get next sample. Error: {traceback.format_exc()}") 287 | 288 | if st.button("See Latest Annotation"): 289 | # @router.get("/{id}/samples/latest") 290 | 291 | response = requests.get(BACKEND_URL + f"/annotators/{st.session_state['annotator_id']}/samples/{st.session_state['dataset_id']}/latest").json() 292 | if "message" in response: 293 | st.warning(response["message"]) 294 | else: 295 | st.session_state["sample"] = response 296 | st.session_state["query_button"] = False 297 | st.session_state["annotate_button"] = False 298 | 299 | columns_sizes = (15, 15, 5) 300 | 301 | st.session_state["prev_dataset_id"] = st.session_state["dataset_id"] 302 | 303 | st.session_state["dataset_id"] = st.sidebar.selectbox( 304 | "Dataset ", 305 | [d["name"] for d in st.session_state["datasets"]], 306 | ) 307 | st.session_state["dataset_id"] = [d["id"] for d in st.session_state["datasets"] if d["name"] == st.session_state["dataset_id"]][0] 308 | 309 | if st.session_state["prev_dataset_id"] != st.session_state["dataset_id"]: 310 | st.session_state["query_button"] = True 311 | st.session_state["annotate_button"] = False 312 | st.session_state["isFirstRun"] = True 313 | if st.session_state["sample"] is not None: 314 | response = requests.put(BACKEND_URL + f"/samples/{int(st.session_state['sample']['id'])}/unlock") 315 | if response.status_code == 200: 316 | app_logger.info(f"Sample {int(st.session_state['sample']['id'])} unlocked") 317 | st.experimental_rerun() 318 | 319 | if st.session_state["dataset_id"] is not None: 320 | 321 | if st.session_state["query_button"] and st.session_state["isFirstRun"]: 322 | st.balloons() 323 | query() 324 | st.session_state["isFirstRun"] = False 325 | elif st.session_state["query_button"]: 326 | query() 327 | st.session_state["query_button"] = False 328 | st.session_state["annotate_button"] = False 329 | # add progresss bar 330 | if "stats" in st.session_state and st.session_state["stats"] is not None: 331 | progress_bar = st.progress(0, text="Progress") 332 | progress = st.session_state["stats"]["annotated"] / st.session_state["stats"]["total"] 333 | progress_bar.progress( 334 | progress, 335 | text=f"Progress: {st.session_state['stats']['annotated']} Rated, {st.session_state['stats']['total'] - st.session_state['stats']['annotated']} Remaining out of {st.session_state['stats']['total']} recordings", 336 | ) 337 | 338 | if st.session_state["sample"] is not None: 339 | 340 | # Input sentence 341 | # sample_container(st.session_state["sample"]) 342 | sample = st.session_state["sample"] 343 | col1, col2, col3 = st.columns(columns_sizes) 344 | col1.metric("ID", sample["filename"]) 345 | col2.metric("Sentence Type", f"{sample['sentence_type']}") 346 | col3.metric("WER", sample["wer"]) 347 | 348 | st.markdown("### Audio") 349 | col1, col2 = st.columns((20, 2)) 350 | # audio player 351 | audio_file = open(st.session_state["sample"]["local_trimmed_path"], "rb") 352 | col1.audio(audio_file, format="audio/wav") 353 | submitted1 = col2.button("Submit", key=f"submit_container_{st.session_state['run_id']}") 354 | 355 | # st.markdown("---") 356 | col1, col2 = st.columns(2) 357 | original_text = col1.text_area( 358 | "Original Text", 359 | f'{sample["original_text"]}', 360 | ) 361 | asr_text = col2.text_area( 362 | "ASR Text", 363 | sample["asr_text"], 364 | ) 365 | 366 | st.markdown("---") 367 | 368 | postedit_columns_sizes = (10, 35, 10, 10) 369 | # Divide screen into 2 columns 370 | col2, col3, col4, col5 = st.columns(postedit_columns_sizes) 371 | # For all systems show all output sentences under each other and system names 372 | 373 | # add vertival radio button for each system 374 | col2.markdown("### Select Better") 375 | better = col2.radio( 376 | "Selected Transcription", 377 | ["Original", "ASR"], 378 | key=f"better_select_{st.session_state['run_id']}", 379 | index=0, 380 | horizontal=True, 381 | label_visibility="collapsed", 382 | ) 383 | 384 | col3.markdown("### Post Edit") 385 | ph = col3.empty() 386 | if better != "": 387 | st.session_state["user_input"]["final_text"] = ph.text_area( 388 | "Please post edit the text if needed", 389 | st.session_state["sample"][f"{better.lower()}_text"], 390 | key=f"best_sys_{st.session_state['run_id']}", 391 | label_visibility="collapsed", 392 | ) 393 | col4.markdown("### Sentence Type") 394 | sentence_type_list = ["Statement", "Question", "Exclamation"] 395 | defult_idx = sentence_type_list.index(sample["sentence_type"].title()) 396 | sentence_type = col4.radio( 397 | "Sentence Type", sentence_type_list, key=f"sentence_type_{st.session_state['run_id']}", index=defult_idx, label_visibility="collapsed" 398 | ) 399 | st.session_state["user_input"]["final_sentence_type"] = sentence_type.lower() 400 | 401 | col5.markdown("#") 402 | col5.markdown("#") 403 | col5.markdown("#") 404 | col5.markdown("###") 405 | 406 | # create a divider 407 | st.markdown("---") 408 | # ask if you want to Discard or Save 409 | col1, col2 = st.columns(2) 410 | with col1: 411 | discard = st.checkbox("Discard", value=False, key=f"discard_{st.session_state['run_id']}") 412 | 413 | if discard: 414 | col1, col2 = st.columns(2) 415 | with col1: 416 | isRepeated = True if st.checkbox("Has Repeation", value=False) else False 417 | # isAccentRight = True if st.checkbox("Accent is Wrong", value=False) else False 418 | # isPronunciationRight = True if st.checkbox("Pronunciation is Wrong", value=False) else False 419 | # isClean = True if st.checkbox("Recording is not Clean", value=False) else False 420 | # isPausesRight = True if st.checkbox("Pauses are not right", value=False) else False 421 | # isSpeedRight = True if st.checkbox("Speed is not right", value=False) else False 422 | # isConsisent = True if st.checkbox("Voice is not consistent", value=False) else False 423 | incorrectProsody = True if st.checkbox("Incorrect prosody", value=False) else False 424 | inconsistentTextAudio = True if st.checkbox("Inconsistent text and audio", value=False) else False 425 | incorrectTrancuation = True if st.checkbox("Incorrect trancuation", value=False) else False 426 | soundArtifacts = True if st.checkbox("Sound artifacts", value=False) else False 427 | with col2: 428 | feedback = st.text_area("Feedback", value=st.session_state["user_input"]["feedback"]) 429 | 430 | st.session_state["user_input"]["isRepeated"] = isRepeated 431 | # st.session_state["user_input"]["isAccentRight"] = isAccentRight 432 | # st.session_state["user_input"]["isPronunciationRight"] = isPronunciationRight 433 | # st.session_state["user_input"]["isClean"] = isClean 434 | # st.session_state["user_input"]["isPausesRight"] = isPausesRight 435 | # st.session_state["user_input"]["isSpeedRight"] = isSpeedRight 436 | # st.session_state["user_input"]["isConsisent"] = isConsisent 437 | st.session_state["user_input"]["incorrectProsody"] = incorrectProsody 438 | st.session_state["user_input"]["inconsistentTextAudio"] = inconsistentTextAudio 439 | st.session_state["user_input"]["incorrectTrancuation"] = incorrectTrancuation 440 | st.session_state["user_input"]["soundArtifacts"] = soundArtifacts 441 | st.session_state["user_input"]["feedback"] = feedback 442 | else: 443 | st.session_state["user_input"]["isRepeated"] = False 444 | # st.session_state["user_input"]["isAccentRight"] = True 445 | # st.session_state["user_input"]["isPronunciationRight"] = True 446 | # st.session_state["user_input"]["isClean"] = True 447 | # st.session_state["user_input"]["isPausesRight"] = True 448 | # st.session_state["user_input"]["isSpeedRight"] = True 449 | # st.session_state["user_input"]["isConsisent"] = True 450 | st.session_state["user_input"]["incorrectProsody"] = False 451 | st.session_state["user_input"]["inconsistentTextAudio"] = False 452 | st.session_state["user_input"]["incorrectTrancuation"] = False 453 | st.session_state["user_input"]["soundArtifacts"] = False 454 | st.session_state["user_input"]["feedback"] = "" 455 | col1, col2 = st.columns((10, 1)) 456 | with col2: 457 | submitted2 = st.button("Submit", key=f"submit_st_{st.session_state['run_id']}") 458 | if submitted1 or submitted2: 459 | st.session_state["run_id"] += 1 460 | if discard: 461 | status = "Discarded" 462 | else: 463 | status = "Reviewed" 464 | st.session_state["user_input"]["status"] = status 465 | st.success("Submitted!") 466 | st.session_state["query_button"] = True 467 | st.session_state["annotate_button"] = True 468 | st.experimental_rerun() 469 | 470 | else: 471 | st.warning("No more samples to rate") 472 | 473 | else: 474 | st.warning("Select Annotator and Dataset") 475 | 476 | # Inject the JavaScript code to modify the controlsList attribute of the audio element 477 | st.markdown( 478 | """ 479 | 487 | """, 488 | unsafe_allow_html=True, 489 | ) 490 | 491 | 492 | app() 493 | -------------------------------------------------------------------------------- /src/web_app/annotator/🏠_Intro_annotator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import streamlit as st 5 | import streamlit_authenticator as stauth 6 | 7 | 8 | current_file_path = os.path.dirname(os.path.abspath(__file__)) 9 | # aapedn 3 parent directories to the path 10 | sys.path.append(os.path.join(current_file_path, "..", "..", "..", "..")) 11 | 12 | from src.logger import root_logger 13 | from src.paths import paths 14 | 15 | 16 | BASE_DIR = str(paths.PROJECT_ROOT_DIR.resolve()) 17 | 18 | app_logger = root_logger.getChild("web_app::home") 19 | 20 | import yaml 21 | from yaml.loader import SafeLoader 22 | 23 | 24 | config_file_path = paths.LOGIN_CONFIG_PATH 25 | with open(config_file_path) as file: 26 | config = yaml.load(file, Loader=SafeLoader) 27 | 28 | st.set_page_config(page_title="TTS QA App", page_icon="🔍", layout="wide") 29 | 30 | authenticator = stauth.Authenticate( 31 | config["credentials"], config["cookie"]["name"], config["cookie"]["key"], config["cookie"]["expiry_days"], config["preauthorized"] 32 | ) 33 | 34 | # sidebar 35 | 36 | name, authentication_status, username = authenticator.login("Login", "sidebar") 37 | if st.session_state["authentication_status"]: 38 | authenticator.logout("Logout", "sidebar") 39 | st.sidebar.write(f'Welcome *{st.session_state["name"]}*') 40 | 41 | elif st.session_state["authentication_status"] is False: 42 | st.sidebar.error("Username/password is incorrect") 43 | try: 44 | username_forgot_pw, email_forgot_password, random_password = authenticator.forgot_password("Forgot password", "sidebar") 45 | if username_forgot_pw: 46 | st.sidebar.success("New password sent securely") 47 | with open(config_file_path, "w") as file: 48 | yaml.dump(config, file, default_flow_style=False) 49 | # Random password to be transferred to user securely 50 | else: 51 | st.sidebar.error("Username not found") 52 | except Exception as e: 53 | st.error(e) 54 | elif st.session_state["authentication_status"] is None: 55 | st.sidebar.warning("Please enter your username and password") 56 | 57 | st.title("TTS QA: Annotator App") 58 | text = """ 59 | ## Record Audio from Uploaded CSV Files Prompt 60 | Welcome to the Record Audio from Uploaded CSV Files Prompt page of the Annotator Application! 61 | Here, you will be prompted to record audio for the uploaded CSV files assigned to you for QA. 62 | Simply click on the "Record" button and start speaking when prompted. You can listen to your 63 | recording and re-record if necessary. Once you are satisfied with your recording, submit it for review. 64 | You can track the status of your submissions on the "Dashboard" page. 65 | 66 | ## Do the QA for the Recordings Done in a Dataset Uploaded to the System 67 | Welcome to the Do the QA for the Recordings Done in a Dataset Uploaded to the System page of the Annotator Application! 68 | Here, you will be assigned datasets for QA based on your expertise and availability. Simply listen to the audio recordings 69 | and review them for accuracy and completeness. You can mark the recordings as correct, incorrect, or request further review if necessary. 70 | You can also provide comments or feedback to the admin team if you encounter any issues during the QA process. Use your expertise and attention 71 | to detail to ensure the quality of the datasets in the system. 72 | """ 73 | 74 | 75 | st.markdown(text) 76 | -------------------------------------------------------------------------------- /test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from src.utils.audio import evaluate_audio\n", 10 | "evaluate_audio(\"/home/ubuntu/repos/tts-qa/DE00000044.wav\")" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [] 19 | } 20 | ], 21 | "metadata": { 22 | "kernelspec": { 23 | "display_name": "corpus-insight", 24 | "language": "python", 25 | "name": "python3" 26 | }, 27 | "language_info": { 28 | "codemirror_mode": { 29 | "name": "ipython", 30 | "version": 3 31 | }, 32 | "file_extension": ".py", 33 | "mimetype": "text/x-python", 34 | "name": "python", 35 | "nbconvert_exporter": "python", 36 | "pygments_lexer": "ipython3", 37 | "version": "3.8.8" 38 | } 39 | }, 40 | "nbformat": 4, 41 | "nbformat_minor": 2 42 | } 43 | -------------------------------------------------------------------------------- /vars.env: -------------------------------------------------------------------------------- 1 | # Environment variables go here, can be read by `python-dotenv` package: 2 | # 3 | # `src/script.py` 4 | # ---------------------------------------------------------------- 5 | # import dotenv 6 | # 7 | # project_dir = os.path.join(os.path.dirname(__file__), os.pardir) 8 | # dotenv_path = os.path.join(project_dir, '.env') 9 | # dotenv.load_dotenv(dotenv_path) 10 | # ---------------------------------------------------------------- 11 | 12 | S3_BUCKET_NAME= 13 | S3_DATASET_DIR= 14 | 15 | SERVER_HOST=127.0.0.1 16 | SERVER_PORT=8089 17 | SERVER_API_VERSION=v1 18 | 19 | ENV=dev # dev or prod 20 | 21 | POSTGRES_HOST=localhost 22 | POSTGRES_PORT=5432 23 | POSTGRES_USER=postgres 24 | POSTGRES_PWD=postgres 25 | POSTGRES_DB=${ENV}_tts_db 26 | POSTGRES_URL=postgresql+psycopg2://${POSTGRES_USER}:${POSTGRES_PWD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB} 27 | 28 | MAX_LOCKING_MIN=5 29 | 30 | # AWS_ACCESS_KEY_ID=your-access-key-id 31 | # AWS_SECRET_ACCESS_KEY=your-secret-access-key 32 | # AWS_REGION=us-east-1 33 | --------------------------------------------------------------------------------