├── images ├── summary.PNG ├── special_char.PNG └── sphx_glr_forced_alignment_tutorial_005.png ├── setup ├── create_compute_clusters.sh ├── create-gpu-cluster-t4.yaml ├── create-gpu-cluster-a100.yaml └── create-cpu-cluster.yaml ├── stt_aml_deploy ├── config │ ├── batch_endpoint.yaml │ └── parallel_job.yaml ├── components │ ├── remove_stt_data │ │ ├── docker │ │ │ └── Dockerfile │ │ ├── remove_stt_data.yaml │ │ └── main.py │ ├── asr │ │ ├── docker │ │ │ └── Dockerfile │ │ └── src │ │ │ └── main.py │ ├── diar │ │ ├── docker │ │ │ └── Dockerfile │ │ └── src │ │ │ └── main.py │ ├── nfa │ │ ├── docker │ │ │ └── Dockerfile │ │ └── src │ │ │ └── main.py │ ├── merge_align │ │ ├── docker │ │ │ └── Dockerfile │ │ └── src │ │ │ └── main.py │ └── prep │ │ ├── docker │ │ └── Dockerfile │ │ └── src │ │ └── main.py ├── setup │ ├── register_components │ │ ├── register_rsttd_component.py │ │ ├── register_nfa_component.py │ │ ├── register_diar_component.py │ │ ├── register_ma_component.py │ │ ├── register_prep_component.py │ │ └── register_asr_component.py │ ├── register_environments.py │ ├── create_batch_deployment.py │ └── register_pipeline.py └── parallel_job.py ├── LICENSE ├── .gitignore └── README.md /images/summary.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hedrergudene/asr-sd-pipeline/HEAD/images/summary.PNG -------------------------------------------------------------------------------- /images/special_char.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hedrergudene/asr-sd-pipeline/HEAD/images/special_char.PNG -------------------------------------------------------------------------------- /images/sphx_glr_forced_alignment_tutorial_005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hedrergudene/asr-sd-pipeline/HEAD/images/sphx_glr_forced_alignment_tutorial_005.png -------------------------------------------------------------------------------- /setup/create_compute_clusters.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | az ml compute create -f ./create-cpu-cluster.yaml 3 | az ml compute create -f ./create-gpu-cluster-t4.yaml 4 | az ml compute create -f ./create-gpu-cluster-a100.yaml -------------------------------------------------------------------------------- /stt_aml_deploy/config/batch_endpoint.yaml: -------------------------------------------------------------------------------- 1 | azure: 2 | subscription_id : 'XXXXXX' 3 | resource_group : 'XXXXXX' 4 | aml_workspace_name : 'XXXXXX' 5 | endpoint: 6 | name: 'stt-service' 7 | description: 'API to run a STT batch inference AzureML pipeline.' 8 | deployment: 9 | name: 10 | description: 11 | default_compute: 'XXXXXX' -------------------------------------------------------------------------------- /stt_aml_deploy/components/remove_stt_data/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Base image 2 | FROM mcr.microsoft.com/azureml/curated/python-sdk-v2:10 3 | 4 | # Install dependencies 5 | RUN pip install 'azure-storage-blob==12.13.0' \ 6 | 'pandas==2.0.2' \ 7 | 'numpy==1.24.3' \ 8 | 'protobuf==3.20.0' \ 9 | 'pyyaml==5.4.1' \ 10 | 'fire==0.4.0' -------------------------------------------------------------------------------- /setup/create-gpu-cluster-t4.yaml: -------------------------------------------------------------------------------- 1 | $schema: https://azuremlschemas.azureedge.net/latest/amlCompute.schema.json 2 | name: gpu-cluster-t4 3 | type: amlcompute 4 | size: Standard_NC16as_T4_v3 # (16 cores, 110 GB RAM, 352 GB disk) 5 | min_instances: 0 6 | max_instances: 32 7 | idle_time_before_scale_down: 120 8 | tier: dedicated 9 | ssh_public_access_enabled: false 10 | identity.type: system_assigned 11 | location: eastus2 -------------------------------------------------------------------------------- /setup/create-gpu-cluster-a100.yaml: -------------------------------------------------------------------------------- 1 | $schema: https://azuremlschemas.azureedge.net/latest/amlCompute.schema.json 2 | name: gpu-cluster-a100 3 | type: amlcompute 4 | size: Standard_NC24ads_A100_v4 # (24 cores, 220 GB RAM, 1024 GB disk) 5 | min_instances: 0 6 | max_instances: 32 7 | idle_time_before_scale_down: 120 8 | tier: dedicated 9 | ssh_public_access_enabled: false 10 | identity.type: system_assigned 11 | location: eastus2 -------------------------------------------------------------------------------- /setup/create-cpu-cluster.yaml: -------------------------------------------------------------------------------- 1 | $schema: https://azuremlschemas.azureedge.net/latest/amlCompute.schema.json 2 | name: cpu-cluster 3 | type: amlcompute 4 | size: STANDARD_DS11_v2 # In case 4 cores are needed, switch to STANDARD_DS3_v2. For large ML trainings in CPU, use Standard_D13_v2 5 | min_instances: 0 6 | max_instances: 32 7 | idle_time_before_scale_down: 120 8 | tier: dedicated 9 | ssh_public_access_enabled: false 10 | identity.type: system_assigned 11 | location: eastus2 -------------------------------------------------------------------------------- /stt_aml_deploy/components/asr/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Base image 2 | FROM mcr.microsoft.com/azureml/curated/acpt-pytorch-2.1-cuda12.1:6 3 | 4 | RUN apt-get update && apt-get install -y sox libsndfile1 ffmpeg 5 | RUN pip install 'mltable>=1.2.0' \ 6 | 'azureml-dataset-runtime[pandas,fuse]==1.53.0' \ 7 | 'azureml-telemetry==1.53.0' \ 8 | 'azureml-core==1.53.0' \ 9 | 'azure-identity==1.15.0' \ 10 | 'azure-keyvault-secrets==4.7.0' \ 11 | 'pgpy==0.6.0' \ 12 | 'faster-whisper==1.0.1' -------------------------------------------------------------------------------- /stt_aml_deploy/components/diar/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Base image (check CUDA version is the same as PyTorch one) 2 | FROM nvcr.io/nvidia/nemo:23.06 3 | 4 | # Install pip dependencies 5 | ## Diarization 6 | RUN apt-get update && apt-get install -y sox libsndfile1 ffmpeg 7 | RUN pip install 'mltable>=1.2.0' \ 8 | 'azureml-dataset-runtime[pandas,fuse]==1.53.0' \ 9 | 'azureml-telemetry==1.53.0' \ 10 | 'azureml-core==1.53.0' \ 11 | 'azure-identity==1.15.0' \ 12 | 'azure-keyvault-secrets==4.7.0' \ 13 | 'pgpy==0.6.0' \ 14 | 'pyyaml==5.4.1' \ 15 | 'fire==0.5.0' -------------------------------------------------------------------------------- /stt_aml_deploy/components/nfa/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Base image (check CUDA version is the same as PyTorch one) 2 | #FROM mcr.microsoft.com/azureml/minimal-ubuntu20.04-py38-cuda11.6.2-gpu-inference:latest 3 | FROM nvcr.io/nvidia/nemo:23.06 4 | 5 | # Install pip dependencies 6 | ## ASR 7 | RUN apt-get update && apt-get install -y sox libsndfile1 ffmpeg 8 | RUN pip install 'mltable>=1.2.0' \ 9 | 'azureml-dataset-runtime[pandas,fuse]==1.53.0' \ 10 | 'azureml-telemetry==1.53.0' \ 11 | 'azureml-core==1.53.0' \ 12 | 'azure-ai-ml==1.10.1' \ 13 | 'azure-identity==1.15.0' \ 14 | 'azure-keyvault-secrets==4.7.0' \ 15 | 'pgpy==0.6.0' \ 16 | 'pyyaml==5.4.1' \ 17 | 'fire==0.5.0' -------------------------------------------------------------------------------- /stt_aml_deploy/components/merge_align/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Base image (check CUDA version is the same as PyTorch one) 2 | #FROM mcr.microsoft.com/azureml/minimal-ubuntu20.04-py38-cuda11.6.2-gpu-inference:latest 3 | FROM mcr.microsoft.com/azureml/curated/acpt-pytorch-2.0-cuda11.7 4 | 5 | # Install pip dependencies 6 | ## ASR 7 | RUN pip install 'transformers==4.35.0' \ 8 | 'protobuf==3.20.0' \ 9 | 'pyyaml==5.4.1' \ 10 | 'mltable>=1.2.0' \ 11 | 'azureml-dataset-runtime[pandas,fuse]==1.53.0' \ 12 | 'azureml-telemetry==1.53.0' \ 13 | 'azureml-core==1.53.0' \ 14 | 'azure-identity==1.15.0' \ 15 | 'azure-keyvault-secrets==4.7.0' \ 16 | 'pgpy==0.6.0' \ 17 | 'pymongo==4.6.1' -------------------------------------------------------------------------------- /stt_aml_deploy/components/prep/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Base image (check CUDA version is the same as PyTorch one) 2 | #FROM mcr.microsoft.com/azureml/minimal-ubuntu20.04-py38-cuda11.6.2-gpu-inference:latest 3 | FROM mcr.microsoft.com/azureml/curated/acpt-pytorch-2.0-cuda11.7 4 | 5 | # Install pip dependencies 6 | ## ASR 7 | RUN apt-get update && apt-get install -y sox libsndfile1 ffmpeg 8 | RUN pip install 'mltable>=1.2.0' \ 9 | 'azureml-dataset-runtime[pandas,fuse]==1.53.0' \ 10 | 'azureml-telemetry==1.53.0' \ 11 | 'azureml-core==1.53.0' \ 12 | 'azure-identity==1.15.0' \ 13 | 'azure-keyvault-secrets==4.7.0' \ 14 | 'pgpy==0.6.0' \ 15 | 'pymongo==4.6.1' \ 16 | 'demucs==4.0.1' \ 17 | 'protobuf==3.20.0' \ 18 | 'pyyaml==5.4.1' -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Antonio Zarauz Moreno 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /stt_aml_deploy/components/remove_stt_data/remove_stt_data.yaml: -------------------------------------------------------------------------------- 1 | $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json 2 | type: command 3 | 4 | # General information about the component 5 | name: remove_stt_data 6 | display_name: Remove STT data 7 | description: A template component to drop intermediate data from target container after azureml pipeline run. 8 | tags: 9 | author: IA-Cognitive 10 | 11 | # Inputs and outputs 12 | inputs: 13 | input_path: 14 | type: uri_folder 15 | optional: false 16 | storage_id: 17 | type: string 18 | optional: false 19 | container_name: 20 | type: string 21 | optional: false 22 | blob_filepath: 23 | type: string 24 | optional: false 25 | 26 | # The source code path of it's defined in the code section and when the 27 | # component is run in cloud, all files from that path will be uploaded 28 | # as the snapshot of this component 29 | code: ./ 30 | 31 | # Environment takes care of source image and dependencies 32 | # https://learn.microsoft.com/en-us/azure/machine-learning/how-to-manage-environments-v2?view=azureml-api-2&tabs=cli 33 | environment: "azureml:remove_stt_data_env:1" 34 | 35 | # Cluster instance 36 | compute: azureml:cpu-cluster 37 | 38 | # Distribution type 39 | distribution: 40 | type: mpi 41 | process_count_per_instance: 1 # Number of nodes per instance 42 | 43 | # How many VMs we need 44 | resources: 45 | instance_count: 1 # Number of instances to create 46 | 47 | # The command section specifies the command to execute while running 48 | # this component 49 | command: python ./main.py --input_path ${{inputs.input_path}} --storage_id ${{inputs.storage_id}} --container_name ${{inputs.container_name}} --blob_filepath ${{inputs.blob_filepath}} -------------------------------------------------------------------------------- /stt_aml_deploy/config/parallel_job.yaml: -------------------------------------------------------------------------------- 1 | aml: 2 | subscription_id : 'XXXXXX' 3 | resource_group : 'XXXXXX' 4 | workspace_name : 'XXXXXX' 5 | computing: 6 | cpu_cluster: 'XXXXXX' 7 | gpu_cluster_t4 : 'XXXXXX' 8 | gpu_cluster_a100 : 'XXXXXX' 9 | project_name : 'stt_pipeline' 10 | 11 | blob: # azureml://datastores//paths/ (blob_path might have some folder structure underneath) 12 | storage_id: 'XXXXXX' 13 | container_name: 'XXXXXX' 14 | input_path: 'XXXXXX' 15 | output_path: 'XXXXXX' 16 | 17 | keyvault: 18 | name: 'XXXXXX' 19 | secret_tenant_sp: "XXXXXX" 20 | secret_client_sp: "XXXXXX" 21 | secret_sp: "XXXXXX" 22 | pk_secret: 'XXXXXX' 23 | pk_pass_secret: 'XXXXXX' 24 | pubk_secret: 'XXXXXX' 25 | 26 | cosmosdb: 27 | name: 'XXXXXX' 28 | collection: 'XXXXXX' 29 | cs_secret: 'XXXXXX' 30 | 31 | preprocessing: 32 | vad_threshold: 0.8 33 | min_speech_duration_ms: 250 34 | min_silence_duration_ms: 400 35 | demucs_model: 'htdemucs' 36 | 37 | asr: 38 | model_name: "large-v3" 39 | num_workers: 4 40 | beam_size: 5 41 | word_level_timestamps: true 42 | condition_on_previous_text: true 43 | language_code: 'es' 44 | compute_type: 'float16' 45 | 46 | fa: 47 | model_name: "stt_es_fastconformer_hybrid_large_pc" 48 | batch_size: 16 49 | 50 | diarization: 51 | event_type: 'telephonic' #could also be 'meeting' 52 | max_num_speakers: 3 53 | min_window_length: 0.2 54 | overlap_threshold: 0.8 55 | 56 | align: 57 | max_words_in_sentence: 60 58 | ner_chunk_size: 50 59 | ner_stride: 5 60 | 61 | job: 62 | instance_count_large: 24 63 | instance_count_small: 6 64 | max_concurrency_per_instance: 1 65 | mini_batch_size: "1" 66 | mini_batch_error_threshold: 1 67 | error_threshold: 1 68 | max_retries: 2 69 | timeout: 3000 70 | allowed_failed_percent: 1 71 | task_overhead_timeout: 300 72 | first_task_creation_timeout: 1200 73 | resource_monitor_interval: 300 -------------------------------------------------------------------------------- /stt_aml_deploy/components/remove_stt_data/main.py: -------------------------------------------------------------------------------- 1 | # Requierments 2 | import logging as log 3 | import re 4 | import os 5 | import sys 6 | from pathlib import Path 7 | import pandas as pd 8 | from azure.storage.blob import BlobServiceClient 9 | from azure.identity import DefaultAzureCredential 10 | import fire 11 | 12 | # Setup logs 13 | root = log.getLogger() 14 | root.setLevel(log.DEBUG) 15 | handler = log.StreamHandler(sys.stdout) 16 | handler.setLevel(log.DEBUG) 17 | formatter = log.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 18 | handler.setFormatter(formatter) 19 | root.addHandler(handler) 20 | 21 | # Helper method to remove blobs 22 | def delete_blob(blob_service_client: BlobServiceClient, container_name: str, blob_name: str): 23 | blob_client = blob_service_client.get_blob_client(container=container_name, blob=blob_name) 24 | blob_client.delete_blob() 25 | 26 | # Main method. Fire automatically allign method arguments with parse commands from console 27 | def main( 28 | input_path, 29 | storage_id, 30 | container_name, 31 | blob_filepath 32 | ): 33 | # Check if given credential can get token successfully 34 | credential = DefaultAzureCredential() 35 | credential.get_token("https://management.azure.com/.default") 36 | # Create a blob client using the local file name as the name for the blob 37 | account_url = f"https://{storage_id}.blob.core.windows.net" 38 | blob_service_client = BlobServiceClient(account_url, credential=credential) 39 | # Loop 40 | regex_fn = lambda pattern,text : len(re.findall(pattern, text))>0 41 | for elem in os.listdir(input_path): 42 | if ((regex_fn('\.wav\.pgp', elem)) | (regex_fn('_prep', elem))| (regex_fn('_asr', elem))| (regex_fn('_nfa', elem))| (regex_fn('_diar', elem))): 43 | delete_blob(blob_service_client, container_name, os.path.join(blob_filepath,elem)) 44 | else: 45 | continue 46 | 47 | if __name__=="__main__": 48 | fire.Fire(main) -------------------------------------------------------------------------------- /stt_aml_deploy/setup/register_components/register_rsttd_component.py: -------------------------------------------------------------------------------- 1 | # Libraries 2 | import yaml 3 | import sys 4 | import logging as log 5 | from azure.identity import DefaultAzureCredential 6 | from azure.ai.ml import MLClient, load_component 7 | import fire 8 | 9 | # Setup logs 10 | root = log.getLogger() 11 | root.setLevel(log.DEBUG) 12 | handler = log.StreamHandler(sys.stdout) 13 | handler.setLevel(log.DEBUG) 14 | formatter = log.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 15 | handler.setFormatter(formatter) 16 | root.addHandler(handler) 17 | 18 | # Main method. Fire automatically align method arguments with parse commands from console 19 | def main( 20 | config_path='.../config/parallel_job.yaml' 21 | ): 22 | 23 | # Get credential token 24 | log.info("Get credential token:") 25 | try: 26 | credential = DefaultAzureCredential() 27 | credential.get_token("https://management.azure.com/.default") 28 | except Exception as ex: 29 | log.error(f"Something went wrong regarding authentication. Returned error is: {ex.message}") 30 | return (f"Something went wrong regarding authentication. Returned error is: {ex.message}", 500) 31 | 32 | # Fetch configuration file 33 | log.info("Fetch configuration file:") 34 | with open(config_path) as file: 35 | config_dct = yaml.load(file, Loader=yaml.FullLoader) 36 | # Get a handle to workspace 37 | log.info("Set up ML Client:") 38 | ml_client = MLClient( 39 | credential=credential, 40 | subscription_id=config_dct['aml']['subscription_id'], 41 | resource_group_name=config_dct['aml']['resource_group'], 42 | workspace_name=config_dct['aml']['workspace_name'], 43 | ) 44 | 45 | # Load component 46 | rsttd_comp = load_component(source=".../components/remove_stt_data/remove_stt_data.yaml") 47 | 48 | # Component register 49 | ml_client.components.create_or_update(rsttd_comp, version="1") 50 | 51 | 52 | if __name__=="__main__": 53 | fire.Fire(main) -------------------------------------------------------------------------------- /stt_aml_deploy/setup/register_environments.py: -------------------------------------------------------------------------------- 1 | # Libraries 2 | import yaml 3 | import sys 4 | import os 5 | import logging as log 6 | from azure.identity import DefaultAzureCredential 7 | from azure.ai.ml import MLClient 8 | from azure.ai.ml.entities import Environment, BuildContext 9 | import fire 10 | 11 | # Setup logs 12 | root = log.getLogger() 13 | root.setLevel(log.DEBUG) 14 | handler = log.StreamHandler(sys.stdout) 15 | handler.setLevel(log.DEBUG) 16 | formatter = log.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 17 | handler.setFormatter(formatter) 18 | root.addHandler(handler) 19 | 20 | # Main method. Fire automatically allign method arguments with parse commands from console 21 | def main( 22 | config_path:str='./config/online_endpoint.yaml' 23 | ): 24 | 25 | # Get credential token 26 | log.info("Get credential token:") 27 | try: 28 | credential = DefaultAzureCredential() 29 | credential.get_token("https://management.azure.com/.default") 30 | except Exception as ex: 31 | log.error(f"Something went wrong regarding authentication. Returned error is: {ex.message}") 32 | return (f"Something went wrong regarding authentication. Returned error is: {ex.message}", 500) 33 | 34 | # Fetch configuration file 35 | log.info("Fetch configuration file:") 36 | with open(config_path) as file: 37 | config_dct = yaml.load(file, Loader=yaml.FullLoader) 38 | # Get a handle to workspace 39 | log.info("Set up ML Client:") 40 | ml_client = MLClient( 41 | credential=credential, 42 | subscription_id=config_dct['azure']['subscription_id'], 43 | resource_group_name=config_dct['azure']['resource_group'], 44 | workspace_name=config_dct['azure']['aml_workspace_name'], 45 | ) 46 | 47 | # Build environments for AML pipeline 48 | for comp_name in os.listdir('./components'): 49 | log.info(f"Building environment {comp_name}:") 50 | env_docker_context = Environment( 51 | build=BuildContext(path=f"../components/{comp_name}/docker"), 52 | name=f"{comp_name}_env", 53 | description=f"Environment for {comp_name} component of speech to text solution.", 54 | ) 55 | ml_client.environments.create_or_update(env_docker_context) 56 | 57 | if __name__=="__main__": 58 | fire.Fire(main) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /stt_aml_deploy/setup/create_batch_deployment.py: -------------------------------------------------------------------------------- 1 | # Libraries 2 | import yaml 3 | import sys 4 | import os 5 | import logging as log 6 | from azure.ai.ml import MLClient, load_component 7 | from azure.ai.ml.entities import BatchEndpoint, PipelineComponentBatchDeployment 8 | from azure.identity import DefaultAzureCredential 9 | import fire 10 | 11 | # Setup logs 12 | root = log.getLogger() 13 | root.setLevel(log.DEBUG) 14 | handler = log.StreamHandler(sys.stdout) 15 | handler.setLevel(log.DEBUG) 16 | formatter = log.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 17 | handler.setFormatter(formatter) 18 | root.addHandler(handler) 19 | 20 | # Main method. Fire automatically allign method arguments with parse commands from console 21 | def main( 22 | config_path:str='./config/batch_endpoint.yaml' 23 | ): 24 | 25 | # Get credential token 26 | log.info("Get credential token:") 27 | try: 28 | credential = DefaultAzureCredential() 29 | credential.get_token("https://management.azure.com/.default") 30 | except Exception as ex: 31 | log.error(f"Something went wrong regarding authentication. Returned error is: {ex.message}") 32 | return (f"Something went wrong regarding authentication. Returned error is: {ex.message}", 500) 33 | 34 | # Fetch configuration file 35 | log.info("Fetch configuration file:") 36 | with open(config_path) as file: 37 | config_dct = yaml.load(file, Loader=yaml.FullLoader) 38 | # Get a handle to workspace 39 | log.info("Set up ML Client:") 40 | ml_client = MLClient( 41 | credential=credential, 42 | subscription_id=config_dct['azure']['subscription_id'], 43 | resource_group_name=config_dct['azure']['resource_group'], 44 | workspace_name=config_dct['azure']['aml_workspace_name'], 45 | ) 46 | 47 | # Define the endpoint 48 | log.info("Define batch endpoint:") 49 | endpoint = BatchEndpoint( 50 | name=config_dct['endpoint']['name'], 51 | description=config_dct['endpoint']['description'] 52 | ) 53 | ml_client.batch_endpoints.begin_create_or_update(endpoint).result() 54 | 55 | # Load registered component 56 | log.info("Load registered component:") 57 | stt_batch = load_component(client=ml_client, name="stt", version="1") 58 | 59 | # Deploy pipeline component 60 | log.info("Deploy pipeline component:") 61 | deployment = PipelineComponentBatchDeployment( 62 | name=config_dct['deployment']['name'], 63 | description=config_dct['deployment']['description'], 64 | endpoint_name=endpoint.name, 65 | component=stt_batch, 66 | settings={"continue_on_step_failure": False, "default_compute": config_dct['deployment']['default_compute']}, 67 | ) 68 | ml_client.batch_deployments.begin_create_or_update(deployment).result() 69 | 70 | # Set as default one 71 | endpoint = ml_client.batch_endpoints.get(config_dct['endpoint']['name']) 72 | endpoint.defaults.deployment_name = deployment.name 73 | ml_client.batch_endpoints.begin_create_or_update(endpoint).result() 74 | 75 | 76 | if __name__=="__main__": 77 | fire.Fire(main) -------------------------------------------------------------------------------- /stt_aml_deploy/parallel_job.py: -------------------------------------------------------------------------------- 1 | # Libraries 2 | import yaml 3 | import sys 4 | import logging as log 5 | from azure.identity import DefaultAzureCredential 6 | from azure.ai.ml import MLClient, Input, load_component 7 | from azure.ai.ml.constants import AssetTypes, InputOutputModes 8 | import fire 9 | 10 | # Setup logs 11 | root = log.getLogger() 12 | root.setLevel(log.DEBUG) 13 | handler = log.StreamHandler(sys.stdout) 14 | handler.setLevel(log.DEBUG) 15 | formatter = log.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 16 | handler.setFormatter(formatter) 17 | root.addHandler(handler) 18 | 19 | # Main method. Fire automatically allign method arguments with parse commands from console 20 | def main( 21 | config_path:str='./config/parallel_job.yaml' 22 | ): 23 | 24 | # Get credential token 25 | log.info("Get credential token:") 26 | try: 27 | credential = DefaultAzureCredential() 28 | credential.get_token("https://management.azure.com/.default") 29 | except Exception as ex: 30 | log.error(f"Something went wrong regarding authentication. Returned error is: {ex.message}") 31 | return (f"Something went wrong regarding authentication. Returned error is: {ex.message}", 500) 32 | 33 | # Fetch configuration file 34 | log.info("Fetch configuration file:") 35 | with open(config_path) as file: 36 | config_dct = yaml.load(file, Loader=yaml.FullLoader) 37 | # Get a handle to workspace 38 | log.info("Set up ML Client:") 39 | ml_client = MLClient( 40 | credential=credential, 41 | subscription_id=config_dct['aml']['subscription_id'], 42 | resource_group_name=config_dct['aml']['resource_group'], 43 | workspace_name=config_dct['aml']['workspace_name'], 44 | ) 45 | 46 | 47 | # Set the input and output URI paths for the data. 48 | input_dts = Input( 49 | path=config_dct['blob']['input_path'], 50 | type=AssetTypes.URI_FOLDER, 51 | mode=InputOutputModes.RO_MOUNT #Alternative, DOWNLOAD 52 | ) 53 | 54 | # Load registered component 55 | stt_batch = load_component(client=ml_client, name="stt", version="1") 56 | 57 | 58 | # Create a pipeline 59 | pipeline_job = stt_batch( 60 | input_dts = input_dts, 61 | output_dts = config_dct['blob']['output_path'], 62 | storage_account_name = config_dct['blob']['storage_id'], 63 | container_name = config_dct['blob']['container_name'], 64 | blob_filepath = config_dct['blob']['blob_filepath'], 65 | aml_cpu_cluster = config_dct['aml']['computing']['cpu_cluster'], 66 | aml_t4_cluster = config_dct['aml']['computing']['gpu_cluster_t4'], 67 | aml_a100_cluster = config_dct['aml']['computing']['gpu_cluster_a100'], 68 | keyvault_name = config_dct['keyvault']['name'], 69 | secret_tenant_sp = config_dct['keyvault']['secret_tenant_sp'], 70 | secret_client_sp = config_dct['keyvault']['secret_client_sp'], 71 | secret_sp = config_dct['keyvault']['secret_sp'], 72 | pk_secret = config_dct['keyvault']['pk_secret'], 73 | pk_pass_secret = config_dct['keyvault']['pk_pass_secret'], 74 | pubk_secret = config_dct['keyvault']['pubk_secret'], 75 | cosmosdb_name = config_dct['cosmosdb']['name'], 76 | cosmosdb_collection = config_dct['cosmosdb']['collection'], 77 | cosmosdb_cs_secret = config_dct['cosmosdb']['cs_secret'], 78 | vad_threshold = config_dct['preprocessing']['vad_threshold'], 79 | min_speech_duration_ms = config_dct['preprocessing']['min_speech_duration_ms'], 80 | min_silence_duration_ms = config_dct['preprocessing']['min_silence_duration_ms'], 81 | demucs_model = config_dct['preprocessing']['demucs_model'], 82 | asr_model_name = config_dct['asr']['model_name'], 83 | asr_num_workers = config_dct['asr']['num_workers'], 84 | asr_beam_size = config_dct['asr']['beam_size'], 85 | word_level_timestamps = config_dct['asr']['word_level_timestamps'], 86 | condition_on_previous_text = config_dct['asr']['condition_on_previous_text'], 87 | asr_compute_type = config_dct['asr']['compute_type'], 88 | asr_language_code = config_dct['asr']['language_code'], 89 | nfa_model_name = config_dct['fa']['model_name'], 90 | nfa_batch_size = config_dct['fa']['batch_size'], 91 | diar_event_type = config_dct['diarization']['event_type'], 92 | diar_max_num_speakers = config_dct['diarization']['max_num_speakers'], 93 | diar_min_window_length = config_dct['diarization']['min_window_length'], 94 | diar_overlap_threshold = config_dct['diarization']['overlap_threshold'], 95 | ma_ner_chunk_size = config_dct['align']['ner_chunk_size'], 96 | ma_ner_stride = config_dct['align']['ner_stride'], 97 | ma_max_words_in_sentence = config_dct['align']['max_words_in_sentence'] 98 | ) 99 | # Run job 100 | pipeline_job = ml_client.jobs.create_or_update( 101 | pipeline_job, experiment_name=config_dct['aml']['project_name'] 102 | ) 103 | 104 | 105 | if __name__=="__main__": 106 | fire.Fire(main) -------------------------------------------------------------------------------- /stt_aml_deploy/setup/register_components/register_nfa_component.py: -------------------------------------------------------------------------------- 1 | # Libraries 2 | import yaml 3 | import sys 4 | import logging as log 5 | from azure.identity import DefaultAzureCredential 6 | from azure.ai.ml import MLClient, Input, Output 7 | from azure.ai.ml.dsl import pipeline 8 | from azure.ai.ml.entities import RetrySettings 9 | from azure.ai.ml.constants import AssetTypes, InputOutputModes 10 | from azure.ai.ml.parallel import parallel_run_function, RunFunction 11 | import fire 12 | 13 | # Setup logs 14 | root = log.getLogger() 15 | root.setLevel(log.DEBUG) 16 | handler = log.StreamHandler(sys.stdout) 17 | handler.setLevel(log.DEBUG) 18 | formatter = log.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 19 | handler.setFormatter(formatter) 20 | root.addHandler(handler) 21 | 22 | # Main method. Fire automatically align method arguments with parse commands from console 23 | def main( 24 | config_path='.../config/parallel_job.yaml' 25 | ): 26 | 27 | # Get credential token 28 | log.info("Get credential token:") 29 | try: 30 | credential = DefaultAzureCredential() 31 | credential.get_token("https://management.azure.com/.default") 32 | except Exception as ex: 33 | log.error(f"Something went wrong regarding authentication. Returned error is: {ex.message}") 34 | return (f"Something went wrong regarding authentication. Returned error is: {ex.message}", 500) 35 | 36 | # Fetch configuration file 37 | log.info("Fetch configuration file:") 38 | with open(config_path) as file: 39 | config_dct = yaml.load(file, Loader=yaml.FullLoader) 40 | # Get a handle to workspace 41 | log.info("Set up ML Client:") 42 | ml_client = MLClient( 43 | credential=credential, 44 | subscription_id=config_dct['aml']['subscription_id'], 45 | resource_group_name=config_dct['aml']['resource_group'], 46 | workspace_name=config_dct['aml']['workspace_name'], 47 | ) 48 | 49 | # 50 | # Declare Parallel task to perform forced alignment, based on Viterbi algorithm by nvidia implementation 51 | # For detailed info, check: https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-parallel-job-in-pipeline?view=azureml-api-2&tabs=python 52 | # 53 | nfa_component = parallel_run_function( 54 | name="pNFA", 55 | display_name="Parallel forced alignment", 56 | description="Parallel component to perform NFA on a large amount of audios", 57 | inputs=dict( 58 | input_audio_path=Input(type=AssetTypes.URI_FOLDER, description="Audios to be transcribed"), 59 | input_asr_path=Input(type=AssetTypes.URI_FOLDER, description="Transcriptions of audios to be analysed"), 60 | keyvault_name=Input(type="string"), 61 | secret_tenant_sp=Input(type="string"), 62 | secret_client_sp=Input(type="string"), 63 | secret_sp=Input(type="string"), 64 | pk_secret=Input(type="string"), 65 | pk_pass_secret=Input(type="string"), 66 | pubk_secret=Input(type="string"), 67 | nfa_model_name=Input(type="string", default=config_dct['fa']['model_name'], optional=True), 68 | batch_size=Input(type="integer", default=config_dct['fa']['batch_size'], optional=True) 69 | ), 70 | outputs=dict(output_fa_path=Output(type=AssetTypes.URI_FOLDER)), 71 | input_data="${{inputs.input_asr_path}}", 72 | instance_count=config_dct['job']['instance_count_large'], 73 | max_concurrency_per_instance=config_dct['job']['max_concurrency_per_instance'], 74 | mini_batch_size=config_dct['job']['mini_batch_size'], 75 | mini_batch_error_threshold=config_dct['job']['mini_batch_error_threshold'], 76 | logging_level="DEBUG", 77 | error_threshold=config_dct['job']['error_threshold'], 78 | retry_settings=RetrySettings( 79 | max_retries=config_dct['job']['max_retries'], 80 | timeout=config_dct['job']['timeout'] 81 | ), 82 | task=RunFunction( 83 | code=".../components/nfa/src", 84 | entry_script="main.py", 85 | environment=ml_client.environments.get(name="nfa_env", version="1"), 86 | program_arguments="--input_audio_path ${{inputs.input_audio_path}} " 87 | "--input_asr_path ${{inputs.input_asr_path}} " 88 | "--keyvault_name ${{inputs.keyvault_name}} " 89 | "--secret_tenant_sp ${{inputs.secret_tenant_sp}} " 90 | "--secret_client_sp ${{inputs.secret_client_sp}} " 91 | "--secret_sp ${{inputs.secret_sp}} " 92 | "--pk_secret ${{inputs.pk_secret}} " 93 | "--pk_pass_secret ${{inputs.pk_pass_secret}} " 94 | "--pubk_secret ${{inputs.pubk_secret}} " 95 | "$[[--nfa_model_name ${{inputs.nfa_model_name}}]] " 96 | "$[[--batch_size ${{inputs.batch_size}}]] " 97 | "--output_fa_path ${{outputs.output_fa_path}} " 98 | f"--allowed_failed_percent {config_dct['job']['allowed_failed_percent']} " 99 | f"--task_overhead_timeout {config_dct['job']['task_overhead_timeout']} " 100 | f"--first_task_creation_timeout {config_dct['job']['first_task_creation_timeout']} " 101 | f"--resource_monitor_interval {config_dct['job']['resource_monitor_interval']} ", 102 | # All values output by run() method invocations will be aggregated into one unique file which is created in the output location. 103 | # If it is not set, 'summary_only' would invoked, which means user script is expected to store the output itself. 104 | #append_row_to="${{outputs.output_path}}" 105 | ), 106 | ) 107 | 108 | # 109 | # Create pipeline 110 | # 111 | 112 | @pipeline(default_compute=config_dct['aml']['computing']['gpu_cluster_t4']) 113 | def nfa( 114 | input_dts:Input(type=AssetTypes.URI_FOLDER, mode=InputOutputModes.RO_MOUNT), 115 | input_asr:Input(type=AssetTypes.URI_FOLDER, mode=InputOutputModes.RO_MOUNT), 116 | output_dts:Input(type="string"), 117 | aml_compute:Input(type="string"), 118 | keyvault_name:Input(type="string"), 119 | secret_tenant_sp:Input(type="string"), 120 | secret_client_sp:Input(type="string"), 121 | secret_sp:Input(type="string"), 122 | pk_secret:Input(type="string"), 123 | pk_pass_secret:Input(type="string"), 124 | pubk_secret:Input(type="string"), 125 | model_name:Input(type="string", default=config_dct['fa']['model_name'], optional=True), 126 | batch_size:Input(type="integer", default=config_dct['fa']['batch_size'], optional=True) 127 | ): 128 | 129 | nfa_node = nfa_component( 130 | input_audio_path=input_dts, 131 | input_asr_path=input_asr, 132 | keyvault_name=keyvault_name, 133 | secret_tenant_sp=secret_tenant_sp, 134 | secret_client_sp=secret_client_sp, 135 | secret_sp=secret_sp, 136 | pk_secret=pk_secret, 137 | pk_pass_secret=pk_pass_secret, 138 | pubk_secret=pubk_secret, 139 | nfa_model_name=model_name, 140 | batch_size=batch_size 141 | ) 142 | nfa_node.outputs.output_fa_path = Output( 143 | path=output_dts, 144 | type=AssetTypes.URI_FOLDER, 145 | mode=InputOutputModes.RW_MOUNT 146 | ) 147 | nfa_node.compute = aml_compute 148 | 149 | return {'output_dts': nfa_node.outputs.output_fa_path} 150 | 151 | 152 | # Create a pipeline 153 | pipeline_job = nfa() 154 | 155 | # Component register 156 | ml_client.components.create_or_update(pipeline_job.component, version="1") 157 | 158 | 159 | if __name__=="__main__": 160 | fire.Fire(main) -------------------------------------------------------------------------------- /stt_aml_deploy/setup/register_components/register_diar_component.py: -------------------------------------------------------------------------------- 1 | # Libraries 2 | import yaml 3 | import sys 4 | import logging as log 5 | from azure.identity import DefaultAzureCredential 6 | from azure.ai.ml import MLClient, Input, Output 7 | from azure.ai.ml.dsl import pipeline 8 | from azure.ai.ml.entities import RetrySettings 9 | from azure.ai.ml.constants import AssetTypes, InputOutputModes 10 | from azure.ai.ml.parallel import parallel_run_function, RunFunction 11 | import fire 12 | 13 | # Setup logs 14 | root = log.getLogger() 15 | root.setLevel(log.DEBUG) 16 | handler = log.StreamHandler(sys.stdout) 17 | handler.setLevel(log.DEBUG) 18 | formatter = log.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 19 | handler.setFormatter(formatter) 20 | root.addHandler(handler) 21 | 22 | # Main method. Fire automatically align method arguments with parse commands from console 23 | def main( 24 | config_path='.../config/parallel_job.yaml' 25 | ): 26 | 27 | # Get credential token 28 | log.info("Get credential token:") 29 | try: 30 | credential = DefaultAzureCredential() 31 | credential.get_token("https://management.azure.com/.default") 32 | except Exception as ex: 33 | log.error(f"Something went wrong regarding authentication. Returned error is: {ex.message}") 34 | return (f"Something went wrong regarding authentication. Returned error is: {ex.message}", 500) 35 | 36 | # Fetch configuration file 37 | log.info("Fetch configuration file:") 38 | with open(config_path) as file: 39 | config_dct = yaml.load(file, Loader=yaml.FullLoader) 40 | # Get a handle to workspace 41 | log.info("Set up ML Client:") 42 | ml_client = MLClient( 43 | credential=credential, 44 | subscription_id=config_dct['aml']['subscription_id'], 45 | resource_group_name=config_dct['aml']['resource_group'], 46 | workspace_name=config_dct['aml']['workspace_name'], 47 | ) 48 | 49 | # 50 | # Declare Parallel task to perform speaker diarization 51 | # For detailed info, check: https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-parallel-job-in-pipeline?view=azureml-api-2&tabs=python 52 | # 53 | diar_component = parallel_run_function( 54 | name="pMSDD", 55 | display_name="Parallel diarization", 56 | description="Parallel component to perform speaker diarization on a large amount of audios", 57 | inputs=dict( 58 | input_audio_path=Input(type=AssetTypes.URI_FOLDER, description="Audios to be diarized"), 59 | input_asr_path=Input(type=AssetTypes.URI_FOLDER, description="Transcriptions of those audios"), 60 | keyvault_name=Input(type="string"), 61 | secret_tenant_sp=Input(type="string"), 62 | secret_client_sp=Input(type="string"), 63 | secret_sp=Input(type="string"), 64 | pk_secret=Input(type="string"), 65 | pk_pass_secret=Input(type="string"), 66 | pubk_secret=Input(type="string"), 67 | event_type=Input(type="string", default=config_dct['diarization']['event_type'], optional=True), 68 | max_num_speakers=Input(type="integer", default=config_dct['diarization']['max_num_speakers'], optional=True), 69 | min_window_length=Input(type="number", default=config_dct['diarization']['min_window_length'], optional=True), 70 | overlap_threshold=Input(type="number", default=config_dct['diarization']['overlap_threshold'], optional=True) 71 | ), 72 | outputs=dict(output_diar_path=Output(type=AssetTypes.URI_FOLDER)), 73 | input_data="${{inputs.input_asr_path}}", 74 | instance_count=config_dct['job']['instance_count_small'], 75 | max_concurrency_per_instance=config_dct['job']['max_concurrency_per_instance'], 76 | mini_batch_size=config_dct['job']['mini_batch_size'], 77 | mini_batch_error_threshold=config_dct['job']['mini_batch_error_threshold'], 78 | logging_level="DEBUG", 79 | error_threshold=config_dct['job']['error_threshold'], 80 | retry_settings=RetrySettings( 81 | max_retries=config_dct['job']['max_retries'], 82 | timeout=config_dct['job']['timeout'] 83 | ), 84 | task=RunFunction( 85 | code=".../components/diar/src", 86 | entry_script="main.py", 87 | environment=ml_client.environments.get(name="diar_env", version="1"), 88 | program_arguments="--input_audio_path ${{inputs.input_audio_path}} " 89 | "--input_asr_path ${{inputs.input_asr_path}} " 90 | "--keyvault_name ${{inputs.keyvault_name}} " 91 | "--secret_tenant_sp ${{inputs.secret_tenant_sp}} " 92 | "--secret_client_sp ${{inputs.secret_client_sp}} " 93 | "--secret_sp ${{inputs.secret_sp}} " 94 | "--pk_secret ${{inputs.pk_secret}} " 95 | "--pk_pass_secret ${{inputs.pk_pass_secret}} " 96 | "--pubk_secret ${{inputs.pubk_secret}} " 97 | "$[[--event_type ${{inputs.event_type}}]] " 98 | "$[[--max_num_speakers ${{inputs.max_num_speakers}}]] " 99 | "$[[--min_window_length ${{inputs.min_window_length}}]] " 100 | "$[[--overlap_threshold ${{inputs.overlap_threshold}}]] " 101 | "--output_diar_path ${{outputs.output_diar_path}} " 102 | f"--allowed_failed_percent {config_dct['job']['allowed_failed_percent']} " 103 | f"--task_overhead_timeout {config_dct['job']['task_overhead_timeout']} " 104 | f"--first_task_creation_timeout {config_dct['job']['first_task_creation_timeout']} " 105 | f"--resource_monitor_interval {config_dct['job']['resource_monitor_interval']} ", 106 | # All values output by run() method invocations will be aggregated into one unique file which is created in the output location. 107 | # If it is not set, 'summary_only' would invoked, which means user script is expected to store the output itself. 108 | #append_row_to="${{outputs.output_path}}" 109 | ), 110 | ) 111 | 112 | # 113 | # Create pipeline 114 | # 115 | 116 | @pipeline(default_compute=config_dct['aml']['computing']['gpu_cluster_a100']) 117 | def diar( 118 | input_dts:Input(type=AssetTypes.URI_FOLDER, mode=InputOutputModes.RO_MOUNT), 119 | input_asr:Input(type=AssetTypes.URI_FOLDER, mode=InputOutputModes.RO_MOUNT), 120 | output_dts:Input(type="string"), 121 | aml_compute:Input(type="string"), 122 | keyvault_name:Input(type="string"), 123 | secret_tenant_sp:Input(type="string"), 124 | secret_client_sp:Input(type="string"), 125 | secret_sp:Input(type="string"), 126 | pk_secret:Input(type="string"), 127 | pk_pass_secret:Input(type="string"), 128 | pubk_secret:Input(type="string"), 129 | event_type:Input(type="string", default=config_dct['diarization']['event_type'], optional=True), 130 | max_num_speakers:Input(type="integer", default=config_dct['diarization']['max_num_speakers'], optional=True), 131 | min_window_length:Input(type="number", default=config_dct['diarization']['min_window_length'], optional=True), 132 | overlap_threshold:Input(type="number", default=config_dct['diarization']['overlap_threshold'], optional=True) 133 | ): 134 | 135 | diar_node = diar_component( 136 | input_audio_path = input_dts, 137 | input_asr_path = input_asr, 138 | keyvault_name=keyvault_name, 139 | secret_tenant_sp=secret_tenant_sp, 140 | secret_client_sp=secret_client_sp, 141 | secret_sp=secret_sp, 142 | pk_secret=pk_secret, 143 | pk_pass_secret=pk_pass_secret, 144 | pubk_secret=pubk_secret, 145 | event_type = event_type, 146 | max_num_speakers = max_num_speakers, 147 | min_window_length = min_window_length, 148 | overlap_threshold = overlap_threshold 149 | ) 150 | diar_node.outputs.output_diar_path = Output( 151 | path=output_dts, 152 | type=AssetTypes.URI_FOLDER, 153 | mode=InputOutputModes.RW_MOUNT 154 | ) 155 | diar_node.compute = aml_compute 156 | 157 | return {'output_dts': diar_node.outputs.output_diar_path} 158 | 159 | 160 | # Create a pipeline 161 | pipeline_job = diar() 162 | 163 | # Component register 164 | ml_client.components.create_or_update(pipeline_job.component, version="1") 165 | 166 | 167 | if __name__=="__main__": 168 | fire.Fire(main) -------------------------------------------------------------------------------- /stt_aml_deploy/setup/register_components/register_ma_component.py: -------------------------------------------------------------------------------- 1 | # Libraries 2 | import yaml 3 | import sys 4 | import logging as log 5 | from azure.identity import DefaultAzureCredential 6 | from azure.ai.ml import MLClient, Input, Output 7 | from azure.ai.ml.dsl import pipeline 8 | from azure.ai.ml.entities import RetrySettings 9 | from azure.ai.ml.constants import AssetTypes, InputOutputModes 10 | from azure.ai.ml.parallel import parallel_run_function, RunFunction 11 | import fire 12 | 13 | # Setup logs 14 | root = log.getLogger() 15 | root.setLevel(log.DEBUG) 16 | handler = log.StreamHandler(sys.stdout) 17 | handler.setLevel(log.DEBUG) 18 | formatter = log.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 19 | handler.setFormatter(formatter) 20 | root.addHandler(handler) 21 | 22 | # Main method. Fire automatically align method arguments with parse commands from console 23 | def main( 24 | config_path='.../config/parallel_job.yaml' 25 | ): 26 | 27 | # Get credential token 28 | log.info("Get credential token:") 29 | try: 30 | credential = DefaultAzureCredential() 31 | credential.get_token("https://management.azure.com/.default") 32 | except Exception as ex: 33 | log.error(f"Something went wrong regarding authentication. Returned error is: {ex.message}") 34 | return (f"Something went wrong regarding authentication. Returned error is: {ex.message}", 500) 35 | 36 | # Fetch configuration file 37 | log.info("Fetch configuration file:") 38 | with open(config_path) as file: 39 | config_dct = yaml.load(file, Loader=yaml.FullLoader) 40 | # Get a handle to workspace 41 | log.info("Set up ML Client:") 42 | ml_client = MLClient( 43 | credential=credential, 44 | subscription_id=config_dct['aml']['subscription_id'], 45 | resource_group_name=config_dct['aml']['resource_group'], 46 | workspace_name=config_dct['aml']['workspace_name'], 47 | ) 48 | 49 | # 50 | # Declare Parallel task to perform merge and alignment 51 | # For detailed info, check: https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-parallel-job-in-pipeline?view=azureml-api-2&tabs=python 52 | # 53 | ma_component = parallel_run_function( 54 | name="pMA", 55 | display_name="Parallel merge & alignment", 56 | description="Parallel component to align transcriptions and diarization", 57 | inputs=dict( 58 | input_asr_path=Input(type=AssetTypes.URI_FOLDER, description="Audios to be diarized"), 59 | input_diar_path=Input(type=AssetTypes.URI_FOLDER, description="Transcriptions of those audios"), 60 | keyvault_name=Input(type="string"), 61 | secret_tenant_sp=Input(type="string"), 62 | secret_client_sp=Input(type="string"), 63 | secret_sp=Input(type="string"), 64 | pk_secret=Input(type="string"), 65 | pk_pass_secret=Input(type="string"), 66 | pubk_secret=Input(type="string"), 67 | cosmosdb_name=Input(type="string"), 68 | cosmosdb_collection=Input(type="string"), 69 | cosmosdb_cs_secret=Input(type="string"), 70 | ner_chunk_size=Input(type="integer", default=config_dct['align']['ner_chunk_size'], optional=True), 71 | ner_stride=Input(type="integer", default=config_dct['align']['ner_stride'], optional=True), 72 | max_words_in_sentence=Input(type="integer", default=config_dct['align']['max_words_in_sentence'], optional=True) 73 | ), 74 | outputs=dict(output_sm_path=Output(type=AssetTypes.URI_FOLDER)), 75 | input_data="${{inputs.input_asr_path}}", 76 | instance_count=config_dct['job']['instance_count_large'], 77 | max_concurrency_per_instance=config_dct['job']['max_concurrency_per_instance'], 78 | mini_batch_size=config_dct['job']['mini_batch_size'], 79 | mini_batch_error_threshold=config_dct['job']['mini_batch_error_threshold'], 80 | logging_level="DEBUG", 81 | error_threshold=config_dct['job']['error_threshold'], 82 | retry_settings=RetrySettings( 83 | max_retries=config_dct['job']['max_retries'], 84 | timeout=config_dct['job']['timeout'] 85 | ), 86 | task=RunFunction( 87 | code=".../components/merge_align/src", 88 | entry_script="main.py", 89 | environment=ml_client.environments.get(name="merge_align_env", version="1"), 90 | program_arguments="--input_asr_path ${{inputs.input_asr_path}} " 91 | "--input_diar_path ${{inputs.input_diar_path}} " 92 | "--keyvault_name ${{inputs.keyvault_name}} " 93 | "--secret_tenant_sp ${{inputs.secret_tenant_sp}} " 94 | "--secret_client_sp ${{inputs.secret_client_sp}} " 95 | "--secret_sp ${{inputs.secret_sp}} " 96 | "--pk_secret ${{inputs.pk_secret}} " 97 | "--pk_pass_secret ${{inputs.pk_pass_secret}} " 98 | "--pubk_secret ${{inputs.pubk_secret}} " 99 | "--cosmosdb_name ${{inputs.cosmosdb_name}} " 100 | "--cosmosdb_collection ${{inputs.cosmosdb_collection}} " 101 | "--cosmosdb_cs_secret ${{inputs.cosmosdb_cs_secret}} " 102 | "$[[--ner_chunk_size ${{inputs.ner_chunk_size}}]] " 103 | "$[[--ner_stride ${{inputs.ner_stride}}]] " 104 | "$[[--max_words_in_sentence ${{inputs.max_words_in_sentence}}]] " 105 | "--output_sm_path ${{outputs.output_sm_path}} " 106 | f"--allowed_failed_percent {config_dct['job']['allowed_failed_percent']} " 107 | f"--task_overhead_timeout {config_dct['job']['task_overhead_timeout']} " 108 | f"--first_task_creation_timeout {config_dct['job']['first_task_creation_timeout']} " 109 | f"--resource_monitor_interval {config_dct['job']['resource_monitor_interval']} ", 110 | # All values output by run() method invocations will be aggregated into one unique file which is created in the output location. 111 | # If it is not set, 'summary_only' would invoked, which means user script is expected to store the output itself. 112 | #append_row_to="${{outputs.output_path}}" 113 | ), 114 | ) 115 | 116 | # 117 | # Create pipeline 118 | # 119 | 120 | @pipeline(default_compute=config_dct['aml']['computing']['gpu_cluster_t4']) 121 | def merge_align( 122 | input_asr:Input(type=AssetTypes.URI_FOLDER, mode=InputOutputModes.RO_MOUNT), 123 | input_diar:Input(type=AssetTypes.URI_FOLDER, mode=InputOutputModes.RO_MOUNT), 124 | output_dts:Input(type="string"), 125 | aml_compute:Input(type="string"), 126 | keyvault_name:Input(type="string"), 127 | secret_tenant_sp:Input(type="string"), 128 | secret_client_sp:Input(type="string"), 129 | secret_sp:Input(type="string"), 130 | pk_secret:Input(type="string"), 131 | pk_pass_secret:Input(type="string"), 132 | pubk_secret:Input(type="string"), 133 | cosmosdb_name:Input(type="string"), 134 | cosmosdb_collection:Input(type="string"), 135 | cosmosdb_cs_secret:Input(type="string"), 136 | ner_chunk_size:Input(type="integer", default=config_dct['align']['ner_chunk_size'], optional=True), 137 | ner_stride:Input(type="integer", default=config_dct['align']['ner_stride'], optional=True), 138 | max_words_in_sentence:Input(type="integer", default=config_dct['align']['max_words_in_sentence'], optional=True) 139 | ): 140 | 141 | ma_node = ma_component( 142 | input_asr_path = input_asr, 143 | input_diar_path = input_diar, 144 | keyvault_name=keyvault_name, 145 | secret_tenant_sp=secret_tenant_sp, 146 | secret_client_sp=secret_client_sp, 147 | secret_sp=secret_sp, 148 | pk_secret=pk_secret, 149 | pk_pass_secret=pk_pass_secret, 150 | pubk_secret=pubk_secret, 151 | cosmosdb_name=cosmosdb_name, 152 | cosmosdb_collection=cosmosdb_collection, 153 | cosmosdb_cs_secret=cosmosdb_cs_secret, 154 | ner_chunk_size = ner_chunk_size, 155 | ner_stride = ner_stride, 156 | max_words_in_sentence = max_words_in_sentence 157 | ) 158 | ma_node.outputs.output_sm_path = Output( 159 | path=output_dts, 160 | type=AssetTypes.URI_FOLDER, 161 | mode=InputOutputModes.RW_MOUNT 162 | ) 163 | ma_node.compute = aml_compute 164 | 165 | return {'output_dts': ma_node.outputs.output_sm_path} 166 | 167 | 168 | # Create a pipeline 169 | pipeline_job = merge_align() 170 | 171 | # Component register 172 | ml_client.components.create_or_update(pipeline_job.component, version="1") 173 | 174 | 175 | if __name__=="__main__": 176 | fire.Fire(main) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards industrial-level Multi-Speaker Speech Recognition solutions 2 | 3 | --- 4 | ## Table of contents 5 | - [1. Introduction](#introduction) 6 | - [2. Description](#description) 7 | - [3. Components](#components) 8 | - [4. IAM](#iam) 9 | - [5. Quickstart](#quickstart) 10 | - [6. Call for contributions](#call-for-contributions) 11 | - [7. License](#license) 12 | --- 13 | 14 | 15 | ## Introduction 16 | 17 | In recent years, speech recognition technology has become ubiquitous in our daily lives, powering virtual assistants, smart home devices, and other voice-enabled applications. However, building a robust speech recognition system is a complex task that requires sophisticated algorithms and models to handle the challenges of different accents, background noise, and multiple speakers. 18 | 19 | In this repository, we aim to provide a comprehensive overview of the latest advancements in speech recognition and speaker diarization using deep learning techniques. We will explore the underlying technologies, including neural networks and their variants, and provide code examples and tutorials to help developers and researchers get started with building their own speech recognition and speaker diarization systems by making minimal changes to the Azure Pipelines implementation provided. 20 | 21 | 22 | ## Description 23 | 24 | Speech recognition is the task of automatically transcribing spoken language into text. It involves developing algorithms and models that can analyze audio recordings and identify the words and phrases spoken by a user. In recent years, deep learning models have shown great success in improving speech recognition accuracy, making it a hot topic in the field of machine learning. 25 | 26 | Autoregressive models such as [Whisper](https://openai.com/research/whisper) provide excepcional transcriptions when combined with some additional preprocessing features, that we have picked up from [this excellent repo](https://github.com/guillaumekln/faster-whisper), being the quality of the timestamps returned for each audio segment rather poor. In this direction, phoneme-based speech recognition tools like [wav2vec2](https://ai.facebook.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/) handle timestamps perfectly, as these are finetuned to recognise the smallest unit of speech distinguishing one word from another. 27 | 28 | A technique that makes both ends meet is [forced alignment](https://linguistics.berkeley.edu/plab/guestwiki/index.php?title=Forced_alignment#:~:text=Forced%20alignment%20refers%20to%20the,automatically%20generate%20phone%20level%20segmentation.); a good introduction to this topic can be found [here](https://pytorch.org/audio/stable/tutorials/forced_alignment_tutorial.html), and our implementation relies on [NeMo repo](https://nvidia.github.io/NeMo/blogs/2023/2023-08-nfa/). 29 | 30 | 31 | 32 | Alignment can only be performed, however, when both model's vocabulary is matched. To overcome this issue, most recent approaches such as Dynamyc Time Warping (DTW) algorithm are applied to cross-attention weights to directly handle Whisper features to enhance word-level timestamps. 33 | 34 | 35 | 36 | Speaker diarization, on the other hand, is the process of separating multiple speakers in an audio recording and assigning each speaker to their respective segments. It involves analyzing the audio signal to identify the unique characteristics of each speaker, such as their voice, intonation, and speaking style. Speaker diarization is essential in applications such as call center analytics, meeting transcription, and language learning, where it is necessary to distinguish between different speakers in a conversation. 37 | 38 | In this direction, [Multi-scale systems](https://developer.nvidia.com/blog/dynamic-scale-weighting-through-multiscale-speaker-diarization/) have emerged as a feasible solution to overcome traditional problems attached to time window selection. In our work, we make use of previous steps to enhance and optimise diarization runtime by providing segments VAD and accurate, word-level timestamps, which is particularly relevant in long audios. 39 | 40 | Finally, in order to enhance natural language transcriptions quality and readability, a punctuation-based sentence alignment strategy has been implemented after both ASR and diarization steps. 41 | 42 | 43 | 44 | When it comes to scalability, parallelisation plays a pivotal role. In this direction, we adopted `parallel_job` solution included in the *AML* toolkit, that allows to distribute your input across a defined number of devices asynchronously to speedup inference. Notice that there is not an immediate extrapolation of the code from standard pipelines to parallel components. 45 | 46 | 47 | 48 | 49 | 50 | ## Structure 51 | 52 | This service's main cornerstones are scalability, robustness and ease to deploy. In this direction, an API interface is provided to easily request batch processing jobs. While most of the parameters have default options, configuration related to storage paths, noSQL database credentials and secrets is required. 53 | 54 | 55 | ## Setup 56 | 57 | ## Storage 58 | Input and output containers must be defined as AzureML Datastores. The reason behind is that we manage intermediate data to not generate an excessive amount of residual files, leading to greater costs; this is particularly relevant due to the fact that processed audios are one of those files. It can be achieved following [this steps](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-datastore?view=azureml-api-2&tabs=sdk-identity-based-access%2Csdk-adls-identity-access%2Csdk-azfiles-accountkey%2Csdk-adlsgen1-identity-access%2Csdk-onelake-identity-access). 59 | 60 | ## Tracking 61 | Ideally, input and output blob paths inside those containers should vary, or processed data should be moved/deleted after each job. If this task is not handled, it would not raise any errors nor processing duplications, as we register every `unique_id` in a cosmosDB database, but potentially many unnecessary requests to that cosmosDB database and job inputs to each component will be ingested, leading to a suboptimal performance of your services. 62 | 63 | ## Keyvaults 64 | An extensive usage of Azure Keyvault resource is made throughout the process. To be more precise, an asymmetric encription protocol (PGP) is used to ensure that anyone can encrypt data, but a limited number of profiles can decrypt it: 65 | 66 | * `pubk_secret`: Holds the secret to the public key. 67 | * `pk_secret`: Holds the secret to the private key. 68 | * `pk_pass_secret`: Holds the secret to the private key's password. 69 | 70 | These last two, however, are disabled; i.e., they must be enabled beforehand to access the secret. To that end, a service principal account credentials are also stored as secrets in the same keyvault, including (respectively) tenant, client and passwords under the identifiers: 71 | 72 | * `secret_tenant_sp` 73 | * `secret_client_sp` 74 | * `secret_sp` 75 | 76 | ## IAM 77 | AML computing clusters, together with AzuremL endpoint, will use a service account to which we must assign a series of roles in order to execute these processes successfully: 78 | * Storage Blob Data Contributor (in storage account resource) 79 | * Storage Queue Data Contributor (in storage account resource) 80 | * AzureML Data Scientist (in AML resource) 81 | * Access to [KeyVaults](https://learn.microsoft.com/en-us/azure/key-vault/general/assign-access-policy?tabs=azure-portal) 82 | * Read and write (documents) roles in cosmosDB resource. Database and collection creation is not necessary (see `Setup`). 83 | 84 | 85 | ## Quickstart 86 | 87 | Once the environment has been created, permissions for service account have been granted and you filled the configuration file with your own data, the fastest way to run AML pipelines is by opening a terminal and running the provided script to start an AzureML job: 88 | 89 | ``` 90 | cd stt_aml_deploy 91 | python online_endpoint.py --config_path ./config/online_endpoint.yaml 92 | ``` 93 | 94 | 95 | ## Call for contributions 96 | 97 | Despite including and end-to-end solution to model design in AML, the following additional features are expected to be developed: 98 | 99 | - [X] Speed up diarization step by using aligned ASR output. 100 | - [X] Include CTranslate2 engine in ASR components. 101 | - [X] Improve preprocessing techniques in an individual component to enhance stability. 102 | - [X] Parallelise processing using distributed, asynchronous clusters. 103 | - [X] Serialise pipeline implementation to avoid Microsoft bugs on `parallel_run_function`. 104 | - [X] End-to-end, monolitic implementation using keyvaults and security protocols. 105 | - [X] Enhance benchmark logging and CUDA capabilities checking. 106 | - [X] API batch service deployment. 107 | - [ ] Make sentence alignment more sensitive to short texts. 108 | 109 | 110 | ## License 111 | Released under [MIT](/LICENSE) by [@hedrergudene](https://github.com/hedrergudene). -------------------------------------------------------------------------------- /stt_aml_deploy/setup/register_components/register_prep_component.py: -------------------------------------------------------------------------------- 1 | # Libraries 2 | import yaml 3 | import sys 4 | import logging as log 5 | from azure.identity import DefaultAzureCredential 6 | from azure.ai.ml import MLClient, Input, Output 7 | from azure.ai.ml.dsl import pipeline 8 | from azure.ai.ml.entities import RetrySettings 9 | from azure.ai.ml.constants import AssetTypes, InputOutputModes 10 | from azure.ai.ml.parallel import parallel_run_function, RunFunction 11 | import fire 12 | 13 | # Setup logs 14 | root = log.getLogger() 15 | root.setLevel(log.DEBUG) 16 | handler = log.StreamHandler(sys.stdout) 17 | handler.setLevel(log.DEBUG) 18 | formatter = log.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 19 | handler.setFormatter(formatter) 20 | root.addHandler(handler) 21 | 22 | # Main method. Fire automatically align method arguments with parse commands from console 23 | def main( 24 | config_path='.../config/parallel_job.yaml' 25 | ): 26 | 27 | # Get credential token 28 | log.info("Get credential token:") 29 | try: 30 | credential = DefaultAzureCredential() 31 | credential.get_token("https://management.azure.com/.default") 32 | except Exception as ex: 33 | log.error(f"Something went wrong regarding authentication. Returned error is: {ex.message}") 34 | return (f"Something went wrong regarding authentication. Returned error is: {ex.message}", 500) 35 | 36 | # Fetch configuration file 37 | log.info("Fetch configuration file:") 38 | with open(config_path) as file: 39 | config_dct = yaml.load(file, Loader=yaml.FullLoader) 40 | # Get a handle to workspace 41 | log.info("Set up ML Client:") 42 | ml_client = MLClient( 43 | credential=credential, 44 | subscription_id=config_dct['aml']['subscription_id'], 45 | resource_group_name=config_dct['aml']['resource_group'], 46 | workspace_name=config_dct['aml']['workspace_name'], 47 | ) 48 | 49 | 50 | # 51 | # Declare Parallel task to perform preprocessing 52 | # For detailed info, check: https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-parallel-job-in-pipeline?view=azureml-api-2&tabs=python 53 | # 54 | prep_component = parallel_run_function( 55 | name="pPrep", 56 | display_name="Parallel preprocessing", 57 | description="Parallel component to perform audio preprocessing", 58 | inputs=dict( 59 | input_path=Input(type=AssetTypes.URI_FOLDER, description="Audios to be preprocessed"), 60 | keyvault_name=Input(type="string"), 61 | secret_tenant_sp=Input(type="string"), 62 | secret_client_sp=Input(type="string"), 63 | secret_sp=Input(type="string"), 64 | pk_secret=Input(type="string"), 65 | pk_pass_secret=Input(type="string"), 66 | pubk_secret=Input(type="string"), 67 | cosmosdb_name=Input(type="string"), 68 | cosmosdb_collection=Input(type="string"), 69 | cosmosdb_cs_secret=Input(type="string"), 70 | vad_threshold=Input(type="number", default=config_dct['preprocessing']['vad_threshold'], optional=True), 71 | min_speech_duration_ms=Input(type="integer", default=config_dct['preprocessing']['min_speech_duration_ms'], optional=True), 72 | min_silence_duration_ms=Input(type="integer", default=config_dct['preprocessing']['min_silence_duration_ms'], optional=True), 73 | demucs_model=Input(type="string", default=config_dct['preprocessing']['demucs_model'], optional=True) 74 | ), 75 | outputs=dict( 76 | output_prep_path=Output(type=AssetTypes.URI_FOLDER) 77 | ), 78 | input_data="${{inputs.input_path}}", 79 | instance_count=config_dct['job']['instance_count_large'], 80 | max_concurrency_per_instance=config_dct['job']['max_concurrency_per_instance'], 81 | mini_batch_size=config_dct['job']['mini_batch_size'], 82 | mini_batch_error_threshold=config_dct['job']['mini_batch_error_threshold'], 83 | logging_level="DEBUG", 84 | error_threshold=config_dct['job']['error_threshold'], 85 | retry_settings=RetrySettings( 86 | max_retries=config_dct['job']['max_retries'], 87 | timeout=config_dct['job']['timeout'] 88 | ), 89 | task=RunFunction( 90 | code=".../components/prep/src", 91 | entry_script="main.py", 92 | environment=ml_client.environments.get(name="prep_env", version="1"), 93 | program_arguments="--input_path ${{inputs.input_path}} " 94 | "--keyvault_name ${{inputs.keyvault_name}} " 95 | "--secret_tenant_sp ${{inputs.secret_tenant_sp}} " 96 | "--secret_client_sp ${{inputs.secret_client_sp}} " 97 | "--secret_sp ${{inputs.secret_sp}} " 98 | "--pk_secret ${{inputs.pk_secret}} " 99 | "--pk_pass_secret ${{inputs.pk_pass_secret}} " 100 | "--pubk_secret ${{inputs.pubk_secret}} " 101 | "--cosmosdb_name ${{inputs.cosmosdb_name}} " 102 | "--cosmosdb_collection ${{inputs.cosmosdb_collection}} " 103 | "--cosmosdb_cs_secret ${{inputs.cosmosdb_cs_secret}} " 104 | "$[[--vad_threshold ${{inputs.vad_threshold}}]] " 105 | "$[[--min_speech_duration_ms ${{inputs.min_speech_duration_ms}}]] " 106 | "$[[--min_silence_duration_ms ${{inputs.min_silence_duration_ms}}]] " 107 | "$[[--demucs_model ${{inputs.demucs_model}}]] " 108 | "--output_prep_path ${{outputs.output_prep_path}} " 109 | f"--allowed_failed_percent {config_dct['job']['allowed_failed_percent']} " 110 | f"--task_overhead_timeout {config_dct['job']['task_overhead_timeout']} " 111 | f"--first_task_creation_timeout {config_dct['job']['first_task_creation_timeout']} " 112 | f"--resource_monitor_interval {config_dct['job']['resource_monitor_interval']} ", 113 | # All values output by run() method invocations will be aggregated into one unique file which is created in the output location. 114 | # If it is not set, 'summary_only' would invoked, which means user script is expected to store the output itself. 115 | #append_row_to="${{outputs.output_path}}" 116 | ), 117 | ) 118 | 119 | # 120 | # Create pipeline 121 | # 122 | 123 | 124 | @pipeline(default_compute=config_dct['aml']['computing']['gpu_cluster_t4']) 125 | def prep( 126 | input_dts:Input(type=AssetTypes.URI_FOLDER, mode=InputOutputModes.RO_MOUNT), 127 | output_dts:Input(type="string"), 128 | aml_compute:Input(type="string"), 129 | keyvault_name:Input(type="string"), 130 | secret_tenant_sp:Input(type="string"), 131 | secret_client_sp:Input(type="string"), 132 | secret_sp:Input(type="string"), 133 | pk_secret:Input(type="string"), 134 | pk_pass_secret:Input(type="string"), 135 | pubk_secret:Input(type="string"), 136 | cosmosdb_name:Input(type="string"), 137 | cosmosdb_collection:Input(type="string"), 138 | cosmosdb_cs_secret:Input(type="string"), 139 | vad_threshold:Input(type="number", default=config_dct['preprocessing']['vad_threshold'], optional=True), 140 | min_speech_duration_ms:Input(type="integer", default=config_dct['preprocessing']['min_speech_duration_ms'], optional=True), 141 | min_silence_duration_ms:Input(type="integer", default=config_dct['preprocessing']['min_silence_duration_ms'], optional=True), 142 | demucs_model:Input(type="string", default=config_dct['preprocessing']['demucs_model'], optional=True) 143 | ): 144 | 145 | # Preprocessing 146 | prep_node = prep_component( 147 | input_path=input_dts, 148 | keyvault_name=keyvault_name, 149 | secret_tenant_sp=secret_tenant_sp, 150 | secret_client_sp=secret_client_sp, 151 | secret_sp=secret_sp, 152 | pk_secret=pk_secret, 153 | pk_pass_secret=pk_pass_secret, 154 | pubk_secret=pubk_secret, 155 | cosmosdb_name=cosmosdb_name, 156 | cosmosdb_collection=cosmosdb_collection, 157 | cosmosdb_cs_secret=cosmosdb_cs_secret, 158 | vad_threshold=vad_threshold, 159 | min_speech_duration_ms=min_speech_duration_ms, 160 | min_silence_duration_ms=min_silence_duration_ms, 161 | demucs_model=demucs_model 162 | ) 163 | prep_node.outputs.output_prep_path = Output( 164 | path=output_dts, 165 | type=AssetTypes.URI_FOLDER, 166 | mode=InputOutputModes.RW_MOUNT 167 | ) 168 | prep_node.compute = aml_compute 169 | 170 | return {'output_dts': prep_node.outputs.output_prep_path} 171 | 172 | # Create a pipeline 173 | pipeline_job = prep() 174 | 175 | # Component register 176 | ml_client.components.create_or_update(pipeline_job.component, version="1") 177 | 178 | 179 | if __name__=="__main__": 180 | fire.Fire(main) -------------------------------------------------------------------------------- /stt_aml_deploy/setup/register_components/register_asr_component.py: -------------------------------------------------------------------------------- 1 | # Libraries 2 | import yaml 3 | import sys 4 | import logging as log 5 | from azure.identity import DefaultAzureCredential 6 | from azure.ai.ml import MLClient, Input, Output 7 | from azure.ai.ml.dsl import pipeline 8 | from azure.ai.ml.entities import RetrySettings 9 | from azure.ai.ml.constants import AssetTypes, InputOutputModes 10 | from azure.ai.ml.parallel import parallel_run_function, RunFunction 11 | import fire 12 | 13 | # Setup logs 14 | root = log.getLogger() 15 | root.setLevel(log.DEBUG) 16 | handler = log.StreamHandler(sys.stdout) 17 | handler.setLevel(log.DEBUG) 18 | formatter = log.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 19 | handler.setFormatter(formatter) 20 | root.addHandler(handler) 21 | 22 | # Main method. Fire automatically align method arguments with parse commands from console 23 | def main( 24 | config_path='.../config/parallel_job.yaml' 25 | ): 26 | 27 | # Get credential token 28 | log.info("Get credential token:") 29 | try: 30 | credential = DefaultAzureCredential() 31 | credential.get_token("https://management.azure.com/.default") 32 | except Exception as ex: 33 | log.error(f"Something went wrong regarding authentication. Returned error is: {ex.message}") 34 | return (f"Something went wrong regarding authentication. Returned error is: {ex.message}", 500) 35 | 36 | # Fetch configuration file 37 | log.info("Fetch configuration file:") 38 | with open(config_path) as file: 39 | config_dct = yaml.load(file, Loader=yaml.FullLoader) 40 | # Get a handle to workspace 41 | log.info("Set up ML Client:") 42 | ml_client = MLClient( 43 | credential=credential, 44 | subscription_id=config_dct['aml']['subscription_id'], 45 | resource_group_name=config_dct['aml']['resource_group'], 46 | workspace_name=config_dct['aml']['workspace_name'], 47 | ) 48 | 49 | # 50 | # Declare Parallel task to perform preprocessing 51 | # For detailed info, check: https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-parallel-job-in-pipeline?view=azureml-api-2&tabs=python 52 | # 53 | asr_component = parallel_run_function( 54 | name="pASR", 55 | display_name="Parallel ASR", 56 | description="Parallel component to perform ASR on a large amount of audios", 57 | inputs=dict( 58 | input_path=Input(type=AssetTypes.URI_FOLDER, description="Audios to be transcribed and metadata attached to those audios."), 59 | keyvault_name=Input(type="string"), 60 | secret_tenant_sp=Input(type="string"), 61 | secret_client_sp=Input(type="string"), 62 | secret_sp=Input(type="string"), 63 | pk_secret=Input(type="string"), 64 | pk_pass_secret=Input(type="string"), 65 | pubk_secret=Input(type="string"), 66 | whisper_model_name=Input(type="string", default=config_dct['asr']['model_name'], optional=True), 67 | num_workers=Input(type="integer", default=config_dct['asr']['num_workers'], optional=True), 68 | beam_size=Input(type="integer", default=config_dct['asr']['beam_size'], optional=True), 69 | word_level_timestamps=Input(type="boolean", default=config_dct['asr']['word_level_timestamps'], optional=True), 70 | condition_on_previous_text=Input(type="boolean", default=config_dct['asr']['condition_on_previous_text'], optional=True), 71 | compute_type=Input(type="string", default=config_dct['asr']['compute_type'], optional=True), 72 | language_code=Input(type="string", default=config_dct['asr']['language_code'], optional=True) 73 | ), 74 | outputs=dict(output_asr_path=Output(type=AssetTypes.URI_FOLDER)), 75 | input_data="${{inputs.input_path}}", 76 | instance_count=config_dct['job']['instance_count_large'], 77 | max_concurrency_per_instance=config_dct['job']['max_concurrency_per_instance'], 78 | mini_batch_size=config_dct['job']['mini_batch_size'], 79 | mini_batch_error_threshold=config_dct['job']['mini_batch_error_threshold'], 80 | logging_level="DEBUG", 81 | error_threshold=config_dct['job']['error_threshold'], 82 | retry_settings=RetrySettings( 83 | max_retries=config_dct['job']['max_retries'], 84 | timeout=config_dct['job']['timeout'] 85 | ), 86 | task=RunFunction( 87 | code=".../components/asr/src", 88 | entry_script="main.py", 89 | environment=ml_client.environments.get(name="asr_env", version="1"), 90 | program_arguments="--input_path ${{inputs.input_path}} " 91 | "--keyvault_name ${{inputs.keyvault_name}} " 92 | "--secret_tenant_sp ${{inputs.secret_tenant_sp}} " 93 | "--secret_client_sp ${{inputs.secret_client_sp}} " 94 | "--secret_sp ${{inputs.secret_sp}} " 95 | "--pk_secret ${{inputs.pk_secret}} " 96 | "--pk_pass_secret ${{inputs.pk_pass_secret}} " 97 | "--pubk_secret ${{inputs.pubk_secret}} " 98 | "$[[--whisper_model_name ${{inputs.whisper_model_name}}]] " 99 | "$[[--num_workers ${{inputs.num_workers}}]] " 100 | "$[[--beam_size ${{inputs.beam_size}}]] " 101 | "$[[--word_level_timestamps ${{inputs.word_level_timestamps}}]] " 102 | "$[[--condition_on_previous_text ${{inputs.condition_on_previous_text}}]] " 103 | "$[[--compute_type ${{inputs.compute_type}}]] " 104 | "$[[--language_code ${{inputs.language_code}}]] " 105 | "--output_asr_path ${{outputs.output_asr_path}} " 106 | f"--allowed_failed_percent {config_dct['job']['allowed_failed_percent']} " 107 | f"--task_overhead_timeout {config_dct['job']['task_overhead_timeout']} " 108 | f"--first_task_creation_timeout {config_dct['job']['first_task_creation_timeout']} " 109 | f"--resource_monitor_interval {config_dct['job']['resource_monitor_interval']} ", 110 | # All values output by run() method invocations will be aggregated into one unique file which is created in the output location. 111 | # If it is not set, 'summary_only' would invoked, which means user script is expected to store the output itself. 112 | #append_row_to="${{outputs.output_path}}" 113 | ), 114 | ) 115 | 116 | # 117 | # Create pipeline 118 | # 119 | 120 | @pipeline(default_compute=config_dct['aml']['computing']['gpu_cluster_t4']) 121 | def asr( 122 | input_dts:Input(type=AssetTypes.URI_FOLDER, mode=InputOutputModes.RO_MOUNT), 123 | output_dts:Input(type="string"), 124 | aml_compute:Input(type="string"), 125 | keyvault_name:Input(type="string"), 126 | secret_tenant_sp:Input(type="string"), 127 | secret_client_sp:Input(type="string"), 128 | secret_sp:Input(type="string"), 129 | pk_secret:Input(type="string"), 130 | pk_pass_secret:Input(type="string"), 131 | pubk_secret:Input(type="string"), 132 | model_name:Input(type="string", default=config_dct['asr']['model_name'], optional=True), 133 | num_workers:Input(type="integer", default=config_dct['asr']['num_workers'], optional=True), 134 | beam_size:Input(type="integer", default=config_dct['asr']['beam_size'], optional=True), 135 | word_level_timestamps:Input(type="boolean", default=config_dct['asr']['word_level_timestamps'], optional=True), 136 | condition_on_previous_text:Input(type="boolean", default=config_dct['asr']['condition_on_previous_text'], optional=True), 137 | compute_type:Input(type="string", default=config_dct['asr']['compute_type'], optional=True), 138 | language_code:Input(type="string", default=config_dct['asr']['language_code'], optional=True), 139 | ): 140 | 141 | asr_node = asr_component( 142 | input_path = input_dts, 143 | keyvault_name=keyvault_name, 144 | secret_tenant_sp=secret_tenant_sp, 145 | secret_client_sp=secret_client_sp, 146 | secret_sp=secret_sp, 147 | pk_secret=pk_secret, 148 | pk_pass_secret=pk_pass_secret, 149 | pubk_secret=pubk_secret, 150 | whisper_model_name = model_name, 151 | num_workers = num_workers, 152 | beam_size = beam_size, 153 | word_level_timestamps = word_level_timestamps, 154 | condition_on_previous_text = condition_on_previous_text, 155 | compute_type = compute_type, 156 | language_code = language_code 157 | ) 158 | asr_node.outputs.output_asr_path = Output( 159 | path=output_dts, 160 | type=AssetTypes.URI_FOLDER, 161 | mode=InputOutputModes.RW_MOUNT 162 | ) 163 | asr_node.compute = aml_compute 164 | 165 | return {'output_dts': asr_node.outputs.output_asr_path} 166 | 167 | 168 | # Create a pipeline 169 | pipeline_job = asr() 170 | 171 | # Component register 172 | ml_client.components.create_or_update(pipeline_job.component, version="1") 173 | 174 | 175 | if __name__=="__main__": 176 | fire.Fire(main) -------------------------------------------------------------------------------- /stt_aml_deploy/setup/register_pipeline.py: -------------------------------------------------------------------------------- 1 | # Libraries 2 | import yaml 3 | import sys 4 | import logging as log 5 | from azure.identity import DefaultAzureCredential 6 | from azure.ai.ml import MLClient, Input, Output, load_component 7 | from azure.ai.ml.dsl import pipeline 8 | from azure.ai.ml.constants import AssetTypes, InputOutputModes 9 | import fire 10 | 11 | # Setup logs 12 | root = log.getLogger() 13 | root.setLevel(log.DEBUG) 14 | handler = log.StreamHandler(sys.stdout) 15 | handler.setLevel(log.DEBUG) 16 | formatter = log.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 17 | handler.setFormatter(formatter) 18 | root.addHandler(handler) 19 | 20 | # Main method. Fire automatically align method arguments with parse commands from console 21 | def main( 22 | config_path='../config/parallel_job.yaml' 23 | ): 24 | 25 | # Get credential token 26 | log.info("Get credential token:") 27 | try: 28 | credential = DefaultAzureCredential() 29 | credential.get_token("https://management.azure.com/.default") 30 | except Exception as ex: 31 | log.error(f"Something went wrong regarding authentication. Returned error is: {ex.message}") 32 | return (f"Something went wrong regarding authentication. Returned error is: {ex.message}", 500) 33 | 34 | # Fetch configuration file 35 | log.info("Fetch configuration file:") 36 | with open(config_path) as file: 37 | config_dct = yaml.load(file, Loader=yaml.FullLoader) 38 | # Get a handle to workspace 39 | log.info("Set up ML Client:") 40 | ml_client = MLClient( 41 | credential=credential, 42 | subscription_id=config_dct['aml']['subscription_id'], 43 | resource_group_name=config_dct['aml']['resource_group'], 44 | workspace_name=config_dct['aml']['workspace_name'], 45 | ) 46 | 47 | 48 | # Fetch components 49 | prep_comp = load_component(client=ml_client, name="prep", version="1") 50 | asr_comp = load_component(client=ml_client, name="asr", version="1") 51 | nfa_comp = load_component(client=ml_client, name="nfa", version="1") 52 | diar_comp = load_component(client=ml_client, name="diar", version="1") 53 | ma_comp = load_component(client=ml_client, name="merge_align", version="1") 54 | rsttd_comp = load_component(client=ml_client, name="remove_stt_data", version="1") 55 | 56 | # 57 | # Create pipeline 58 | # 59 | 60 | 61 | @pipeline() 62 | def stt( 63 | input_dts:Input(type=AssetTypes.URI_FOLDER, mode=InputOutputModes.RO_MOUNT), 64 | output_dts:Input(type='string'), 65 | storage_account_name:Input(type="string"), 66 | container_name:Input(type="string"), 67 | blob_filepath:Input(type="string"), 68 | aml_cpu_cluster:Input(type="string"), 69 | aml_t4_cluster:Input(type="string"), 70 | aml_a100_cluster:Input(type="string"), 71 | keyvault_name:Input(type="string"), 72 | secret_tenant_sp:Input(type="string"), 73 | secret_client_sp:Input(type="string"), 74 | secret_sp:Input(type="string"), 75 | pk_secret:Input(type="string"), 76 | pk_pass_secret:Input(type="string"), 77 | pubk_secret:Input(type="string"), 78 | cosmosdb_name:Input(type="string"), 79 | cosmosdb_collection:Input(type="string"), 80 | cosmosdb_cs_secret:Input(type="string"), 81 | vad_threshold:Input(type="number", default=config_dct['preprocessing']['vad_threshold'], optional=True), 82 | min_speech_duration_ms:Input(type="integer", default=config_dct['preprocessing']['min_speech_duration_ms'], optional=True), 83 | min_silence_duration_ms:Input(type="integer", default=config_dct['preprocessing']['min_silence_duration_ms'], optional=True), 84 | demucs_model:Input(type="string", default=config_dct['preprocessing']['demucs_model'], optional=True), 85 | asr_model_name:Input(type="string", default=config_dct['asr']['model_name'], optional=True), 86 | asr_num_workers:Input(type="integer", default=config_dct['asr']['num_workers'], optional=True), 87 | asr_beam_size:Input(type="integer", default=config_dct['asr']['beam_size'], optional=True), 88 | word_level_timestamps:Input(type="boolean", default=config_dct['asr']['word_level_timestamps'], optional=True), 89 | condition_on_previous_text:Input(type="boolean", default=config_dct['asr']['condition_on_previous_text'], optional=True), 90 | asr_compute_type:Input(type="string", default=config_dct['asr']['compute_type'], optional=True), 91 | asr_language_code:Input(type="string", default=config_dct['asr']['language_code'], optional=True), 92 | nfa_model_name:Input(type="string", default=config_dct['fa']['model_name'], optional=True), 93 | nfa_batch_size:Input(type="integer", default=config_dct['fa']['batch_size'], optional=True), 94 | diar_event_type:Input(type="string", default=config_dct['diarization']['event_type'], optional=True), 95 | diar_max_num_speakers:Input(type="integer", default=config_dct['diarization']['max_num_speakers'], optional=True), 96 | diar_min_window_length:Input(type="number", default=config_dct['diarization']['min_window_length'], optional=True), 97 | diar_overlap_threshold:Input(type="number", default=config_dct['diarization']['overlap_threshold'], optional=True), 98 | ma_ner_chunk_size:Input(type="integer", default=config_dct['align']['ner_chunk_size'], optional=True), 99 | ma_ner_stride:Input(type="integer", default=config_dct['align']['ner_stride'], optional=True), 100 | ma_max_words_in_sentence:Input(type="integer", default=config_dct['align']['max_words_in_sentence'], optional=True) 101 | ): 102 | 103 | # Preprocessing 104 | prep_node = prep_comp( 105 | input_dts=input_dts, 106 | output_dts=output_dts, 107 | aml_compute=aml_t4_cluster, 108 | keyvault_name=keyvault_name, 109 | secret_tenant_sp=secret_tenant_sp, 110 | secret_client_sp=secret_client_sp, 111 | secret_sp=secret_sp, 112 | pk_secret=pk_secret, 113 | pk_pass_secret=pk_pass_secret, 114 | pubk_secret=pubk_secret, 115 | cosmosdb_name=cosmosdb_name, 116 | cosmosdb_collection=cosmosdb_collection, 117 | cosmosdb_cs_secret=cosmosdb_cs_secret, 118 | vad_threshold=vad_threshold, 119 | min_speech_duration_ms=min_speech_duration_ms, 120 | min_silence_duration_ms=min_silence_duration_ms, 121 | demucs_model=demucs_model 122 | ) 123 | prep_node.outputs.output_dts = Output( 124 | path=output_dts, 125 | type=AssetTypes.URI_FOLDER, 126 | mode=InputOutputModes.RW_MOUNT 127 | ) 128 | prep_node.compute = aml_t4_cluster 129 | 130 | # ASR 131 | asr_node = asr_comp( 132 | input_dts=prep_node.outputs.output_dts, 133 | output_dts=output_dts, 134 | aml_compute=aml_t4_cluster, 135 | keyvault_name=keyvault_name, 136 | secret_tenant_sp=secret_tenant_sp, 137 | secret_client_sp=secret_client_sp, 138 | secret_sp=secret_sp, 139 | pk_secret=pk_secret, 140 | pk_pass_secret=pk_pass_secret, 141 | pubk_secret=pubk_secret, 142 | model_name=asr_model_name, 143 | num_workers=asr_num_workers, 144 | beam_size=asr_beam_size, 145 | word_level_timestamps=word_level_timestamps, 146 | condition_on_previous_text=condition_on_previous_text, 147 | compute_type=asr_compute_type, 148 | language_code=asr_language_code 149 | ) 150 | asr_node.outputs.output_dts = Output( 151 | path=output_dts, 152 | type=AssetTypes.URI_FOLDER, 153 | mode=InputOutputModes.RW_MOUNT 154 | ) 155 | asr_node.compute = aml_t4_cluster 156 | 157 | # NFA 158 | nfa_node = nfa_comp( 159 | input_dts=prep_node.outputs.output_dts, 160 | input_asr=asr_node.outputs.output_dts, 161 | output_dts=output_dts, 162 | aml_compute=aml_t4_cluster, 163 | keyvault_name=keyvault_name, 164 | secret_tenant_sp=secret_tenant_sp, 165 | secret_client_sp=secret_client_sp, 166 | secret_sp=secret_sp, 167 | pk_secret=pk_secret, 168 | pk_pass_secret=pk_pass_secret, 169 | pubk_secret=pubk_secret, 170 | model_name=nfa_model_name, 171 | batch_size=nfa_batch_size 172 | ) 173 | nfa_node.outputs.output_dts = Output( 174 | path=output_dts, 175 | type=AssetTypes.URI_FOLDER, 176 | mode=InputOutputModes.RW_MOUNT 177 | ) 178 | nfa_node.compute = aml_t4_cluster 179 | 180 | # Diarization 181 | diar_node = diar_comp( 182 | input_dts=prep_node.outputs.output_dts, 183 | input_asr=nfa_node.outputs.output_dts, 184 | output_dts=output_dts, 185 | aml_compute=aml_a100_cluster, 186 | keyvault_name=keyvault_name, 187 | secret_tenant_sp=secret_tenant_sp, 188 | secret_client_sp=secret_client_sp, 189 | secret_sp=secret_sp, 190 | pk_secret=pk_secret, 191 | pk_pass_secret=pk_pass_secret, 192 | pubk_secret=pubk_secret, 193 | event_type=diar_event_type, 194 | max_num_speakers=diar_max_num_speakers, 195 | min_window_length=diar_min_window_length, 196 | overlap_threshold=diar_overlap_threshold 197 | ) 198 | diar_node.outputs.output_dts = Output( 199 | path=output_dts, 200 | type=AssetTypes.URI_FOLDER, 201 | mode=InputOutputModes.RW_MOUNT 202 | ) 203 | diar_node.compute = aml_a100_cluster 204 | 205 | # Merge&Align 206 | ma_node = ma_comp( 207 | input_asr = nfa_node.outputs.output_dts, 208 | input_diar = diar_node.outputs.output_dts, 209 | output_dts=output_dts, 210 | aml_compute=aml_t4_cluster, 211 | keyvault_name=keyvault_name, 212 | secret_tenant_sp=secret_tenant_sp, 213 | secret_client_sp=secret_client_sp, 214 | secret_sp=secret_sp, 215 | pk_secret=pk_secret, 216 | pk_pass_secret=pk_pass_secret, 217 | pubk_secret=pubk_secret, 218 | cosmosdb_name=cosmosdb_name, 219 | cosmosdb_collection=cosmosdb_collection, 220 | cosmosdb_cs_secret=cosmosdb_cs_secret, 221 | ner_chunk_size=ma_ner_chunk_size, 222 | ner_stride=ma_ner_stride, 223 | max_words_in_sentence=ma_max_words_in_sentence 224 | ) 225 | ma_node.outputs.output_dts = Output( 226 | path=output_dts, 227 | type=AssetTypes.URI_FOLDER, 228 | mode=InputOutputModes.RW_MOUNT 229 | ) 230 | ma_node.compute = aml_t4_cluster 231 | 232 | # Remove STT data 233 | rsttd_node = rsttd_comp( 234 | input_path = ma_node.outputs.output_dts, 235 | storage_id = storage_account_name, 236 | container_name = container_name, 237 | blob_filepath = blob_filepath 238 | ) 239 | rsttd_node.compute = aml_cpu_cluster 240 | 241 | 242 | # Create a pipeline 243 | pipeline_job = stt() 244 | 245 | # Component register 246 | ml_client.components.create_or_update(pipeline_job.component, version="1") 247 | 248 | 249 | if __name__=="__main__": 250 | fire.Fire(main) -------------------------------------------------------------------------------- /stt_aml_deploy/components/asr/src/main.py: -------------------------------------------------------------------------------- 1 | # Libraries 2 | import argparse 3 | import sys 4 | import logging as log 5 | from pathlib import Path 6 | import os 7 | import re 8 | import time 9 | import json 10 | from typing import List, Dict, Tuple, Optional 11 | from azure.identity import DefaultAzureCredential, ClientSecretCredential 12 | from azure.keyvault.secrets import SecretClient 13 | import pgpy 14 | import torch 15 | from faster_whisper import WhisperModel 16 | 17 | 18 | # Setup logs 19 | root = log.getLogger() 20 | root.setLevel(log.DEBUG) 21 | handler = log.StreamHandler(sys.stdout) 22 | handler.setLevel(log.DEBUG) 23 | formatter = log.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 24 | handler.setFormatter(formatter) 25 | root.addHandler(handler) 26 | 27 | # Keyvault handling class 28 | class CredentialManager(): 29 | def __init__( 30 | self, 31 | keyvault_name:str, 32 | secret_tenant_sp:str=None, 33 | secret_client_sp:str=None, 34 | secret_sp:str=None, 35 | puk_secret_name:str=None, 36 | prk_secret_name:str=None, 37 | prk_password_secret_name:str=None, 38 | ) -> None: 39 | """Base class to handle PGP encryption system in Azure. 40 | 41 | Args: 42 | keyvault_name (str): KeyVault resource where secrets are stored. 43 | secret_tenant_sp (str): Service principal tenant_id secret, stored in KeyVault 'keyvault_name'. 44 | secret_client_sp (str): Service principal client_id secret, stored in KeyVault 'keyvault_name'. 45 | secret_sp (str): Service principal secret_client secret, stored in KeyVault 'keyvault_name'. 46 | """ 47 | self.keyvault_name = keyvault_name 48 | self.secret_tenant_sp = secret_tenant_sp 49 | self.secret_client_sp = secret_client_sp 50 | self.secret_sp = secret_sp 51 | self.login = None 52 | # Import public key from PGP 53 | puk_secret_value = self.fetch_secret(self.default_login(), puk_secret_name) 54 | self.public_key, _ = pgpy.PGPKey.from_blob(puk_secret_value) 55 | # Retrieve pk secrets 56 | sc = self.sp_login() 57 | self.enable_secret(sc, prk_secret_name, True) 58 | self.enable_secret(sc, prk_password_secret_name, True) 59 | pk_secret_value = self.fetch_secret(sc, prk_secret_name) 60 | self.pk_pass_secret_value = self.fetch_secret(sc, prk_password_secret_name) 61 | # Fetch pk key 62 | self.private_key, _ = pgpy.PGPKey.from_blob(pk_secret_value) 63 | 64 | 65 | def encrypt( 66 | self, 67 | input_path:str, 68 | output_path:str, 69 | filenames:List[str], 70 | remove_input:bool=False, 71 | secret_client:SecretClient=None 72 | ) -> None: 73 | # Check input is a list 74 | if isinstance(filenames, str): 75 | filenames = [filenames] 76 | # Default login 77 | if ((self.login!='default') | (secret_client is None)): 78 | secret_client = self.default_login() 79 | # Loop 80 | for filename in filenames: 81 | input_filepath = os.path.join(input_path, filename) 82 | output_filepath = os.path.join(output_path, filename) 83 | folder_path, fn, ext = self.get_file_attr(input_filepath) 84 | if ext=='.pgp': 85 | log.warning(f"File {fn} is already encrypted. Skipping...") 86 | continue 87 | with open(input_filepath, 'rb') as f: 88 | message = pgpy.PGPMessage.new(f.read()) 89 | encrypted_message = self.public_key.encrypt(message) 90 | encrypted_message = str(encrypted_message) 91 | with open(output_filepath+'.pgp', "w") as f: 92 | f.write(encrypted_message) 93 | log.info(f"File {fn+ext+'.pgp'} has been generated in {folder_path}.") 94 | if remove_input: os.remove(input_filepath) 95 | log.info(f"File {fn+ext} has been removed.") 96 | 97 | 98 | def decrypt( 99 | self, 100 | input_path: str, 101 | output_path: str, 102 | filenames:List[str], 103 | remove_input:bool=False, 104 | secret_client:SecretClient=None 105 | ) -> None: 106 | # Check input is a list 107 | if isinstance(filenames, str): 108 | filenames = [filenames] 109 | # Service principal login 110 | if ((self.login!='sp') | (secret_client is None)): 111 | secret_client = self.sp_login() 112 | # Loop 113 | for filename in filenames: 114 | input_filepath = os.path.join(input_path, filename) 115 | folder_path, fn, ext = self.get_file_attr(input_filepath) 116 | if ext not in ['.pgp', '.enc']: 117 | log.warning(f"File {fn} is already decrypted. Skipping...") 118 | continue 119 | with self.private_key.unlock(self.pk_pass_secret_value) as ukey: 120 | if ukey: 121 | encrypted_message = pgpy.PGPMessage.from_file(input_filepath) 122 | decrypted_message = ukey.decrypt(encrypted_message).message 123 | if isinstance(decrypted_message, str): 124 | with open(os.path.join(output_path, fn), "w") as f: 125 | f.write(decrypted_message) 126 | elif isinstance(decrypted_message, bytearray): 127 | with open(os.path.join(output_path, fn), "wb") as f: 128 | f.write(decrypted_message) 129 | else: 130 | log.error(f"File {fn} returned a decrypted message that it's not either str not bytearray. Please check.") 131 | raise ValueError(f"File {fn} returned a decrypted message that it's not either str not bytearray. Please check.") 132 | log.info(f"File {fn} has been generated in {folder_path}.") 133 | if remove_input: os.remove(input_filepath) 134 | log.info(f"File {fn+ext} has been removed.") 135 | else: 136 | log.error(f"Private key password is not correct.") 137 | raise ValueError(f"Private key password is not correct.") 138 | 139 | 140 | def default_login(self) -> SecretClient: 141 | credential = DefaultAzureCredential() 142 | credential.get_token("https://management.azure.com/.default") 143 | secret_client = SecretClient(vault_url=f"https://{self.keyvault_name}.vault.azure.net/", credential=credential) 144 | self.login = 'default' 145 | return secret_client 146 | 147 | 148 | def sp_login(self) -> SecretClient: 149 | # Make sure all parameters are in place 150 | if ((self.secret_tenant_sp is None) | (self.secret_client_sp is None) | (self.secret_sp is None)): 151 | log.error(f"Service principal credentials have not been set up properly.") 152 | raise ValueError(f"Service principal credentials have not been set up properly.") 153 | # Get secret client 154 | secret_client = self.default_login() 155 | tenant_id = secret_client.get_secret(name=self.secret_tenant_sp).value 156 | client_id = secret_client.get_secret(name=self.secret_client_sp).value 157 | client_secret = secret_client.get_secret(name=self.secret_sp).value 158 | credential = ClientSecretCredential(tenant_id, client_id, client_secret) 159 | secret_client = SecretClient(vault_url=f"https://{self.keyvault_name}.vault.azure.net/", credential=credential) 160 | self.login = 'sp' 161 | return secret_client 162 | 163 | 164 | def enable_secret( 165 | self, 166 | secret_client:SecretClient=None, 167 | secret_name:str=None, 168 | enable:bool=False 169 | ) -> None: 170 | # Get the right login for the operation 171 | if self.login!='sp': 172 | secret_client = self.sp_login() 173 | # Check secret current status 174 | try: 175 | secret_status = secret_client.get_secret(secret_name).properties.enabled 176 | except: 177 | secret_status = False 178 | # Compare with input action 179 | if secret_status==enable: 180 | s = 'enabled' if enable else 'disabled' 181 | log.info(f"Secret {secret_name} is already {s}.") 182 | else: 183 | s = 'enabled' if enable else 'disabled' 184 | secret_client.update_secret_properties(secret_name, enabled=enable) 185 | log.info(f"Secret {secret_name} is now {s}.") 186 | 187 | 188 | def fetch_secret( 189 | self, 190 | secret_client:SecretClient, 191 | secret_name:str 192 | ) -> str: 193 | secret_value = secret_client.get_secret(secret_name).value 194 | return secret_value 195 | 196 | 197 | @staticmethod 198 | def get_file_attr( 199 | filepath:str 200 | ) -> List[str]: 201 | """Helper function to consistently split a filepath into folder path, filename and extension. 202 | 203 | Args: 204 | filepath (str): Path where file is stored. 205 | 206 | Returns: 207 | List[str]: Folder path, file name and file extension. 208 | """ 209 | folder_path = '/'.join(filepath.split('/')[:-1]) 210 | fn, ext = os.path.splitext(filepath.split('/')[-1]) 211 | return folder_path, fn, ext 212 | 213 | 214 | # Helper function to cleanup audios directory 215 | def delete_files_in_directory_and_subdirectories(directory_path): 216 | try: 217 | for root, dirs, files in os.walk(directory_path): 218 | for file in files: 219 | file_path = os.path.join(root, file) 220 | os.remove(file_path) 221 | print("All files and subdirectories deleted successfully.") 222 | except OSError: 223 | print("Error occurred while deleting files and subdirectories.") 224 | 225 | 226 | # 227 | # Scoring (entry) script: entry point for execution, scoring script should contain two functions: 228 | # * init(): this function should be used for any costly or common preparation for subsequent inferences, e.g., 229 | # deserializing and loading the model into a global object. 230 | # * run(mini_batch): The method to be parallelized. Each invocation will have one minibatch. 231 | # * mini_batch: Batch inference will invoke run method and pass either a list or Pandas DataFrame as an argument to the method. 232 | # Each entry in min_batch will be - a filepath if input is a FileDataset, a Pandas DataFrame if input is a TabularDataset. 233 | # * return value: run() method should return a Pandas DataFrame or an array. 234 | # For append_row output_action, these returned elements are appended into the common output file. 235 | # For summary_only, the contents of the elements are ignored. 236 | # For all output actions, each returned output element indicates one successful inference of input element in the input mini-batch. 237 | # 238 | 239 | def init(): 240 | """Init""" 241 | # Managed output path to control where objects are returned 242 | parser = argparse.ArgumentParser( 243 | allow_abbrev=False, description="ParallelRunStep Agent" 244 | ) 245 | parser.add_argument("--keyvault_name", type=str) 246 | parser.add_argument("--secret_tenant_sp", type=str) 247 | parser.add_argument("--secret_client_sp", type=str) 248 | parser.add_argument("--secret_sp", type=str) 249 | parser.add_argument("--pk_secret", type=str) 250 | parser.add_argument("--pk_pass_secret", type=str) 251 | parser.add_argument("--pubk_secret", type=str) 252 | parser.add_argument("--whisper_model_name", type=str, default='large-v3') 253 | parser.add_argument("--num_workers", type=int, default=4) 254 | parser.add_argument("--beam_size", type=int, default=5) 255 | parser.add_argument("--word_level_timestamps", type=bool, default=True) 256 | parser.add_argument("--condition_on_previous_text", type=bool, default=True) 257 | parser.add_argument("--compute_type", type=str, default='float16') 258 | parser.add_argument("--language_code", type=str, default='es') 259 | parser.add_argument("--output_asr_path", type=str) 260 | args, _ = parser.parse_known_args() 261 | 262 | # Device 263 | global device 264 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 265 | 266 | # Folder structure 267 | Path('./decrypted_files').mkdir(parents=True, exist_ok=True) 268 | 269 | # Encrypt params 270 | global keyvault_name, secret_tenant_sp, secret_client_sp, secret_sp, pk_secret, pk_pass_secret, pubk_secret 271 | keyvault_name = args.keyvault_name 272 | secret_tenant_sp = args.secret_tenant_sp 273 | secret_client_sp = args.secret_client_sp 274 | secret_sp = args.secret_sp 275 | pk_secret = args.pk_secret 276 | pk_pass_secret = args.pk_pass_secret 277 | pubk_secret = args.pubk_secret 278 | 279 | # Instantiate credential manager 280 | global cm 281 | cm = CredentialManager(keyvault_name, secret_tenant_sp, secret_client_sp, secret_sp, pubk_secret, pk_secret, pk_pass_secret) 282 | 283 | # ASR params 284 | global beam_size, word_level_timestamps, condition_on_previous_text, language_code, output_asr_path 285 | beam_size = args.beam_size 286 | word_level_timestamps = args.word_level_timestamps 287 | condition_on_previous_text = args.condition_on_previous_text 288 | language_code = args.language_code 289 | output_asr_path = args.output_asr_path 290 | 291 | # Folder structure 292 | Path(output_asr_path).mkdir(parents=True, exist_ok=True) 293 | 294 | # ASR models 295 | global whisper_model 296 | whisper_model = WhisperModel( 297 | model_size_or_path=args.whisper_model_name, 298 | device=device, 299 | compute_type=args.compute_type, 300 | cpu_threads=os.cpu_count(), 301 | num_workers=args.num_workers 302 | ) 303 | 304 | 305 | def run(mini_batch): 306 | 307 | for elem in mini_batch: 308 | # Read file and filter if necessary (we are only looking for files with pattern '(.*?)_metadata.json') 309 | pathdir = Path(elem) 310 | if not re.search(r'(.*?).*_prep\.json\.pgp$', str(pathdir)): 311 | log.info(f"File {str(pathdir)} does not contain metadata from preprocessing. Skipping...") 312 | continue 313 | input_path = '/'.join(str(pathdir).split('/')[:-1]) 314 | fn, _ = os.path.splitext(str(pathdir).split('/')[-1]) 315 | fn = re.findall('(.*?)_prep', fn)[0] # remove '_prep' from filename to get unique_id 316 | 317 | # Fetch metadata 318 | log.info(f"Processing file {fn}:") 319 | cm.decrypt(input_path, './decrypted_files', f"{fn}_prep.json.pgp") 320 | with open(f'./decrypted_files/{fn}_prep.json', 'r') as f: 321 | metadata_dct = json.load(f) 322 | 323 | # Ensure audio contains activity 324 | if len(metadata_dct['vad_timestamps'])==0: 325 | log.info(f"Audio {fn} does not contain any activity. Generating dummy metadata:") 326 | with open(f"./decrypted_files/{fn}_asr.json", 'w') as f: 327 | json.dump( 328 | { 329 | 'vad_timestamps': metadata_dct['vad_timestamps'], # List of dictionaries with keys 'start', 'end' 330 | 'segments': [] 331 | }, 332 | f, 333 | indent=4, 334 | ensure_ascii=False 335 | ) 336 | cm.encrypt('./decrypted_files', output_asr_path, f"{fn}_asr.json", True) 337 | continue 338 | 339 | # 340 | # Decrypt (if needed) 341 | # 342 | if os.path.isfile(f"{input_path}/{fn}.wav.pgp"): 343 | log.info(f"Decrypt:") 344 | cm.decrypt(input_path, './decrypted_files', f"{fn}.wav.pgp") 345 | filepath = f"./decrypted_files/{fn}.wav" 346 | elif os.path.isfile(f"{input_path}/{fn}.wav"): 347 | filepath = f"{input_path}/{fn}.wav" 348 | 349 | # 350 | # Transcription 351 | # 352 | log.info(f"\tASR:") 353 | transcription_time = time.time() 354 | segments, _ = whisper_model.transcribe( 355 | filepath, 356 | beam_size=beam_size, 357 | language=language_code, 358 | condition_on_previous_text=condition_on_previous_text, 359 | vad_filter=False, 360 | word_timestamps=word_level_timestamps 361 | ) 362 | 363 | if word_level_timestamps: 364 | segs = [] 365 | end_repl = lambda text: re.sub(r'\s([?.!"](?:\s|$))', r'\1', text) 366 | start_repl = lambda text: re.sub(r'([¿¡"])\s+', r'\1', text) 367 | for x in segments: 368 | words = [] 369 | if len(x.words)==0: continue # So that global stats basen on word ts are not messed up 370 | for word in x.words: 371 | words.append( 372 | { 373 | 'start':word.start, 374 | 'end':word.end, 375 | 'text':end_repl(start_repl(word.word.strip())), 376 | 'confidence': word.probability 377 | } 378 | ) 379 | s = { 380 | 'start':words[0]['start'], 381 | 'end':words[-1]['end'], 382 | 'text':' '.join([w['text'] for w in words]), 383 | 'confidence': sum([w['confidence'] for w in words])/len([w['confidence'] for w in words]) 384 | } 385 | s['words'] = words 386 | segs.append(s) 387 | else: 388 | segs = [{'start': x.start, 'end': x.end, 'text': end_repl(start_repl(x.text.strip()))} for x in segments] 389 | transcription_time = time.time()-transcription_time 390 | log.info(f"\t\tTranscription time: {transcription_time}") 391 | 392 | # Build metadata 393 | mtd = { 394 | "transcription_time": transcription_time 395 | } 396 | # Save output 397 | with open(f'./decrypted_files/{fn}_asr.json', 'w', encoding='utf8') as f: 398 | json.dump( 399 | { 400 | 'segments': segs, 401 | 'duration': metadata_dct['duration'], 402 | 'vad_timestamps': metadata_dct['vad_timestamps'], # List of dictionaries with keys 'start', 'end' 403 | 'metadata': {**metadata_dct['metadata'], **mtd} 404 | }, 405 | f, 406 | indent=4, 407 | ensure_ascii=False 408 | ) 409 | cm.encrypt('./decrypted_files', output_asr_path, f"{fn}_asr.json", True) 410 | 411 | # Cleanup resources 412 | delete_files_in_directory_and_subdirectories('./decrypted_files') 413 | 414 | return mini_batch 415 | 416 | 417 | def shutdown(): 418 | cm.enable_secret(cm.sp_login(), pk_secret, False) 419 | cm.enable_secret(cm.sp_login(), pk_pass_secret, False) -------------------------------------------------------------------------------- /stt_aml_deploy/components/diar/src/main.py: -------------------------------------------------------------------------------- 1 | # Libraries 2 | import argparse 3 | import sys 4 | import logging as log 5 | import requests 6 | from pathlib import Path 7 | import os 8 | import re 9 | import json 10 | import time 11 | from typing import List, Dict, Tuple, Optional 12 | from azure.identity import DefaultAzureCredential, ClientSecretCredential 13 | from azure.keyvault.secrets import SecretClient 14 | import pgpy 15 | from nemo.collections.asr.parts.utils.diarization_utils import OfflineDiarWithASR 16 | from omegaconf import OmegaConf 17 | 18 | # Setup logs 19 | root = log.getLogger() 20 | root.setLevel(log.DEBUG) 21 | handler = log.StreamHandler(sys.stdout) 22 | handler.setLevel(log.DEBUG) 23 | formatter = log.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 24 | handler.setFormatter(formatter) 25 | root.addHandler(handler) 26 | 27 | 28 | # Keyvault handling class 29 | class CredentialManager(): 30 | def __init__( 31 | self, 32 | keyvault_name:str, 33 | secret_tenant_sp:str=None, 34 | secret_client_sp:str=None, 35 | secret_sp:str=None, 36 | puk_secret_name:str=None, 37 | prk_secret_name:str=None, 38 | prk_password_secret_name:str=None, 39 | ) -> None: 40 | """Base class to handle PGP encryption system in Azure. 41 | 42 | Args: 43 | keyvault_name (str): KeyVault resource where secrets are stored. 44 | secret_tenant_sp (str): Service principal tenant_id secret, stored in KeyVault 'keyvault_name'. 45 | secret_client_sp (str): Service principal client_id secret, stored in KeyVault 'keyvault_name'. 46 | secret_sp (str): Service principal secret_client secret, stored in KeyVault 'keyvault_name'. 47 | """ 48 | self.keyvault_name = keyvault_name 49 | self.secret_tenant_sp = secret_tenant_sp 50 | self.secret_client_sp = secret_client_sp 51 | self.secret_sp = secret_sp 52 | self.login = None 53 | # Import public key from PGP 54 | puk_secret_value = self.fetch_secret(self.default_login(), puk_secret_name) 55 | self.public_key, _ = pgpy.PGPKey.from_blob(puk_secret_value) 56 | # Retrieve pk secrets 57 | sc = self.sp_login() 58 | self.enable_secret(sc, prk_secret_name, True) 59 | self.enable_secret(sc, prk_password_secret_name, True) 60 | pk_secret_value = self.fetch_secret(sc, prk_secret_name) 61 | self.pk_pass_secret_value = self.fetch_secret(sc, prk_password_secret_name) 62 | # Fetch pk key 63 | self.private_key, _ = pgpy.PGPKey.from_blob(pk_secret_value) 64 | 65 | 66 | def encrypt( 67 | self, 68 | input_path:str, 69 | output_path:str, 70 | filenames:List[str], 71 | remove_input:bool=False, 72 | secret_client:SecretClient=None 73 | ) -> None: 74 | # Check input is a list 75 | if isinstance(filenames, str): 76 | filenames = [filenames] 77 | # Default login 78 | if ((self.login!='default') | (secret_client is None)): 79 | secret_client = self.default_login() 80 | # Loop 81 | for filename in filenames: 82 | input_filepath = os.path.join(input_path, filename) 83 | output_filepath = os.path.join(output_path, filename) 84 | folder_path, fn, ext = self.get_file_attr(input_filepath) 85 | if ext=='.pgp': 86 | log.warning(f"File {fn} is already encrypted. Skipping...") 87 | continue 88 | with open(input_filepath, 'rb') as f: 89 | message = pgpy.PGPMessage.new(f.read()) 90 | encrypted_message = self.public_key.encrypt(message) 91 | encrypted_message = str(encrypted_message) 92 | with open(output_filepath+'.pgp', "w") as f: 93 | f.write(encrypted_message) 94 | log.info(f"File {fn+ext+'.pgp'} has been generated in {folder_path}.") 95 | if remove_input: os.remove(input_filepath) 96 | log.info(f"File {fn+ext} has been removed.") 97 | 98 | 99 | def decrypt( 100 | self, 101 | input_path: str, 102 | output_path: str, 103 | filenames:List[str], 104 | remove_input:bool=False, 105 | secret_client:SecretClient=None 106 | ) -> None: 107 | # Check input is a list 108 | if isinstance(filenames, str): 109 | filenames = [filenames] 110 | # Service principal login 111 | if ((self.login!='sp') | (secret_client is None)): 112 | secret_client = self.sp_login() 113 | # Loop 114 | for filename in filenames: 115 | input_filepath = os.path.join(input_path, filename) 116 | folder_path, fn, ext = self.get_file_attr(input_filepath) 117 | if ext not in ['.pgp', '.enc']: 118 | log.warning(f"File {fn} is already decrypted. Skipping...") 119 | continue 120 | with self.private_key.unlock(self.pk_pass_secret_value) as ukey: 121 | if ukey: 122 | encrypted_message = pgpy.PGPMessage.from_file(input_filepath) 123 | decrypted_message = ukey.decrypt(encrypted_message).message 124 | if isinstance(decrypted_message, str): 125 | with open(os.path.join(output_path, fn), "w") as f: 126 | f.write(decrypted_message) 127 | elif isinstance(decrypted_message, bytearray): 128 | with open(os.path.join(output_path, fn), "wb") as f: 129 | f.write(decrypted_message) 130 | else: 131 | log.error(f"File {fn} returned a decrypted message that it's not either str not bytearray. Please check.") 132 | raise ValueError(f"File {fn} returned a decrypted message that it's not either str not bytearray. Please check.") 133 | log.info(f"File {fn} has been generated in {folder_path}.") 134 | if remove_input: os.remove(input_filepath) 135 | log.info(f"File {fn+ext} has been removed.") 136 | else: 137 | log.error(f"Private key password is not correct.") 138 | raise ValueError(f"Private key password is not correct.") 139 | 140 | 141 | def default_login(self) -> SecretClient: 142 | credential = DefaultAzureCredential() 143 | credential.get_token("https://management.azure.com/.default") 144 | secret_client = SecretClient(vault_url=f"https://{self.keyvault_name}.vault.azure.net/", credential=credential) 145 | self.login = 'default' 146 | return secret_client 147 | 148 | 149 | def sp_login(self) -> SecretClient: 150 | # Make sure all parameters are in place 151 | if ((self.secret_tenant_sp is None) | (self.secret_client_sp is None) | (self.secret_sp is None)): 152 | log.error(f"Service principal credentials have not been set up properly.") 153 | raise ValueError(f"Service principal credentials have not been set up properly.") 154 | # Get secret client 155 | secret_client = self.default_login() 156 | tenant_id = secret_client.get_secret(name=self.secret_tenant_sp).value 157 | client_id = secret_client.get_secret(name=self.secret_client_sp).value 158 | client_secret = secret_client.get_secret(name=self.secret_sp).value 159 | credential = ClientSecretCredential(tenant_id, client_id, client_secret) 160 | secret_client = SecretClient(vault_url=f"https://{self.keyvault_name}.vault.azure.net/", credential=credential) 161 | self.login = 'sp' 162 | return secret_client 163 | 164 | 165 | def enable_secret( 166 | self, 167 | secret_client:SecretClient=None, 168 | secret_name:str=None, 169 | enable:bool=False 170 | ) -> None: 171 | # Get the right login for the operation 172 | if self.login!='sp': 173 | secret_client = self.sp_login() 174 | # Check secret current status 175 | try: 176 | secret_status = secret_client.get_secret(secret_name).properties.enabled 177 | except: 178 | secret_status = False 179 | # Compare with input action 180 | if secret_status==enable: 181 | s = 'enabled' if enable else 'disabled' 182 | log.info(f"Secret {secret_name} is already {s}.") 183 | else: 184 | s = 'enabled' if enable else 'disabled' 185 | secret_client.update_secret_properties(secret_name, enabled=enable) 186 | log.info(f"Secret {secret_name} is now {s}.") 187 | 188 | 189 | def fetch_secret( 190 | self, 191 | secret_client:SecretClient, 192 | secret_name:str 193 | ) -> str: 194 | secret_value = secret_client.get_secret(secret_name).value 195 | return secret_value 196 | 197 | 198 | @staticmethod 199 | def get_file_attr( 200 | filepath:str 201 | ) -> List[str]: 202 | """Helper function to consistently split a filepath into folder path, filename and extension. 203 | 204 | Args: 205 | filepath (str): Path where file is stored. 206 | 207 | Returns: 208 | List[str]: Folder path, file name and file extension. 209 | """ 210 | folder_path = '/'.join(filepath.split('/')[:-1]) 211 | fn, ext = os.path.splitext(filepath.split('/')[-1]) 212 | return folder_path, fn, ext 213 | 214 | 215 | # Helper function to build NeMo input manifes 216 | def create_msdd_config(audio_filenames:List[str]): 217 | if os.path.exists("input/diar_manifest.jsonl"): os.remove("input/diar_manifest.jsonl") 218 | with open("input/diar_manifest.jsonl", "w") as fp: 219 | for x in audio_filenames: 220 | json.dump({ 221 | "audio_filepath": x, 222 | "offset": 0, 223 | "duration": None, 224 | "label": "infer", 225 | "text": "-", 226 | "rttm_filepath": None, 227 | "uem_filepath": None 228 | }, fp) 229 | fp.write('\n') 230 | 231 | 232 | # Helper function to create voice activity detection manifest 233 | def create_asr_vad_config(segments:Dict, filepath:str): 234 | fn, _ = os.path.splitext(filepath.split('/')[-1]) 235 | asr_vad_manifest=[{"audio_filepath": filepath, "offset": float(x['start']), "duration": float(x['end'])-float(x['start']), "label": "UNK", "uniq_id": fn} for x in segments] 236 | if os.path.exists("./input/asr_vad_manifest.jsonl"): os.remove("./input/asr_vad_manifest.jsonl") 237 | with open("./input/asr_vad_manifest.jsonl", "w") as fp: 238 | for line in asr_vad_manifest: 239 | json.dump(line, fp) 240 | fp.write('\n') 241 | 242 | 243 | # Helper function to process diarization output from method output 244 | def process_diar_output(diar_output): 245 | return {fp:[{'start':float(x.split(' ')[0]), 'end': float(x.split(' ')[1]), 'speaker':x.split(' ')[2][-1]} for x in segments] for fp, segments in diar_output.items()} 246 | 247 | 248 | # Helper function to cleanup audios directory 249 | def delete_files_in_directory_and_subdirectories(directory_path): 250 | try: 251 | for root, dirs, files in os.walk(directory_path): 252 | for file in files: 253 | file_path = os.path.join(root, file) 254 | os.remove(file_path) 255 | print("All files and subdirectories deleted successfully.") 256 | except OSError: 257 | print("Error occurred while deleting files and subdirectories.") 258 | 259 | 260 | # 261 | # Scoring (entry) script: entry point for execution, scoring script should contain two functions: 262 | # * init(): this function should be used for any costly or common preparation for subsequent inferences, e.g., 263 | # deserializing and loading the model into a global object. 264 | # * run(mini_batch): The method to be parallelized. Each invocation will have one minibatch. 265 | # * mini_batch: Batch inference will invoke run method and pass either a list or Pandas DataFrame as an argument to the method. 266 | # Each entry in min_batch will be - a filepath if input is a FileDataset, a Pandas DataFrame if input is a TabularDataset. 267 | # * return value: run() method should return a Pandas DataFrame or an array. 268 | # For append_row output_action, these returned elements are appended into the common output file. 269 | # For summary_only, the contents of the elements are ignored. 270 | # For all output actions, each returned output element indicates one successful inference of input element in the input mini-batch. 271 | # 272 | 273 | def init(): 274 | """Init""" 275 | # Managed output path to control where objects are returned 276 | parser = argparse.ArgumentParser( 277 | allow_abbrev=False, description="ParallelRunStep Agent" 278 | ) 279 | parser.add_argument("--input_audio_path", type=str) 280 | parser.add_argument("--keyvault_name", type=str) 281 | parser.add_argument("--secret_tenant_sp", type=str) 282 | parser.add_argument("--secret_client_sp", type=str) 283 | parser.add_argument("--secret_sp", type=str) 284 | parser.add_argument("--pk_secret", type=str) 285 | parser.add_argument("--pk_pass_secret", type=str) 286 | parser.add_argument("--pubk_secret", type=str) 287 | parser.add_argument("--event_type", type=str, default='telephonic') 288 | parser.add_argument("--max_num_speakers", type=int, default=3) 289 | parser.add_argument("--min_window_length", type=float, default=0.2) 290 | parser.add_argument("--overlap_threshold", type=float, default=0.8) 291 | parser.add_argument("--output_diar_path", type=str) 292 | args, _ = parser.parse_known_args() 293 | 294 | # Encrypt params 295 | global keyvault_name, secret_tenant_sp, secret_client_sp, secret_sp, pk_secret, pk_pass_secret, pubk_secret 296 | keyvault_name = args.keyvault_name 297 | secret_tenant_sp = args.secret_tenant_sp 298 | secret_client_sp = args.secret_client_sp 299 | secret_sp = args.secret_sp 300 | pk_secret = args.pk_secret 301 | pk_pass_secret = args.pk_pass_secret 302 | pubk_secret = args.pubk_secret 303 | 304 | # Instantiate credential manager 305 | global cm 306 | cm = CredentialManager(keyvault_name, secret_tenant_sp, secret_client_sp, secret_sp, pubk_secret, pk_secret, pk_pass_secret) 307 | 308 | # Diarization parameters 309 | global msdd_model, msdd_cfg, input_audio_path, output_diar_path 310 | input_audio_path = args.input_audio_path 311 | output_diar_path = args.output_diar_path 312 | 313 | # Folder structure 314 | Path('./input').mkdir(parents=True, exist_ok=True) 315 | Path('./decrypted_files').mkdir(parents=True, exist_ok=True) 316 | Path('./nemo_diar_output').mkdir(parents=True, exist_ok=True) 317 | Path(output_diar_path).mkdir(parents=True, exist_ok=True) 318 | 319 | # Config files 320 | query_parameters = {"downloadformat": "yaml"} 321 | response = requests.get('https://raw.githubusercontent.com/hedrergudene/asr-sd-pipeline/41ea72a4efde3ce4c20bf7e78f135684dd4a0b55/stt_aml_deploy/components/diar/src/input/diar_infer_telephonic.yaml', params=query_parameters) 322 | with open("./input/diar_infer_telephonic.yaml", mode="wb") as f: 323 | f.write(response.content) 324 | response = requests.get('https://raw.githubusercontent.com/hedrergudene/asr-sd-pipeline/41ea72a4efde3ce4c20bf7e78f135684dd4a0b55/stt_aml_deploy/components/diar/src/input/diar_infer_meeting.yaml', params=query_parameters) 325 | with open("./input/diar_infer_meeting.yaml", mode="wb") as f: 326 | f.write(response.content) 327 | 328 | # Read NeMo MSDD configuration file 329 | round_digits = lambda number, digits: int(number*10**digits)/10**digits 330 | msdd_cfg = OmegaConf.load(f'./input/diar_infer_{args.event_type}.yaml') 331 | msdd_cfg.diarizer.clustering.parameters.max_num_speakers = args.max_num_speakers 332 | msdd_cfg.diarizer.vad.external_vad_manifest='./input/asr_vad_manifest.json' 333 | msdd_cfg.diarizer.asr.parameters.asr_based_vad = True 334 | msdd_cfg.diarizer.speaker_embeddings.parameters.window_length_in_sec = [ 335 | round_digits(6*args.min_window_length,2), 336 | round_digits(4*args.min_window_length,2), 337 | round_digits(3*args.min_window_length,2), 338 | round_digits(2*args.min_window_length,2), 339 | round_digits(args.min_window_length,2) 340 | ] 341 | msdd_cfg.diarizer.speaker_embeddings.parameters.shift_length_in_sec = [ 342 | round_digits(6*args.min_window_length/2,3), 343 | round_digits(4*args.min_window_length/2,3), 344 | round_digits(3*args.min_window_length/2,3), 345 | round_digits(2*args.min_window_length/2,3), 346 | round_digits(args.min_window_length/2,3) 347 | ] 348 | msdd_cfg.diarizer.msdd_model.parameters.sigmoid_threshold = [args.overlap_threshold] 349 | create_msdd_config(['sample_audio.wav']) # initialise msdd cfg 350 | # Initialize NeMo MSDD diarization model 351 | msdd_model = OfflineDiarWithASR(msdd_cfg.diarizer) 352 | 353 | 354 | def run(mini_batch): 355 | 356 | for elem in mini_batch: 357 | # Read file 358 | pathdir = Path(elem) 359 | if not re.search(r'(.*?).*_nfa\.json\.pgp$', str(pathdir)): 360 | log.info(f"File {str(pathdir)} does not contain metadata from NFA. Skipping...") 361 | continue 362 | input_path = '/'.join(str(pathdir).split('/')[:-1]) 363 | fn, _ = os.path.splitext(str(pathdir).split('/')[-1]) 364 | fn = re.findall('(.*?)_nfa', fn)[0] # remove '_prep' from filename to get unique_id 365 | log.info(f"Processing file {fn}:") 366 | # Read word-level transcription to fetch timestamps 367 | cm.decrypt(input_path, './decrypted_files', f"{fn}_nfa.json.pgp") 368 | with open(f"./decrypted_files/{fn}_nfa.json", 'r', encoding='utf-8') as f: 369 | x = json.load(f)['segments'] 370 | # Ensure audio contains activity 371 | if len(x)==0: 372 | log.info(f"Audio {fn} does not contain any activity. Generating dummy metadata:") 373 | with open(f"./decrypted_files/{fn}_diar.json", 'w') as f: 374 | json.dump( 375 | { 376 | 'vad_timestamps': [], # List of dictionaries with keys 'start', 'end' 377 | 'segments': [] 378 | }, 379 | f, 380 | indent=4, 381 | ensure_ascii=False 382 | ) 383 | cm.encrypt('./decrypted_files', output_diar_path, f"{fn}_diar.json", True) 384 | continue 385 | word_ts = [[w['start'], w['end']] for segment in x for w in segment['words']] 386 | 387 | # 388 | # Decrypt (if needed) 389 | # 390 | if os.path.isfile(f"{input_audio_path}/{fn}.wav.pgp"): 391 | log.info(f"Decrypt:") 392 | cm.decrypt(input_path, './decrypted_files', f"{fn}.wav.pgp") 393 | filepath = f"./decrypted_files/{fn}.wav" 394 | elif os.path.isfile(f"{input_audio_path}/{fn}.wav"): 395 | filepath = f"{input_audio_path}/{fn}.wav" 396 | 397 | # Create ./input/asr_vad_manifest.json 398 | create_asr_vad_config(x, filepath) 399 | 400 | # 401 | # Speaker diarization 402 | # 403 | log.info(f"Run diarization") 404 | diar_time = time.time() 405 | create_msdd_config([filepath]) # initialise msdd cfg 406 | msdd_model.audio_file_list = [filepath] # update audios list 407 | diar_hyp, _ = msdd_model.run_diarization(msdd_cfg, {fn:word_ts}) 408 | diar_time = time.time() - diar_time 409 | log.info(f"\tDiarization time: {diar_time}") 410 | # Process diarization output 411 | log.info(f"Save outputs") 412 | segments = process_diar_output(diar_hyp)[fn] 413 | with open(os.path.join(f"./decrypted_files/{fn}_diar.json"), 'w', encoding='utf8') as f: 414 | json.dump( 415 | { 416 | 'segments': segments, 417 | 'metadata': { 418 | 'diarization_time': diar_time 419 | } 420 | }, 421 | f, 422 | indent=4, 423 | ensure_ascii=False 424 | ) 425 | cm.encrypt('./decrypted_files', output_diar_path, f"{fn}_diar.json", True) 426 | 427 | log.info(f"Cleanup resources") 428 | delete_files_in_directory_and_subdirectories('./decrypted_files') 429 | delete_files_in_directory_and_subdirectories('./nemo_diar_output') 430 | 431 | return mini_batch 432 | 433 | 434 | def shutdown(): 435 | cm.enable_secret(cm.sp_login(), pk_secret, False) 436 | cm.enable_secret(cm.sp_login(), pk_pass_secret, False) -------------------------------------------------------------------------------- /stt_aml_deploy/components/prep/src/main.py: -------------------------------------------------------------------------------- 1 | # Requierments 2 | ##Essentials 3 | import logging as log 4 | import subprocess 5 | import json 6 | import os 7 | import sys 8 | from pathlib import Path 9 | import time 10 | import shlex 11 | import numpy as np 12 | from typing import List, Dict, Tuple, Optional 13 | from azure.identity import DefaultAzureCredential, ClientSecretCredential 14 | from azure.keyvault.secrets import SecretClient 15 | import pgpy 16 | import pymongo 17 | import argparse 18 | ## Audio processing 19 | import demucs.separate 20 | import torch 21 | 22 | # Setup logs 23 | root = log.getLogger() 24 | root.setLevel(log.DEBUG) 25 | handler = log.StreamHandler(sys.stdout) 26 | handler.setLevel(log.DEBUG) 27 | formatter = log.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 28 | handler.setFormatter(formatter) 29 | root.addHandler(handler) 30 | 31 | # Keyvault handling class 32 | class CredentialManager(): 33 | def __init__( 34 | self, 35 | keyvault_name:str, 36 | secret_tenant_sp:str=None, 37 | secret_client_sp:str=None, 38 | secret_sp:str=None, 39 | puk_secret_name:str=None, 40 | prk_secret_name:str=None, 41 | prk_password_secret_name:str=None, 42 | ) -> None: 43 | """Base class to handle PGP encryption system in Azure. 44 | 45 | Args: 46 | keyvault_name (str): KeyVault resource where secrets are stored. 47 | secret_tenant_sp (str): Service principal tenant_id secret, stored in KeyVault 'keyvault_name'. 48 | secret_client_sp (str): Service principal client_id secret, stored in KeyVault 'keyvault_name'. 49 | secret_sp (str): Service principal secret_client secret, stored in KeyVault 'keyvault_name'. 50 | """ 51 | self.keyvault_name = keyvault_name 52 | self.secret_tenant_sp = secret_tenant_sp 53 | self.secret_client_sp = secret_client_sp 54 | self.secret_sp = secret_sp 55 | self.login = None 56 | # Import public key from PGP 57 | puk_secret_value = self.fetch_secret(self.default_login(), puk_secret_name) 58 | self.public_key, _ = pgpy.PGPKey.from_blob(puk_secret_value) 59 | # Retrieve pk secrets 60 | sc = self.sp_login() 61 | self.enable_secret(sc, prk_secret_name, True) 62 | self.enable_secret(sc, prk_password_secret_name, True) 63 | pk_secret_value = self.fetch_secret(sc, prk_secret_name) 64 | self.pk_pass_secret_value = self.fetch_secret(sc, prk_password_secret_name) 65 | # Fetch pk key 66 | self.private_key, _ = pgpy.PGPKey.from_blob(pk_secret_value) 67 | 68 | 69 | def encrypt( 70 | self, 71 | input_path:str, 72 | output_path:str, 73 | filenames:List[str], 74 | remove_input:bool=False, 75 | secret_client:SecretClient=None 76 | ) -> None: 77 | # Check input is a list 78 | if isinstance(filenames, str): 79 | filenames = [filenames] 80 | # Default login 81 | if ((self.login!='default') | (secret_client is None)): 82 | secret_client = self.default_login() 83 | # Loop 84 | for filename in filenames: 85 | input_filepath = os.path.join(input_path, filename) 86 | output_filepath = os.path.join(output_path, filename) 87 | folder_path, fn, ext = self.get_file_attr(input_filepath) 88 | if ext=='.pgp': 89 | log.warning(f"File {fn} is already encrypted. Skipping...") 90 | continue 91 | with open(input_filepath, 'rb') as f: 92 | message = pgpy.PGPMessage.new(f.read()) 93 | encrypted_message = self.public_key.encrypt(message) 94 | encrypted_message = str(encrypted_message) 95 | with open(output_filepath+'.pgp', "w") as f: 96 | f.write(encrypted_message) 97 | log.info(f"File {fn+ext+'.pgp'} has been generated in {folder_path}.") 98 | if remove_input: os.remove(input_filepath) 99 | log.info(f"File {fn+ext} has been removed.") 100 | 101 | 102 | def decrypt( 103 | self, 104 | input_path: str, 105 | output_path: str, 106 | filenames:List[str], 107 | remove_input:bool=False, 108 | secret_client:SecretClient=None 109 | ) -> None: 110 | # Check input is a list 111 | if isinstance(filenames, str): 112 | filenames = [filenames] 113 | # Service principal login 114 | if ((self.login!='sp') | (secret_client is None)): 115 | secret_client = self.sp_login() 116 | # Loop 117 | for filename in filenames: 118 | input_filepath = os.path.join(input_path, filename) 119 | folder_path, fn, ext = self.get_file_attr(input_filepath) 120 | if ext not in ['.pgp', '.enc']: 121 | log.warning(f"File {fn} is already decrypted. Skipping...") 122 | continue 123 | with self.private_key.unlock(self.pk_pass_secret_value) as ukey: 124 | if ukey: 125 | encrypted_message = pgpy.PGPMessage.from_file(input_filepath) 126 | decrypted_message = ukey.decrypt(encrypted_message).message 127 | if isinstance(decrypted_message, str): 128 | with open(os.path.join(output_path, fn), "w") as f: 129 | f.write(decrypted_message) 130 | elif isinstance(decrypted_message, bytearray): 131 | with open(os.path.join(output_path, fn), "wb") as f: 132 | f.write(decrypted_message) 133 | else: 134 | log.error(f"File {fn} returned a decrypted message that it's not either str not bytearray. Please check.") 135 | raise ValueError(f"File {fn} returned a decrypted message that it's not either str not bytearray. Please check.") 136 | log.info(f"File {fn} has been generated in {folder_path}.") 137 | if remove_input: os.remove(input_filepath) 138 | log.info(f"File {fn+ext} has been removed.") 139 | else: 140 | log.error(f"Private key password is not correct.") 141 | raise ValueError(f"Private key password is not correct.") 142 | 143 | 144 | def default_login(self) -> SecretClient: 145 | credential = DefaultAzureCredential() 146 | credential.get_token("https://management.azure.com/.default") 147 | secret_client = SecretClient(vault_url=f"https://{self.keyvault_name}.vault.azure.net/", credential=credential) 148 | self.login = 'default' 149 | return secret_client 150 | 151 | 152 | def sp_login(self) -> SecretClient: 153 | # Make sure all parameters are in place 154 | if ((self.secret_tenant_sp is None) | (self.secret_client_sp is None) | (self.secret_sp is None)): 155 | log.error(f"Service principal credentials have not been set up properly.") 156 | raise ValueError(f"Service principal credentials have not been set up properly.") 157 | # Get secret client 158 | secret_client = self.default_login() 159 | tenant_id = secret_client.get_secret(name=self.secret_tenant_sp).value 160 | client_id = secret_client.get_secret(name=self.secret_client_sp).value 161 | client_secret = secret_client.get_secret(name=self.secret_sp).value 162 | credential = ClientSecretCredential(tenant_id, client_id, client_secret) 163 | secret_client = SecretClient(vault_url=f"https://{self.keyvault_name}.vault.azure.net/", credential=credential) 164 | self.login = 'sp' 165 | return secret_client 166 | 167 | 168 | def enable_secret( 169 | self, 170 | secret_client:SecretClient=None, 171 | secret_name:str=None, 172 | enable:bool=False 173 | ) -> None: 174 | # Get the right login for the operation 175 | if self.login!='sp': 176 | secret_client = self.sp_login() 177 | # Check secret current status 178 | try: 179 | secret_status = secret_client.get_secret(secret_name).properties.enabled 180 | except: 181 | secret_status = False 182 | # Compare with input action 183 | if secret_status==enable: 184 | s = 'enabled' if enable else 'disabled' 185 | log.info(f"Secret {secret_name} is already {s}.") 186 | else: 187 | s = 'enabled' if enable else 'disabled' 188 | secret_client.update_secret_properties(secret_name, enabled=enable) 189 | log.info(f"Secret {secret_name} is now {s}.") 190 | 191 | 192 | def fetch_secret( 193 | self, 194 | secret_client:SecretClient, 195 | secret_name:str 196 | ) -> str: 197 | secret_value = secret_client.get_secret(secret_name).value 198 | return secret_value 199 | 200 | 201 | @staticmethod 202 | def get_file_attr( 203 | filepath:str 204 | ) -> List[str]: 205 | """Helper function to consistently split a filepath into folder path, filename and extension. 206 | 207 | Args: 208 | filepath (str): Path where file is stored. 209 | 210 | Returns: 211 | List[str]: Folder path, file name and file extension. 212 | """ 213 | folder_path = '/'.join(filepath.split('/')[:-1]) 214 | fn, ext = os.path.splitext(filepath.split('/')[-1]) 215 | return folder_path, fn, ext 216 | 217 | 218 | # Helper method to decode an audio 219 | def preprocess_audio(input_filepath, output_filepath, filename): 220 | """Method to preprocess audios with ffmpeg, using the following configuration: 221 | * '-acodec': Specifies the audio codec to be used. In this case, it's set to 'pcm_s16le', 222 | which stands for 16-bit little-endian PCM (Pulse Code Modulation). 223 | This is a standard audio format. 224 | * '-ac' '1': Sets the number of audio channels to 1, which is mono audio. 225 | * '-ar' '16000': Sets the audio sample rate to 16 kHz. 226 | 227 | Args: 228 | input_filepath (str): Folder where audio lies 229 | output_filepath (str): Folder where audio is to be stored after processing 230 | filename (str): Name of the file (with extension) you are processing. 231 | """ 232 | fn, ext = os.path.splitext(filename) 233 | command = ['ffmpeg', '-i', f"{input_filepath}/{filename}", '-acodec', 'pcm_s16le', '-ac', '1', '-ar', '16000', f"{output_filepath}/{fn}.wav"] 234 | out = subprocess.run(command,stdout=subprocess.PIPE,stderr=subprocess.PIPE, stdin=subprocess.PIPE) 235 | if out.returncode!=0: 236 | raise RuntimeError(f"An error occured during audio preprocessing. Logs are: {out.stderr}") 237 | 238 | 239 | # Helper method to cleanup audios directory 240 | def delete_files_in_directory_and_subdirectories(directory_path): 241 | try: 242 | for root, dirs, files in os.walk(directory_path): 243 | for file in files: 244 | file_path = os.path.join(root, file) 245 | os.remove(file_path) 246 | print("All files and subdirectories deleted successfully.") 247 | except OSError: 248 | print("Error occurred while deleting files and subdirectories.") 249 | 250 | 251 | # 252 | # Scoring (entry) script: entry point for execution, scoring script should contain two functions: 253 | # * init(): this function should be used for any costly or common preparation for subsequent inferences, e.g., 254 | # deserializing and loading the model into a global object. 255 | # * run(mini_batch): The method to be parallelized. Each invocation will have one minibatch. 256 | # * mini_batch: Batch inference will invoke run method and pass either a list or Pandas DataFrame as an argument to the method. 257 | # Each entry in min_batch will be - a filepath if input is a FileDataset, a Pandas DataFrame if input is a TabularDataset. 258 | # * return value: run() method should return a Pandas DataFrame or an array. 259 | # For append_row output_action, these returned elements are appended into the common output file. 260 | # For summary_only, the contents of the elements are ignored. 261 | # For all output actions, each returned output element indicates one successful inference of input element in the input mini-batch. 262 | # 263 | 264 | def init(): 265 | """Init""" 266 | # Managed output path to control where objects are returned 267 | parser = argparse.ArgumentParser( 268 | allow_abbrev=False, description="ParallelRunStep Agent" 269 | ) 270 | parser.add_argument("--keyvault_name", type=str) 271 | parser.add_argument("--secret_tenant_sp", type=str) 272 | parser.add_argument("--secret_client_sp", type=str) 273 | parser.add_argument("--secret_sp", type=str) 274 | parser.add_argument("--pk_secret", type=str) 275 | parser.add_argument("--pk_pass_secret", type=str) 276 | parser.add_argument("--pubk_secret", type=str) 277 | parser.add_argument("--cosmosdb_name", type=str) 278 | parser.add_argument("--cosmosdb_collection", type=str) 279 | parser.add_argument("--cosmosdb_cs_secret", type=str) 280 | parser.add_argument("--vad_threshold", type=float, default=0.75) 281 | parser.add_argument("--min_speech_duration_ms", type=int, default=250) 282 | parser.add_argument("--min_silence_duration_ms", type=int, default=500) 283 | parser.add_argument("--demucs_model", type=str, default='htdemucs') 284 | parser.add_argument("--output_prep_path", type=str) 285 | args, _ = parser.parse_known_args() 286 | 287 | # Device 288 | global device 289 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 290 | 291 | # Encrypt params 292 | global keyvault_name, secret_tenant_sp, secret_client_sp, secret_sp, pk_secret, pk_pass_secret, pubk_secret 293 | keyvault_name = args.keyvault_name 294 | secret_tenant_sp = args.secret_tenant_sp 295 | secret_client_sp = args.secret_client_sp 296 | secret_sp = args.secret_sp 297 | pk_secret = args.pk_secret 298 | pk_pass_secret = args.pk_pass_secret 299 | pubk_secret = args.pubk_secret 300 | 301 | 302 | # Preprocess params 303 | global vad_threshold, min_speech_duration_ms, min_silence_duration_ms, demucs_model, output_prep_path 304 | vad_threshold = args.vad_threshold 305 | min_speech_duration_ms = args.min_speech_duration_ms 306 | min_silence_duration_ms = args.min_silence_duration_ms 307 | demucs_model = args.demucs_model 308 | output_prep_path = args.output_prep_path 309 | 310 | # Folder structure 311 | Path('./decrypted_files').mkdir(parents=True, exist_ok=True) 312 | Path('./prep_audios').mkdir(parents=True, exist_ok=True) 313 | Path('./trimmed_audios').mkdir(parents=True, exist_ok=True) 314 | Path(output_prep_path).mkdir(parents=True, exist_ok=True) 315 | 316 | # VAD model 317 | global vad_model, get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks 318 | vad_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', 319 | model='silero_vad', 320 | force_reload=True 321 | ) 322 | (get_speech_timestamps, 323 | save_audio, 324 | read_audio, 325 | VADIterator, 326 | collect_chunks) = utils 327 | 328 | # Instantiate credential manager 329 | global cm 330 | cm = CredentialManager(keyvault_name, secret_tenant_sp, secret_client_sp, secret_sp, pubk_secret, pk_secret, pk_pass_secret) 331 | 332 | # MongoDB client 333 | credential = DefaultAzureCredential() 334 | credential.get_token("https://management.azure.com/.default") 335 | sc = SecretClient(vault_url=f"https://{keyvault_name}.vault.azure.net/", credential=credential) 336 | connection_string = sc.get_secret(name=args.cosmosdb_cs_secret).value 337 | mongodb_client = pymongo.MongoClient(connection_string) 338 | # DB connection 339 | if args.cosmosdb_name not in mongodb_client.list_database_names(): 340 | log.error(f"Database {args.cosmosdb_name} not found.") 341 | raise ValueError(f"Database {args.cosmosdb_name} not found.") 342 | else: 343 | cosmosdb_db = mongodb_client[args.cosmosdb_name] 344 | log.info(f"Database {args.cosmosdb_name} connected.") 345 | # Collection connection 346 | if args.cosmosdb_collection not in cosmosdb_db.list_collection_names(): 347 | log.error(f"Collection {args.cosmosdb_collection} not found.") 348 | raise ValueError(f"Collection {args.cosmosdb_collection} not found.") 349 | else: 350 | global cosmosdb_client 351 | cosmosdb_client = cosmosdb_db[args.cosmosdb_collection] 352 | log.info(f"Collection {args.cosmosdb_collection} connected.") 353 | 354 | 355 | def run(mini_batch): 356 | 357 | for elem in mini_batch: 358 | # Read file 359 | pathdir = Path(elem) 360 | input_folder = '/'.join(str(pathdir).split('/')[:-1]) 361 | fn, ext_enc = os.path.splitext(str(pathdir).split('/')[-1]) 362 | fn, ext_file = os.path.splitext(fn) 363 | if ext_enc not in ['.pgp', '.enc']: 364 | if ext_enc in ['.wav', '.mp3']: 365 | log.warning(f"Processing unencrypted file {fn}.") 366 | else: 367 | log.info(f"Skipping file {fn}, encoding extension not valid ('{ext_enc}')") 368 | continue 369 | elif ext_file not in ['.wav', '.mp3']: 370 | log.info(f"Skipping file {fn}, file extension not valid ('{ext_file}')") 371 | continue 372 | elif cosmosdb_client.find_one({"_id": fn}) is not None: 373 | log.info(f"Skipping file {fn}, record already found in cosmosDB collection.") 374 | continue 375 | else: 376 | log.info(f"Processing file {fn}:") 377 | prep_time = time.time() 378 | 379 | # Standarise format 380 | if not os.path.isfile(str(pathdir)): raise ValueError(f"Filepath does not exist.") 381 | if ext_enc in ['.pgp', '.enc']: 382 | log.info(f"Decrypting and preprocessing:") 383 | cm.decrypt(input_folder, './decrypted_files', [f"{fn}{ext_file}{ext_enc}"]) 384 | preprocess_audio('./decrypted_files', './prep_audios', f"{fn}{ext_file}") 385 | else: 386 | log.info(f"Preprocessing:") 387 | ext_file = ext_enc 388 | preprocess_audio(input_folder, './prep_audios', f"{fn}{ext_file}") 389 | 390 | # VAD 391 | log.info(f"Get speech activity timestamps:") 392 | wav = read_audio(f"./prep_audios/{fn}.wav", sampling_rate=16000) 393 | speech_timestamps = get_speech_timestamps( 394 | wav, 395 | vad_model, 396 | threshold=vad_threshold, 397 | sampling_rate=16000, 398 | min_speech_duration_ms=min_speech_duration_ms, 399 | min_silence_duration_ms=min_silence_duration_ms 400 | ) 401 | if len(speech_timestamps)==0: 402 | log.info(f"Audio {fn} does not contain any activity. Generating dummy metadata:") 403 | with open(f"./decrypted_files/{fn}_prep.json", 'w') as f: 404 | json.dump( 405 | { 406 | 'vad_timestamps': speech_timestamps, # List of dictionaries with keys 'start', 'end' 407 | }, 408 | f, 409 | indent=4, 410 | ensure_ascii=False 411 | ) 412 | cm.encrypt('./decrypted_files', output_prep_path, [f"{fn}_prep.json"], True) 413 | continue 414 | save_audio(f"./prep_audios/{fn}_vad.wav", 415 | collect_chunks(speech_timestamps, wav), sampling_rate=16000) 416 | audio_length_s = len(wav)/16000 417 | vad_length_s = sum([(s['end']-s['start']) for s in speech_timestamps])/16000 418 | log.info(f"\tVAD filtered {np.round((vad_length_s/audio_length_s)*100,2)}% of audio. Remaining audio length: {np.round(vad_length_s,2)}s") 419 | 420 | # Demucs 421 | log.info(f"Apply demucs:") 422 | demucs.separate.main(shlex.split(f'--two-stems vocals -o "./prep_audios" -n {demucs_model} "./prep_audios/{fn}_vad.wav"')) 423 | 424 | # Convert demucs output to mono signal 425 | log.info(f"Standardise demucs output:") 426 | command = [ 427 | 'ffmpeg', 428 | '-i', 429 | f"./prep_audios/{demucs_model}/{fn}_vad/vocals.wav", 430 | '-ac', 431 | '1', 432 | '-ar', 433 | '16000', 434 | f"./trimmed_audios/{fn}.wav" 435 | ] 436 | out = subprocess.run(command,stdout=subprocess.PIPE,stdin=subprocess.PIPE) 437 | if out.returncode!=0: 438 | raise RuntimeError(f"An error occured during audio preprocessing. Logs are: {out.stderr}") 439 | 440 | # Encrypt output 441 | log.info(f"Encrypt audio output:") 442 | cm.encrypt('./trimmed_audios', output_prep_path, [f"{fn}.wav"], True) 443 | 444 | prep_time = time.time() - prep_time 445 | log.info(f"\tRutime: {prep_time}") 446 | 447 | # Write metadata file 448 | with open(f"./decrypted_files/{fn}_prep.json", 'w') as f: 449 | json.dump( 450 | { 451 | 'vad_timestamps': speech_timestamps, # List of dictionaries with keys 'start', 'end' 452 | 'duration': audio_length_s, 453 | 'metadata': { 454 | 'preprocessing_time': prep_time 455 | } 456 | }, 457 | f, 458 | indent=4, 459 | ensure_ascii=False 460 | ) 461 | cm.encrypt('./decrypted_files', output_prep_path, [f"{fn}_prep.json"], True) 462 | 463 | # Cleanup resources 464 | delete_files_in_directory_and_subdirectories('./decrypted_files') 465 | delete_files_in_directory_and_subdirectories('./prep_audios') 466 | delete_files_in_directory_and_subdirectories('./trimmed_audios') 467 | 468 | return mini_batch 469 | 470 | 471 | def shutdown(): 472 | cm.enable_secret(cm.sp_login(), pk_secret, False) 473 | cm.enable_secret(cm.sp_login(), pk_pass_secret, False) -------------------------------------------------------------------------------- /stt_aml_deploy/components/nfa/src/main.py: -------------------------------------------------------------------------------- 1 | # Libraries 2 | import sys 3 | import subprocess as sp 4 | import logging as log 5 | from pathlib import Path 6 | import os 7 | import re 8 | import json 9 | import time 10 | from typing import List, Dict, Tuple, Optional 11 | from azure.identity import DefaultAzureCredential, ClientSecretCredential 12 | from azure.keyvault.secrets import SecretClient 13 | import pgpy 14 | import json 15 | import argparse 16 | 17 | 18 | # Setup logs 19 | root = log.getLogger() 20 | root.setLevel(log.DEBUG) 21 | handler = log.StreamHandler(sys.stdout) 22 | handler.setLevel(log.DEBUG) 23 | formatter = log.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 24 | handler.setFormatter(formatter) 25 | root.addHandler(handler) 26 | 27 | 28 | # Keyvault handling class 29 | class CredentialManager(): 30 | def __init__( 31 | self, 32 | keyvault_name:str, 33 | secret_tenant_sp:str=None, 34 | secret_client_sp:str=None, 35 | secret_sp:str=None, 36 | puk_secret_name:str=None, 37 | prk_secret_name:str=None, 38 | prk_password_secret_name:str=None, 39 | ) -> None: 40 | """Base class to handle PGP encryption system in Azure. 41 | 42 | Args: 43 | keyvault_name (str): KeyVault resource where secrets are stored. 44 | secret_tenant_sp (str): Service principal tenant_id secret, stored in KeyVault 'keyvault_name'. 45 | secret_client_sp (str): Service principal client_id secret, stored in KeyVault 'keyvault_name'. 46 | secret_sp (str): Service principal secret_client secret, stored in KeyVault 'keyvault_name'. 47 | """ 48 | self.keyvault_name = keyvault_name 49 | self.secret_tenant_sp = secret_tenant_sp 50 | self.secret_client_sp = secret_client_sp 51 | self.secret_sp = secret_sp 52 | self.login = None 53 | # Import public key from PGP 54 | puk_secret_value = self.fetch_secret(self.default_login(), puk_secret_name) 55 | self.public_key, _ = pgpy.PGPKey.from_blob(puk_secret_value) 56 | # Retrieve pk secrets 57 | sc = self.sp_login() 58 | self.enable_secret(sc, prk_secret_name, True) 59 | self.enable_secret(sc, prk_password_secret_name, True) 60 | pk_secret_value = self.fetch_secret(sc, prk_secret_name) 61 | self.pk_pass_secret_value = self.fetch_secret(sc, prk_password_secret_name) 62 | # Fetch pk key 63 | self.private_key, _ = pgpy.PGPKey.from_blob(pk_secret_value) 64 | 65 | 66 | def encrypt( 67 | self, 68 | input_path:str, 69 | output_path:str, 70 | filenames:List[str], 71 | remove_input:bool=False, 72 | secret_client:SecretClient=None 73 | ) -> None: 74 | # Check input is a list 75 | if isinstance(filenames, str): 76 | filenames = [filenames] 77 | # Default login 78 | if ((self.login!='default') | (secret_client is None)): 79 | secret_client = self.default_login() 80 | # Loop 81 | for filename in filenames: 82 | input_filepath = os.path.join(input_path, filename) 83 | output_filepath = os.path.join(output_path, filename) 84 | folder_path, fn, ext = self.get_file_attr(input_filepath) 85 | if ext=='.pgp': 86 | log.warning(f"File {fn} is already encrypted. Skipping...") 87 | continue 88 | with open(input_filepath, 'rb') as f: 89 | message = pgpy.PGPMessage.new(f.read()) 90 | encrypted_message = self.public_key.encrypt(message) 91 | encrypted_message = str(encrypted_message) 92 | with open(output_filepath+'.pgp', "w") as f: 93 | f.write(encrypted_message) 94 | log.info(f"File {fn+ext+'.pgp'} has been generated in {folder_path}.") 95 | if remove_input: os.remove(input_filepath) 96 | log.info(f"File {fn+ext} has been removed.") 97 | 98 | 99 | def decrypt( 100 | self, 101 | input_path: str, 102 | output_path: str, 103 | filenames:List[str], 104 | remove_input:bool=False, 105 | secret_client:SecretClient=None 106 | ) -> None: 107 | # Check input is a list 108 | if isinstance(filenames, str): 109 | filenames = [filenames] 110 | # Service principal login 111 | if ((self.login!='sp') | (secret_client is None)): 112 | secret_client = self.sp_login() 113 | # Loop 114 | for filename in filenames: 115 | input_filepath = os.path.join(input_path, filename) 116 | folder_path, fn, ext = self.get_file_attr(input_filepath) 117 | if ext not in ['.pgp', '.enc']: 118 | log.warning(f"File {fn} is already decrypted. Skipping...") 119 | continue 120 | with self.private_key.unlock(self.pk_pass_secret_value) as ukey: 121 | if ukey: 122 | encrypted_message = pgpy.PGPMessage.from_file(input_filepath) 123 | decrypted_message = ukey.decrypt(encrypted_message).message 124 | if isinstance(decrypted_message, str): 125 | with open(os.path.join(output_path, fn), "w") as f: 126 | f.write(decrypted_message) 127 | elif isinstance(decrypted_message, bytearray): 128 | with open(os.path.join(output_path, fn), "wb") as f: 129 | f.write(decrypted_message) 130 | else: 131 | log.error(f"File {fn} returned a decrypted message that it's not either str not bytearray. Please check.") 132 | raise ValueError(f"File {fn} returned a decrypted message that it's not either str not bytearray. Please check.") 133 | log.info(f"File {fn} has been generated in {folder_path}.") 134 | if remove_input: os.remove(input_filepath) 135 | log.info(f"File {fn+ext} has been removed.") 136 | else: 137 | log.error(f"Private key password is not correct.") 138 | raise ValueError(f"Private key password is not correct.") 139 | 140 | 141 | def default_login(self) -> SecretClient: 142 | credential = DefaultAzureCredential() 143 | credential.get_token("https://management.azure.com/.default") 144 | secret_client = SecretClient(vault_url=f"https://{self.keyvault_name}.vault.azure.net/", credential=credential) 145 | self.login = 'default' 146 | return secret_client 147 | 148 | 149 | def sp_login(self) -> SecretClient: 150 | # Make sure all parameters are in place 151 | if ((self.secret_tenant_sp is None) | (self.secret_client_sp is None) | (self.secret_sp is None)): 152 | log.error(f"Service principal credentials have not been set up properly.") 153 | raise ValueError(f"Service principal credentials have not been set up properly.") 154 | # Get secret client 155 | secret_client = self.default_login() 156 | tenant_id = secret_client.get_secret(name=self.secret_tenant_sp).value 157 | client_id = secret_client.get_secret(name=self.secret_client_sp).value 158 | client_secret = secret_client.get_secret(name=self.secret_sp).value 159 | credential = ClientSecretCredential(tenant_id, client_id, client_secret) 160 | secret_client = SecretClient(vault_url=f"https://{self.keyvault_name}.vault.azure.net/", credential=credential) 161 | self.login = 'sp' 162 | return secret_client 163 | 164 | 165 | def enable_secret( 166 | self, 167 | secret_client:SecretClient=None, 168 | secret_name:str=None, 169 | enable:bool=False 170 | ) -> None: 171 | # Get the right login for the operation 172 | if self.login!='sp': 173 | secret_client = self.sp_login() 174 | # Check secret current status 175 | try: 176 | secret_status = secret_client.get_secret(secret_name).properties.enabled 177 | except: 178 | secret_status = False 179 | # Compare with input action 180 | if secret_status==enable: 181 | s = 'enabled' if enable else 'disabled' 182 | log.info(f"Secret {secret_name} is already {s}.") 183 | else: 184 | s = 'enabled' if enable else 'disabled' 185 | secret_client.update_secret_properties(secret_name, enabled=enable) 186 | log.info(f"Secret {secret_name} is now {s}.") 187 | 188 | 189 | def fetch_secret( 190 | self, 191 | secret_client:SecretClient, 192 | secret_name:str 193 | ) -> str: 194 | secret_value = secret_client.get_secret(secret_name).value 195 | return secret_value 196 | 197 | 198 | @staticmethod 199 | def get_file_attr( 200 | filepath:str 201 | ) -> List[str]: 202 | """Helper function to consistently split a filepath into folder path, filename and extension. 203 | 204 | Args: 205 | filepath (str): Path where file is stored. 206 | 207 | Returns: 208 | List[str]: Folder path, file name and file extension. 209 | """ 210 | folder_path = '/'.join(filepath.split('/')[:-1]) 211 | fn, ext = os.path.splitext(filepath.split('/')[-1]) 212 | return folder_path, fn, ext 213 | 214 | 215 | # Helper function to prepare ASR output to be aligned. Segments separators are '|' 216 | def create_nfa_config(segments:Dict, filepath:str): 217 | if os.path.exists("input/nfa_manifest.jsonl"): os.remove("input/nfa_manifest.jsonl") 218 | with open("input/nfa_manifest.jsonl", "w") as fp: 219 | json.dump({ 220 | "audio_filepath": filepath, 221 | "text": ' | '.join([x['text'] for x in segments]) 222 | }, fp) 223 | fp.write('\n') 224 | 225 | 226 | # Helper function to process forced alignment output 227 | def process_nfa_output(filename): 228 | # Get word-level timestamps 229 | with open(f"./nemo_nfa_output/ctm/segments/{filename}.ctm", 'r') as f: 230 | sentence_level_ts = f.read().split('\n')[:-1] 231 | sentence_level_ts = [{'start':float(y.split(' ')[2]), 'end':float(y.split(' ')[2])+float(y.split(' ')[3]), 'text':y.split(' ')[-1].replace('', ' ')} for y in sentence_level_ts] 232 | with open(f"./nemo_nfa_output/ctm/words/{filename}.ctm", 'r') as f: 233 | word_level_ts = f.read().split('\n')[:-1] 234 | word_level_ts = [{'start':float(y.split(' ')[2]), 'end':float(y.split(' ')[2])+float(y.split(' ')[3]), 'text':y.split(' ')[-1]} for y in word_level_ts] 235 | sg = [] 236 | shift=0 237 | for h in sentence_level_ts: 238 | sg.append( 239 | { 240 | 'start':h['start'], 241 | 'end':h['end'], 242 | 'text':h['text'], 243 | 'words':word_level_ts[shift:shift+len(h['text'].split(' '))] 244 | } 245 | ) 246 | shift+=len(h['text'].split(' ')) 247 | return sg 248 | 249 | 250 | # Helper function to cleanup audios directory 251 | def delete_files_in_directory_and_subdirectories(directory_path): 252 | try: 253 | for root, dirs, files in os.walk(directory_path): 254 | for file in files: 255 | file_path = os.path.join(root, file) 256 | os.remove(file_path) 257 | log.info("All files and subdirectories deleted successfully.") 258 | except OSError: 259 | log.info("Error occurred while deleting files and subdirectories.") 260 | 261 | 262 | # 263 | # Scoring (entry) script: entry point for execution, scoring script should contain two functions: 264 | # * init(): this function should be used for any costly or common preparation for subsequent inferences, e.g., 265 | # deserializing and loading the model into a global object. 266 | # * run(mini_batch): The method to be parallelized. Each invocation will have one minibatch. 267 | # * mini_batch: Batch inference will invoke run method and pass either a list or Pandas DataFrame as an argument to the method. 268 | # Each entry in min_batch will be - a filepath if input is a FileDataset, a Pandas DataFrame if input is a TabularDataset. 269 | # * return value: run() method should return a Pandas DataFrame or an array. 270 | # For append_row output_action, these returned elements are appended into the common output file. 271 | # For summary_only, the contents of the elements are ignored. 272 | # For all output actions, each returned output element indicates one successful inference of input element in the input mini-batch. 273 | # 274 | 275 | def init(): 276 | """Init""" 277 | # Managed output path to control where objects are returned 278 | parser = argparse.ArgumentParser( 279 | allow_abbrev=False, description="ParallelRunStep Agent" 280 | ) 281 | parser.add_argument("--input_audio_path", type=str) 282 | parser.add_argument("--input_asr_path", type=str) 283 | parser.add_argument("--keyvault_name", type=str) 284 | parser.add_argument("--secret_tenant_sp", type=str) 285 | parser.add_argument("--secret_client_sp", type=str) 286 | parser.add_argument("--secret_sp", type=str) 287 | parser.add_argument("--pk_secret", type=str) 288 | parser.add_argument("--pk_pass_secret", type=str) 289 | parser.add_argument("--pubk_secret", type=str) 290 | parser.add_argument("--nfa_model_name", type=str, default='stt_es_fastconformer_hybrid_large_pc') 291 | parser.add_argument("--batch_size", type=int, default=16) 292 | parser.add_argument("--output_fa_path", type=str) 293 | args, _ = parser.parse_known_args() 294 | 295 | # Encrypt params 296 | global keyvault_name, secret_tenant_sp, secret_client_sp, secret_sp, pk_secret, pk_pass_secret, pubk_secret 297 | keyvault_name = args.keyvault_name 298 | secret_tenant_sp = args.secret_tenant_sp 299 | secret_client_sp = args.secret_client_sp 300 | secret_sp = args.secret_sp 301 | pk_secret = args.pk_secret 302 | pk_pass_secret = args.pk_pass_secret 303 | pubk_secret = args.pubk_secret 304 | 305 | # Instantiate credential manager 306 | global cm 307 | cm = CredentialManager(keyvault_name, secret_tenant_sp, secret_client_sp, secret_sp, pubk_secret, pk_secret, pk_pass_secret) 308 | 309 | # Params 310 | global input_audio_path, nfa_model_name, batch_size, output_fa_path 311 | input_audio_path = args.input_audio_path 312 | nfa_model_name = args.nfa_model_name 313 | batch_size = args.batch_size 314 | output_fa_path = args.output_fa_path 315 | 316 | # Folder structure 317 | Path('./decrypted_files').mkdir(parents=True, exist_ok=True) 318 | Path('./NeMo').mkdir(parents=True, exist_ok=True) 319 | Path('./input').mkdir(parents=True, exist_ok=True) 320 | Path('./nemo_nfa_output').mkdir(parents=True, exist_ok=True) 321 | Path(output_fa_path).mkdir(parents=True, exist_ok=True) 322 | 323 | # Clone repo 324 | result = sp.run( 325 | [ 326 | 'git', 327 | 'clone', 328 | 'https://github.com/NVIDIA/NeMo', 329 | '-b', 330 | 'v1.20.0', 331 | './NeMo' 332 | ], 333 | capture_output=True, 334 | text=True 335 | ) 336 | # Check return code 337 | if result.returncode!=0: 338 | log.error(f"NeMo repo cloning raised an exception: {result.stderr}") 339 | raise RuntimeError(f"NeMo repo cloning raised an exception: {result.stderr}") 340 | 341 | 342 | def run(mini_batch): 343 | 344 | for elem in mini_batch: 345 | # Read file 346 | pathdir = Path(elem) 347 | if not re.search(r'(.*?).*_asr\.json\.pgp$', str(pathdir)): 348 | log.info(f"File {str(pathdir)} does not contain metadata from ASR. Skipping...") 349 | continue 350 | input_path = '/'.join(str(pathdir).split('/')[:-1]) 351 | fn, _ = os.path.splitext(str(pathdir).split('/')[-1]) 352 | fn = re.findall('(.*?)_asr', fn)[0] # remove '_prep' from filename to get unique_id 353 | log.info(f"Processing file {fn}:") 354 | 355 | # Read word-level transcription to fetch timestamps 356 | cm.decrypt(input_path, './decrypted_files', f"{fn}_asr.json.pgp") 357 | with open(f"./decrypted_files/{fn}_asr.json", 'r', encoding='utf-8') as f: 358 | asr_dct = json.load(f) 359 | # Ensure audio contains activity 360 | if len(asr_dct['segments'])==0: 361 | log.info(f"Audio {fn} does not contain any activity. Generating dummy metadata:") 362 | with open(f"./decrypted_files/{fn}_nfa.json", 'w') as f: 363 | json.dump( 364 | { 365 | 'vad_timestamps': [], # List of dictionaries with keys 'start', 'end' 366 | 'segments': [] 367 | }, 368 | f, 369 | indent=4, 370 | ensure_ascii=False 371 | ) 372 | cm.encrypt('./decrypted_files', output_fa_path, f"{fn}_nfa.json", True) 373 | continue 374 | 375 | # 376 | # Decrypt (if needed) 377 | # 378 | if os.path.isfile(f"{input_audio_path}/{fn}.wav.pgp"): 379 | log.info(f"Decrypt:") 380 | cm.decrypt(input_path, './decrypted_files', f"{fn}.wav.pgp") 381 | filepath = f"./decrypted_files/{fn}.wav" 382 | elif os.path.isfile(f"{input_audio_path}/{fn}.wav"): 383 | filepath = f"{input_audio_path}/{fn}.wav" 384 | 385 | # Create config 386 | create_nfa_config(asr_dct['segments'], filepath) 387 | 388 | # 389 | # Forced alignment 390 | # 391 | log.info(f"Run alignment") 392 | # Run script 393 | align_time = time.time() 394 | result = sp.run( 395 | [ 396 | sys.executable, 397 | 'NeMo/tools/nemo_forced_aligner/align.py', 398 | f'pretrained_name="{nfa_model_name}"', 399 | 'manifest_filepath="./input/nfa_manifest.jsonl"', 400 | 'output_dir="./nemo_nfa_output"', 401 | f'batch_size={batch_size}', 402 | 'additional_segment_grouping_separator="|"' 403 | ], 404 | capture_output=True, 405 | text=True, 406 | encoding='utf-8', 407 | errors='xmlcharrefreplace' 408 | ) 409 | align_time = time.time() - align_time 410 | # Check return code 411 | check_run = ((result.returncode==0) & (os.path.isfile(f"./nemo_nfa_output/ctm/segments/{fn}.ctm")) & (os.path.isfile(f"./nemo_nfa_output/ctm/words/{fn}.ctm"))) 412 | if ((not check_run) & (asr_dct['segments'][0].get('words') is None)): 413 | log.error(f"Alignment raised an exception and there are no timestamps available from ASR: {result.stderr}") 414 | raise RuntimeError(f"Alignment raised an exception and there are no timestamps available from ASR: {result.stderr}") 415 | elif ((not check_run) & (asr_dct['segments'][0].get('words') is not None)): 416 | log.warning(f"Alignment raised an exception; using ASR word-level timestamps: {result.stderr}") 417 | # Process output 418 | with open(f"./decrypted_files/{fn}_nfa.json", 'w', encoding='utf8') as f: 419 | json.dump( 420 | { 421 | 'segments': asr_dct['segments'], 422 | 'duration': asr_dct['duration'], 423 | 'vad_timestamps': asr_dct['vad_timestamps'], # List of dictionaries with keys 'start', 'end' 424 | 'metadata': {**asr_dct['metadata'], **{'alignment_time': align_time}} 425 | }, 426 | f, 427 | indent=4, 428 | ensure_ascii=False 429 | ) 430 | cm.encrypt('./decrypted_files', output_fa_path, f"{fn}_nfa.json", True) 431 | elif ((check_run) & (asr_dct['segments'][0].get('words') is None)): 432 | log.info(f"Alignment run successfully. Including word-level timestamps.") 433 | # Update timestamps from both segment-level and word-level information 434 | segments = process_nfa_output(fn) 435 | # Process output 436 | with open(f"./decrypted_files/{fn}_nfa.json", 'w', encoding='utf8') as f: 437 | json.dump( 438 | { 439 | 'segments': segments, 440 | 'duration': asr_dct['duration'], 441 | 'vad_timestamps': asr_dct['vad_timestamps'], # List of dictionaries with keys 'start', 'end' 442 | 'metadata': {**asr_dct['metadata'], **{'alignment_time': align_time}} 443 | }, 444 | f, 445 | indent=4, 446 | ensure_ascii=False 447 | ) 448 | cm.encrypt('./decrypted_files', output_fa_path, f"{fn}_nfa.json", True) 449 | else: 450 | # Update timestamps from both segment-level and word-level information 451 | log.info(f"Alignment run successfully. Updating word-level timestamps.") 452 | segments = process_nfa_output(fn) 453 | 454 | # Keep confidence results from ASR 455 | for asr_seg, nfa_seg in zip(asr_dct['segments'], segments): 456 | c = [] 457 | for asr_word, nfa_word in zip(asr_seg['words'], nfa_seg['words']): 458 | nfa_word['confidence'] = asr_word['confidence'] 459 | c.append(asr_word['confidence']) 460 | nfa_seg['confidence'] = sum(c)/len(c) 461 | 462 | # Process output 463 | with open(f"./decrypted_files/{fn}_nfa.json", 'w', encoding='utf8') as f: 464 | json.dump( 465 | { 466 | 'segments': segments, 467 | 'duration': asr_dct['duration'], 468 | 'vad_timestamps': asr_dct['vad_timestamps'], # List of dictionaries with keys 'start', 'end' 469 | 'metadata': {**asr_dct['metadata'], **{'alignment_time': align_time}} 470 | }, 471 | f, 472 | indent=4, 473 | ensure_ascii=False 474 | ) 475 | cm.encrypt('./decrypted_files', output_fa_path, f"{fn}_nfa.json", True) 476 | log.info(f"Cleanup resources") 477 | delete_files_in_directory_and_subdirectories('./decrypted_files') 478 | delete_files_in_directory_and_subdirectories('./nemo_nfa_output') 479 | 480 | return mini_batch 481 | 482 | 483 | def shutdown(): 484 | cm.enable_secret(cm.sp_login(), pk_secret, False) 485 | cm.enable_secret(cm.sp_login(), pk_pass_secret, False) -------------------------------------------------------------------------------- /stt_aml_deploy/components/merge_align/src/main.py: -------------------------------------------------------------------------------- 1 | # Libraries 2 | import argparse 3 | import sys 4 | import logging as log 5 | from transformers import pipeline 6 | import torch 7 | import re 8 | import json 9 | import os 10 | import re 11 | import time 12 | import bisect 13 | from typing import List, Dict, Tuple, Optional 14 | from azure.identity import DefaultAzureCredential, ClientSecretCredential 15 | from azure.keyvault.secrets import SecretClient 16 | import pgpy 17 | import pymongo 18 | from datetime import datetime 19 | from pathlib import Path 20 | 21 | # Setup logs 22 | root = log.getLogger() 23 | root.setLevel(log.DEBUG) 24 | handler = log.StreamHandler(sys.stdout) 25 | handler.setLevel(log.DEBUG) 26 | formatter = log.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 27 | handler.setFormatter(formatter) 28 | root.addHandler(handler) 29 | 30 | 31 | # Helper functions to align sentences with punctuation signs 32 | def get_word_ts_anchor(s, e, option="mid"): 33 | if option == "end": 34 | return e 35 | elif option == "mid": 36 | return (s + e) / 2 37 | return s 38 | 39 | def get_words_speaker_mapping(wrd_ts, spk_ts, word_anchor_option="mid"): 40 | s, e, sp = spk_ts[0] 41 | wrd_pos, turn_idx = 0, 0 42 | wrd_spk_mapping = [] 43 | for wrd_dict in wrd_ts: 44 | ws, we, wc, wrd = ( 45 | wrd_dict["start"], 46 | wrd_dict["end"], 47 | wrd_dict["confidence"], 48 | wrd_dict["text"], 49 | ) 50 | wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option) 51 | while wrd_pos > float(e): 52 | turn_idx += 1 53 | turn_idx = min(turn_idx, len(spk_ts) - 1) 54 | s, e, sp = spk_ts[turn_idx] 55 | if turn_idx == len(spk_ts) - 1: 56 | e = get_word_ts_anchor(ws, we, option="end") 57 | wrd_spk_mapping.append( 58 | { 59 | "word": wrd, 60 | "start_time": ws, 61 | "end_time": we, 62 | "confidence": wc, 63 | "speaker": sp 64 | } 65 | ) 66 | return wrd_spk_mapping 67 | 68 | def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words, sentence_ending_punctuations): 69 | is_word_sentence_end = ( 70 | lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations 71 | ) 72 | left_idx = word_idx 73 | while ( 74 | left_idx > 0 75 | and word_idx - left_idx < max_words 76 | and speaker_list[left_idx - 1] == speaker_list[left_idx] 77 | and not is_word_sentence_end(left_idx - 1) 78 | ): 79 | left_idx -= 1 80 | 81 | return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1 82 | 83 | def get_last_word_idx_of_sentence(word_idx, word_list, max_words, sentence_ending_punctuations): 84 | is_word_sentence_end = ( 85 | lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations 86 | ) 87 | right_idx = word_idx 88 | while ( 89 | right_idx < len(word_list) 90 | and right_idx - word_idx < max_words 91 | and not is_word_sentence_end(right_idx) 92 | ): 93 | right_idx += 1 94 | 95 | return ( 96 | right_idx 97 | if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx) 98 | else -1 99 | ) 100 | 101 | def get_realigned_ws_mapping_with_punctuation( 102 | word_speaker_mapping, max_words_in_sentence, sentence_ending_punctuations 103 | ): 104 | is_word_sentence_end = ( 105 | lambda x: ((x >= 0) 106 | & (word_speaker_mapping[x]["word"][-1] in sentence_ending_punctuations)) 107 | ) 108 | wsp_len = len(word_speaker_mapping) 109 | 110 | words_list, speaker_list = [], [] 111 | for k, line_dict in enumerate(word_speaker_mapping): 112 | word, speaker = line_dict["word"], line_dict["speaker"] 113 | words_list.append(word) 114 | speaker_list.append(speaker) 115 | 116 | k = 0 117 | while k < len(word_speaker_mapping): 118 | line_dict = word_speaker_mapping[k] 119 | if ( 120 | k < wsp_len - 1 121 | and speaker_list[k] != speaker_list[k + 1] 122 | and not is_word_sentence_end(k) 123 | ): 124 | left_idx = get_first_word_idx_of_sentence( 125 | k, words_list, speaker_list, max_words_in_sentence, sentence_ending_punctuations 126 | ) 127 | right_idx = ( 128 | get_last_word_idx_of_sentence( 129 | k, words_list, max_words_in_sentence - k + left_idx - 1, sentence_ending_punctuations 130 | ) 131 | if left_idx > -1 132 | else -1 133 | ) 134 | if min(left_idx, right_idx) == -1: 135 | k += 1 136 | continue 137 | 138 | spk_labels = speaker_list[left_idx : right_idx + 1] 139 | mod_speaker = max(set(spk_labels), key=spk_labels.count) 140 | if spk_labels.count(mod_speaker) < len(spk_labels) // 2: 141 | k += 1 142 | continue 143 | 144 | speaker_list[left_idx : right_idx + 1] = [mod_speaker] * ( 145 | right_idx - left_idx + 1 146 | ) 147 | k = right_idx 148 | 149 | k += 1 150 | 151 | k, realigned_list = 0, [] 152 | while k < len(word_speaker_mapping): 153 | line_dict = word_speaker_mapping[k].copy() 154 | line_dict["speaker"] = speaker_list[k] 155 | realigned_list.append(line_dict) 156 | k += 1 157 | 158 | return realigned_list 159 | 160 | def get_sentences_speaker_mapping(word_speaker_mapping, spk_ts): 161 | s, e, spk = spk_ts[0] 162 | prev_spk = spk 163 | 164 | snts = [] 165 | cf = [] 166 | words = [] 167 | 168 | snt = { 169 | "speaker": f"Speaker {spk}", 170 | "start_time": s, 171 | "end_time": e, 172 | "confidence": 0, 173 | "text": "", 174 | "words": [] 175 | } 176 | 177 | for wrd_dict in word_speaker_mapping: 178 | wrd, spk = wrd_dict["word"], wrd_dict["speaker"] 179 | if spk != prev_spk: 180 | snts.append(snt) 181 | snt = { 182 | "speaker": f"Speaker {spk}", 183 | "start_time": wrd_dict["start_time"], 184 | "end_time": wrd_dict["end_time"], 185 | "confidence": 0, 186 | "text": "", 187 | "words": [], 188 | } 189 | cf = [] 190 | words = [] 191 | else: 192 | snt["end_time"] = wrd_dict["end_time"] 193 | cf.append(wrd_dict['confidence']) 194 | snt['confidence'] = sum(cf)/len(cf) 195 | words.append(wrd_dict) 196 | snt['words'] = words 197 | snt["text"] += wrd + " " 198 | prev_spk = spk 199 | snt['text'] = snt['text'].strip() 200 | snts.append(snt) 201 | return snts 202 | 203 | 204 | # Class to align VAD timestamps 205 | class SpeechTimestampsMap: 206 | """Helper class to restore original speech timestamps.""" 207 | 208 | def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2): 209 | self.sampling_rate = sampling_rate 210 | self.time_precision = time_precision 211 | self.chunk_end_sample = [] 212 | self.total_silence_before = [] 213 | 214 | previous_end = 0 215 | silent_samples = 0 216 | 217 | for chunk in chunks: 218 | silent_samples += chunk["start"] - previous_end 219 | previous_end = chunk["end"] 220 | 221 | self.chunk_end_sample.append(chunk["end"] - silent_samples) 222 | self.total_silence_before.append(silent_samples / sampling_rate) 223 | 224 | def get_original_time( 225 | self, 226 | time: float, 227 | chunk_index: Optional[int] = None, 228 | ) -> float: 229 | if chunk_index is None: 230 | chunk_index = self.get_chunk_index(time) 231 | 232 | total_silence_before = self.total_silence_before[chunk_index] 233 | return round(total_silence_before + time, self.time_precision) 234 | 235 | def get_chunk_index(self, time: float) -> int: 236 | sample = int(time * self.sampling_rate) 237 | return min( 238 | bisect.bisect(self.chunk_end_sample, sample), 239 | len(self.chunk_end_sample) - 1, 240 | ) 241 | 242 | 243 | # Punctuation model 244 | class PunctuationModel(): 245 | def __init__( 246 | self, 247 | model:str ="kredor/punctuate-all" 248 | ) -> None: 249 | if torch.cuda.is_available(): 250 | self.pipe = pipeline("ner",model, aggregation_strategy="none", device=0) 251 | else: 252 | self.pipe = pipeline("ner",model, aggregation_strategy="none") 253 | 254 | def preprocess(self,text): 255 | #remove markers except for markers in numbers 256 | text = re.sub(r"(? result[result_index]["end"] : 298 | label = result[result_index]['entity'] 299 | score = result[result_index]['score'] 300 | result_index += 1 301 | tagged_words.append([word,label, score]) 302 | 303 | assert len(tagged_words) == len(words) 304 | return tagged_words 305 | 306 | def prediction_to_text(self,prediction): 307 | result = "" 308 | for word, label, _ in prediction: 309 | result += word 310 | if label == "0": 311 | result += " " 312 | if label in ".,?-:": 313 | result += label+" " 314 | return result.strip() 315 | 316 | 317 | # Keyvault handling class 318 | class CredentialManager(): 319 | def __init__( 320 | self, 321 | keyvault_name:str, 322 | secret_tenant_sp:str=None, 323 | secret_client_sp:str=None, 324 | secret_sp:str=None, 325 | puk_secret_name:str=None, 326 | prk_secret_name:str=None, 327 | prk_password_secret_name:str=None, 328 | ) -> None: 329 | """Base class to handle PGP encryption system in Azure. 330 | 331 | Args: 332 | keyvault_name (str): KeyVault resource where secrets are stored. 333 | secret_tenant_sp (str): Service principal tenant_id secret, stored in KeyVault 'keyvault_name'. 334 | secret_client_sp (str): Service principal client_id secret, stored in KeyVault 'keyvault_name'. 335 | secret_sp (str): Service principal secret_client secret, stored in KeyVault 'keyvault_name'. 336 | """ 337 | self.keyvault_name = keyvault_name 338 | self.secret_tenant_sp = secret_tenant_sp 339 | self.secret_client_sp = secret_client_sp 340 | self.secret_sp = secret_sp 341 | self.login = None 342 | # Import public key from PGP 343 | puk_secret_value = self.fetch_secret(self.default_login(), puk_secret_name) 344 | self.public_key, _ = pgpy.PGPKey.from_blob(puk_secret_value) 345 | # Retrieve pk secrets 346 | sc = self.sp_login() 347 | self.enable_secret(sc, prk_secret_name, True) 348 | self.enable_secret(sc, prk_password_secret_name, True) 349 | pk_secret_value = self.fetch_secret(sc, prk_secret_name) 350 | self.pk_pass_secret_value = self.fetch_secret(sc, prk_password_secret_name) 351 | # Fetch pk key 352 | self.private_key, _ = pgpy.PGPKey.from_blob(pk_secret_value) 353 | 354 | 355 | def encrypt( 356 | self, 357 | input_path:str, 358 | output_path:str, 359 | filenames:List[str], 360 | remove_input:bool=False, 361 | secret_client:SecretClient=None 362 | ) -> None: 363 | # Check input is a list 364 | if isinstance(filenames, str): 365 | filenames = [filenames] 366 | # Default login 367 | if ((self.login!='default') | (secret_client is None)): 368 | secret_client = self.default_login() 369 | # Loop 370 | for filename in filenames: 371 | input_filepath = os.path.join(input_path, filename) 372 | output_filepath = os.path.join(output_path, filename) 373 | folder_path, fn, ext = self.get_file_attr(input_filepath) 374 | if ext=='.pgp': 375 | log.warning(f"File {fn} is already encrypted. Skipping...") 376 | continue 377 | with open(input_filepath, 'rb') as f: 378 | message = pgpy.PGPMessage.new(f.read()) 379 | encrypted_message = self.public_key.encrypt(message) 380 | encrypted_message = str(encrypted_message) 381 | with open(output_filepath+'.pgp', "w") as f: 382 | f.write(encrypted_message) 383 | log.info(f"File {fn+ext+'.pgp'} has been generated in {folder_path}.") 384 | if remove_input: os.remove(input_filepath) 385 | log.info(f"File {fn+ext} has been removed.") 386 | 387 | 388 | def decrypt( 389 | self, 390 | input_path: str, 391 | output_path: str, 392 | filenames:List[str], 393 | remove_input:bool=False, 394 | secret_client:SecretClient=None 395 | ) -> None: 396 | # Check input is a list 397 | if isinstance(filenames, str): 398 | filenames = [filenames] 399 | # Service principal login 400 | if ((self.login!='sp') | (secret_client is None)): 401 | secret_client = self.sp_login() 402 | # Loop 403 | for filename in filenames: 404 | input_filepath = os.path.join(input_path, filename) 405 | folder_path, fn, ext = self.get_file_attr(input_filepath) 406 | if ext not in ['.pgp', '.enc']: 407 | log.warning(f"File {fn} is already decrypted. Skipping...") 408 | continue 409 | with self.private_key.unlock(self.pk_pass_secret_value) as ukey: 410 | if ukey: 411 | encrypted_message = pgpy.PGPMessage.from_file(input_filepath) 412 | decrypted_message = ukey.decrypt(encrypted_message).message 413 | if isinstance(decrypted_message, str): 414 | with open(os.path.join(output_path, fn), "w") as f: 415 | f.write(decrypted_message) 416 | elif isinstance(decrypted_message, bytearray): 417 | with open(os.path.join(output_path, fn), "wb") as f: 418 | f.write(decrypted_message) 419 | else: 420 | log.error(f"File {fn} returned a decrypted message that it's not either str not bytearray. Please check.") 421 | raise ValueError(f"File {fn} returned a decrypted message that it's not either str not bytearray. Please check.") 422 | log.info(f"File {fn} has been generated in {folder_path}.") 423 | if remove_input: os.remove(input_filepath) 424 | log.info(f"File {fn+ext} has been removed.") 425 | else: 426 | log.error(f"Private key password is not correct.") 427 | raise ValueError(f"Private key password is not correct.") 428 | 429 | 430 | def default_login(self) -> SecretClient: 431 | credential = DefaultAzureCredential() 432 | credential.get_token("https://management.azure.com/.default") 433 | secret_client = SecretClient(vault_url=f"https://{self.keyvault_name}.vault.azure.net/", credential=credential) 434 | self.login = 'default' 435 | return secret_client 436 | 437 | 438 | def sp_login(self) -> SecretClient: 439 | # Make sure all parameters are in place 440 | if ((self.secret_tenant_sp is None) | (self.secret_client_sp is None) | (self.secret_sp is None)): 441 | log.error(f"Service principal credentials have not been set up properly.") 442 | raise ValueError(f"Service principal credentials have not been set up properly.") 443 | # Get secret client 444 | secret_client = self.default_login() 445 | tenant_id = secret_client.get_secret(name=self.secret_tenant_sp).value 446 | client_id = secret_client.get_secret(name=self.secret_client_sp).value 447 | client_secret = secret_client.get_secret(name=self.secret_sp).value 448 | credential = ClientSecretCredential(tenant_id, client_id, client_secret) 449 | secret_client = SecretClient(vault_url=f"https://{self.keyvault_name}.vault.azure.net/", credential=credential) 450 | self.login = 'sp' 451 | return secret_client 452 | 453 | 454 | def enable_secret( 455 | self, 456 | secret_client:SecretClient=None, 457 | secret_name:str=None, 458 | enable:bool=False 459 | ) -> None: 460 | # Get the right login for the operation 461 | if self.login!='sp': 462 | secret_client = self.sp_login() 463 | # Check secret current status 464 | try: 465 | secret_status = secret_client.get_secret(secret_name).properties.enabled 466 | except: 467 | secret_status = False 468 | # Compare with input action 469 | if secret_status==enable: 470 | s = 'enabled' if enable else 'disabled' 471 | log.info(f"Secret {secret_name} is already {s}.") 472 | else: 473 | s = 'enabled' if enable else 'disabled' 474 | secret_client.update_secret_properties(secret_name, enabled=enable) 475 | log.info(f"Secret {secret_name} is now {s}.") 476 | 477 | 478 | def fetch_secret( 479 | self, 480 | secret_client:SecretClient, 481 | secret_name:str 482 | ) -> str: 483 | secret_value = secret_client.get_secret(secret_name).value 484 | return secret_value 485 | 486 | 487 | @staticmethod 488 | def get_file_attr( 489 | filepath:str 490 | ) -> List[str]: 491 | """Helper function to consistently split a filepath into folder path, filename and extension. 492 | 493 | Args: 494 | filepath (str): Path where file is stored. 495 | 496 | Returns: 497 | List[str]: Folder path, file name and file extension. 498 | """ 499 | folder_path = '/'.join(filepath.split('/')[:-1]) 500 | fn, ext = os.path.splitext(filepath.split('/')[-1]) 501 | return folder_path, fn, ext 502 | 503 | 504 | # Helper function to wrap predict logic 505 | def inference_punct( 506 | model:PunctuationModel, 507 | asr_ip:Dict, 508 | diar_ip:Dict, 509 | ncs:int=None, 510 | ns:int=None, 511 | mws:int=None, 512 | ep:str=".?!", 513 | mp:str=".,;:!?" 514 | ): 515 | wsm = get_words_speaker_mapping(asr_ip, diar_ip) 516 | words_list = list(map(lambda x: x["word"], wsm)) 517 | try: 518 | labled_words = model.predict(words_list, ncs, ns) 519 | except: 520 | log.warning(f"Raised error using chunk_size={ncs}. Retrying with half chunk size.") 521 | try: 522 | labled_words = model.predict(words_list, ncs//2, ns) 523 | except: 524 | log.warning(f"Raised error using chunk_size={ncs//2}. Retrying with half chunk size.") 525 | labled_words = model.predict(words_list, ncs//4, ns) 526 | # Acronyms handling 527 | is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x) 528 | for word_dict, labeled_tuple in zip(wsm, labled_words): 529 | word = word_dict["word"] 530 | if ( 531 | word 532 | and labeled_tuple[1] in ep 533 | and (word[-1] not in mp or is_acronym(word)) 534 | ): 535 | word += labeled_tuple[1] 536 | if word.endswith(".."): 537 | word = word.rstrip(".") 538 | word_dict["word"] = word 539 | try: 540 | wsm_final = get_realigned_ws_mapping_with_punctuation(wsm, mws, ep) 541 | except: 542 | mws = max(mws-10,mws//2) 543 | log.warning(f"Raised error during punctuation alignment. Retrying with max_word_sentence={mws}.") 544 | try: 545 | wsm_final = get_realigned_ws_mapping_with_punctuation(wsm, mws, ep) 546 | except: 547 | mws = max(mws-10,mws//2) 548 | log.warning(f"Raised error during punctuation alignment. Retrying with max_word_sentence={mws}.") 549 | wsm_final = get_realigned_ws_mapping_with_punctuation(wsm, mws, ep) 550 | ssm = get_sentences_speaker_mapping(wsm_final, diar_ip) 551 | return ssm 552 | 553 | 554 | # Helper function to cleanup audios directory 555 | def delete_files_in_directory_and_subdirectories(directory_path): 556 | try: 557 | for root, dirs, files in os.walk(directory_path): 558 | for file in files: 559 | file_path = os.path.join(root, file) 560 | os.remove(file_path) 561 | log.info("All files and subdirectories deleted successfully.") 562 | except OSError: 563 | log.info("Error occurred while deleting files and subdirectories.") 564 | 565 | 566 | # 567 | # Scoring (entry) script: entry point for execution, scoring script should contain two functions: 568 | # * init(): this function should be used for any costly or common preparation for subsequent inferences, e.g., 569 | # deserializing and loading the model into a global object. 570 | # * run(mini_batch): The method to be parallelized. Each invocation will have one minibatch. 571 | # * mini_batch: Batch inference will invoke run method and pass either a list or Pandas DataFrame as an argument to the method. 572 | # Each entry in min_batch will be - a filepath if input is a FileDataset, a Pandas DataFrame if input is a TabularDataset. 573 | # * return value: run() method should return a Pandas DataFrame or an array. 574 | # For append_row output_action, these returned elements are appended into the common output file. 575 | # For summary_only, the contents of the elements are ignored. 576 | # For all output actions, each returned output element indicates one successful inference of input element in the input mini-batch. 577 | # 578 | 579 | def init(): 580 | """Init""" 581 | # Managed output path to control where objects are returned 582 | parser = argparse.ArgumentParser( 583 | allow_abbrev=False, description="ParallelRunStep Agent" 584 | ) 585 | parser.add_argument("--input_diar_path", type=str) 586 | parser.add_argument("--keyvault_name", type=str) 587 | parser.add_argument("--secret_tenant_sp", type=str) 588 | parser.add_argument("--secret_client_sp", type=str) 589 | parser.add_argument("--secret_sp", type=str) 590 | parser.add_argument("--pk_secret", type=str) 591 | parser.add_argument("--pk_pass_secret", type=str) 592 | parser.add_argument("--pubk_secret", type=str) 593 | parser.add_argument("--cosmosdb_name", type=str) 594 | parser.add_argument("--cosmosdb_collection", type=str) 595 | parser.add_argument("--cosmosdb_cs_secret", type=str) 596 | parser.add_argument("--ner_chunk_size", type=int, default=80) 597 | parser.add_argument("--ner_stride", type=int, default=5) 598 | parser.add_argument("--max_words_in_sentence", type=int, default=40) 599 | parser.add_argument("--output_sm_path", type=str) 600 | args, _ = parser.parse_known_args() 601 | 602 | # Encrypt params 603 | global keyvault_name, secret_tenant_sp, secret_client_sp, secret_sp, pk_secret, pk_pass_secret, pubk_secret 604 | keyvault_name = args.keyvault_name 605 | secret_tenant_sp = args.secret_tenant_sp 606 | secret_client_sp = args.secret_client_sp 607 | secret_sp = args.secret_sp 608 | pk_secret = args.pk_secret 609 | pk_pass_secret = args.pk_pass_secret 610 | pubk_secret = args.pubk_secret 611 | 612 | # Instantiate credential manager 613 | global cm 614 | cm = CredentialManager(keyvault_name, secret_tenant_sp, secret_client_sp, secret_sp, pubk_secret, pk_secret, pk_pass_secret) 615 | 616 | # Params 617 | global input_diar_path, max_words_in_sentence, output_sm_path 618 | input_diar_path = args.input_diar_path 619 | max_words_in_sentence = args.max_words_in_sentence 620 | output_sm_path = args.output_sm_path 621 | 622 | # Folder structure 623 | Path('./decrypted_files').mkdir(parents=True, exist_ok=True) 624 | Path(output_sm_path).mkdir(parents=True, exist_ok=True) 625 | 626 | # Punctuation model 627 | global punct_model, ending_puncts, model_puncts, ner_chunk_size, ner_stride 628 | punct_model = PunctuationModel() 629 | ner_chunk_size = args.ner_chunk_size 630 | ner_stride = args.ner_stride 631 | 632 | # MongoDB client 633 | credential = DefaultAzureCredential() 634 | credential.get_token("https://management.azure.com/.default") 635 | sc = SecretClient(vault_url=f"https://{keyvault_name}.vault.azure.net/", credential=credential) 636 | connection_string = sc.get_secret(name=args.cosmosdb_cs_secret).value 637 | mongodb_client = pymongo.MongoClient(connection_string) 638 | # DB connection 639 | if args.cosmosdb_name not in mongodb_client.list_database_names(): 640 | log.error(f"Database {args.cosmosdb_name} not found.") 641 | raise ValueError(f"Database {args.cosmosdb_name} not found.") 642 | else: 643 | cosmosdb_db = mongodb_client[args.cosmosdb_name] 644 | log.info(f"Database {args.cosmosdb_name} connected.") 645 | # Collection connection 646 | if args.cosmosdb_collection not in cosmosdb_db.list_collection_names(): 647 | log.error(f"Collection {args.cosmosdb_collection} not found.") 648 | raise ValueError(f"Collection {args.cosmosdb_collection} not found.") 649 | else: 650 | global cosmosdb_client 651 | cosmosdb_client = cosmosdb_db[args.cosmosdb_collection] 652 | log.info(f"Collection {args.cosmosdb_collection} connected.") 653 | 654 | # Timestamp 655 | global ts 656 | ts = str(datetime.now()) 657 | 658 | 659 | def run(mini_batch): 660 | 661 | for elem in mini_batch: # mini_batch on ASR files 662 | # Read file 663 | pathdir = Path(elem) 664 | if not re.search(r'(.*?).*_nfa\.json\.pgp$', str(pathdir)): 665 | log.info(f"File {str(pathdir)} does not contain metadata from diarization. Skipping...") 666 | continue 667 | input_path = '/'.join(str(pathdir).split('/')[:-1]) 668 | fn, _ = os.path.splitext(str(pathdir).split('/')[-1]) 669 | fn = re.findall('(.*?)_nfa', fn)[0] # remove '_nfa' from filename to get unique_id 670 | 671 | # Process 672 | log.debug(f"Processing file {fn}:") 673 | cm.decrypt(input_path, './decrypted_files', f"{fn}_nfa.json.pgp") 674 | with open(f"./decrypted_files/{fn}_nfa.json", 'r', encoding='utf8') as f: 675 | asr_dct = json.load(f) 676 | 677 | # If file contains no segments, jump to the next one generating dummy metadata 678 | if len(asr_dct['segments'])==0: 679 | log.debug(f"Audio {fn} does not contain segments. Dumping dummy file and skipping:") 680 | # Save output 681 | with open( 682 | f'./decrypted_files/{fn}.json', 683 | 'w', 684 | encoding='utf8' 685 | ) as f: 686 | json.dump( 687 | { 688 | 'unique_id': fn, 689 | 'segments': [] 690 | }, 691 | f, 692 | indent=4, 693 | ensure_ascii=False 694 | ) 695 | cm.encrypt('./decrypted_files', output_sm_path, f"{fn}.json", True) 696 | # Generate record in cosmosDB 697 | cosmosdb_client.update_one( 698 | {"_id": fn}, {"$set": {"timestamp": ts}}, upsert=True 699 | ) 700 | continue 701 | asr_input = [w for s in asr_dct['segments'] for w in s['words']] 702 | 703 | # Diarization metadata 704 | cm.decrypt(input_path, './decrypted_files', f"{fn}_diar.json.pgp") 705 | with open(f"./decrypted_files/{fn}_diar.json", 'r', encoding='utf8') as f: 706 | diar_dct = json.load(f) 707 | diar_input = [[s['start'], s['end'], s['speaker']] for s in diar_dct['segments']] 708 | 709 | # Get labels for each piece of text from ASR 710 | sm_time = time.time() 711 | ssm = inference_punct( 712 | asr_input, 713 | diar_input, 714 | ner_chunk_size, 715 | ner_stride, 716 | max_words_in_sentence 717 | ) 718 | sm_time = time.time() - sm_time 719 | log.info(f"\tSentence-mapping time: {sm_time}") 720 | 721 | # 722 | # Adjust timestamps with VAD chunks 723 | # 724 | log.info('\tMapping VAD timestamps with transcription') 725 | ts_map = SpeechTimestampsMap(asr_dct['vad_timestamps'], 16000) 726 | for segment in ssm: 727 | words = [] 728 | for word in segment['words']: 729 | # Ensure the word start and end times are resolved to the same chunk. 730 | middle = (word['start_time'] + word['end_time']) / 2 731 | chunk_index = ts_map.get_chunk_index(middle) 732 | word['start_time'] = ts_map.get_original_time(word['start_time'], chunk_index) 733 | word['end_time'] = ts_map.get_original_time(word['end_time'], chunk_index) 734 | words.append(word) 735 | 736 | segment['start_time'] = words[0]['start_time'] 737 | segment['end_time'] = words[-1]['end_time'] 738 | segment['words'] = words 739 | 740 | 741 | # Save output 742 | with open( 743 | os.path.join( 744 | './decrypted_files', 745 | f"{fn}.json" 746 | ), 747 | 'w', 748 | encoding='utf8' 749 | ) as f: 750 | json.dump( 751 | { 752 | 'unique_id': fn, 753 | 'duration': asr_dct['duration'], 754 | 'processing_time': { 755 | **asr_dct['metadata'], 756 | **diar_dct['metadata'], 757 | **{ 758 | 'sentence_mapping_time': sm_time 759 | } 760 | }, 761 | 'segments': ssm 762 | }, 763 | f, 764 | indent=4, 765 | ensure_ascii=False 766 | ) 767 | 768 | # Generate record in cosmosDB 769 | cosmosdb_client.update_one( 770 | {"_id": fn}, 771 | { 772 | "$set": { 773 | "timestamp": ts, 774 | 'duration': asr_dct['duration'], 775 | 'processing_time': { 776 | **asr_dct['metadata'], 777 | **diar_dct['metadata'], 778 | **{ 779 | 'sentence_mapping_time': sm_time 780 | } 781 | } 782 | } 783 | }, upsert=True 784 | ) 785 | 786 | # Decrypt output 787 | cm.encrypt('./decrypted_files', output_sm_path, f"{fn}.json", True) 788 | 789 | # Cleanup resources 790 | log.info(f"Cleanup resources") 791 | delete_files_in_directory_and_subdirectories('./decrypted_files') 792 | 793 | return mini_batch 794 | 795 | 796 | def shutdown(): 797 | cm.enable_secret(cm.sp_login(), pk_secret, False) 798 | cm.enable_secret(cm.sp_login(), pk_pass_secret, False) --------------------------------------------------------------------------------