diff --git a/.dvc/.gitignore b/.dvc/.gitignore new file mode 100644 index 0000000..528f30c --- /dev/null +++ b/.dvc/.gitignore @@ -0,0 +1,3 @@ +/config.local +/tmp +/cache diff --git a/.dvc/config b/.dvc/config new file mode 100644 index 0000000..3229a91 --- /dev/null +++ b/.dvc/config @@ -0,0 +1,4 @@ +[core] + remote = adr-ml-training-data +['remote "adr-ml-training-data"'] + url = s3://adr-ml-training-data/dvc diff --git a/.dvcignore b/.dvcignore new file mode 100644 index 0000000..5197305 --- /dev/null +++ b/.dvcignore @@ -0,0 +1,3 @@ +# Add patterns of files dvc should ignore, which could improve +# the performance. Learn more at +# https://dvc.org/doc/user-guide/dvcignore diff --git a/ml/_doc_ml.md b/ml/_doc_ml.md index 325c264..59e9676 100644 --- a/ml/_doc_ml.md +++ b/ml/_doc_ml.md @@ -1,10 +1,62 @@ -Speech Recognition Models +# ML Pipeline Design — AdaptiveRemote (current implementation) -========================= +## Purpose and Scope -## Folders -/ml/scripts -> Python scripts for implementing speech recognition training and evaluation. -/ml/data -> Contains datasets used for training and evaluating speech recognition models. -/ml/notebooks -> Jupyter notebooks for experimenting with different speech recognition techniques. -/ml/models -> Pre-trained speech recognition models and scripts for training new models. +This document describes the ML pipeline in the `ml` folder. It includes DVC stages, scripts, inputs and outputs, how to run it locally for development, and a short set of next steps. The pipeline is implemented for local Windows development and uses DVC to manage data and artifacts. + +## Repository layout (relevant paths) + +- `ml/dvc.yaml` — pipeline orchestration and stage definitions. +- `ml/scripts/intent_prediction/` — scripts to generate intent phrase variations (`01_generate_phrases.py` and `01_input_phrases.csv`). +- `ml/scripts/speech_to_text/` — speech sample generation, augmentation, featurization, training, and evaluation scripts (`01*`–`09*`). +- `ml/data/` — DVC-tracked raw, intermediate, and output artifacts (manifests, spectrograms, models). This data is managed by DVC, not committed to the Git repo. + +## Implemented DVC stages and scripts + +The pipeline for both speech-to-text and intent prediction is defined with explicit stages in [`ml/dvc.yaml`](./dvc.yaml). See that file for stage names, inputs, and outputs. + +Each stage in `dvc.yaml` declares the exact command, dependencies, and outputs used by the pipeline. The scripts follow a consistent CLI convention (required `--input`/`--manifest`/`--output` arguments) and each stage saves outputs into its own folder in the the `ml/data` tree. + +## Implementation details (summary) + +- Intent generation: `01_generate_phrases.py` reads `01_input_phrases.csv`, synthesizes surface-form variations (adding pleasantries, hesitations, spelling variants, and repeats), and writes `training_data.csv` used as input to both intent prediction and speech sample generation. +- Speech sample generation and augmentation: scripts `01*`–`04*` in speech_to_text generate TTS or synthetic samples, randomize delays, add background and microphone noise, and write clean/noisy audio files into DVC-backed directories. +- Manifests and vocab: `05_create_set_manifests.py` creates train/val CSV manifests referencing audio filepaths and expected transcripts; `06_create_vocab_list.py` builds `vocab_list.txt`. +- Featurization: `07_compute_spectrograms.py` reads manifests and `vocab_list.txt`, computes fixed-size log-Mel spectrograms and token arrays (`*.npy`), and stores them in the spectrogram output directory. +- Training: `08_train_model.py` loads training manifest and spectrogram/token `.npy` files, constructs a Keras model (Conv2D → BiLSTM → Dense), trains with a CTC-style loss loop, and saves `speech_to_text_model.keras`. +- Evaluation: `09_evaluate_model.py` loads the saved model, runs greedy CTC decoding on eval spectrograms, computes WER using `jiwer`, and writes an `evaluation_predictions.txt` report. + +## Dependencies and environment + +- See `ml/scripts/requirements.txt` for required Python packages. +- The code is written to run on Windows for development (CPU TensorFlow). Model training and larger-scale runs can be moved to Linux GPU hosts with minimal changes. + +## How to run (developer quickstart) + +1. Ensure DVC is installed and configured for your environment. +2. From the repository root: + +```powershell +cd ml +pip install -r scripts/requirements.txt +dvc pull +dvc repro +``` + +`dvc repro` will execute the defined stages in the correct order and populate `ml/data` with outputs. Inspect the `dvc.yaml` file for per-stage commands and I/O when you need to run or debug a particular step. + +## Observations and current limitations + +- The training loop in `08_train_model.py` implements a simple custom training loop with CTC loss; model hyperparameters (epochs, batch_size, time_steps) are hardcoded and could be parameterized. +- The dataset generation is entirely file-based; large intermediate artifacts (audio, spectrograms, models) are stored under `ml/data` and should be pushed to the DVC remote to share across machines. +- There is a simple greedy CTC decoder in `09_evaluate_model.py`; for production accuracy reporting, a beam search decoder could be added. +- Scripts assume certain manifest and file conventions (manifest columns include `filepath` and `speech_to_detect`). Changes to manifest format will require updating `07*`, `08*`, and `09*` scripts. + +## Next steps (practical, minimal changes) + +- Parameterize training and featurization hyperparameters via CLI args or a YAML config to avoid editing source for experiments. +- Pin dependency versions in `ml/scripts/requirements.txt` (or add `requirements.lock`) for reproducibility. +- Add a small-sample smoke dataset and a CI job that runs `dvc pull` + `dvc repro` on that sample to detect regressions. +- Add a minimal `model_registry.json` (template) that records model metadata (train commit, metrics, DVC path) when a training run completes. +- Add a simple cleanup helper to remove local intermediate files not referenced by DVC. diff --git a/ml/dvc.lock b/ml/dvc.lock new file mode 100644 index 0000000..9d1b093 --- /dev/null +++ b/ml/dvc.lock @@ -0,0 +1,413 @@ +schema: '2.0' +stages: + generate_intent_phrases: + cmd: python scripts/intent_prediction/01_generate_phrases.py + deps: + - path: scripts/intent_prediction/01_generate_phrases.py + hash: md5 + md5: bf93a13f89570424e339158dad0d6f84 + size: 16373 + - path: scripts/intent_prediction/01_input_phrases.csv + hash: md5 + md5: e51aa089ecca384f1de076f1cd1c043a + size: 882 + outs: + - path: data/intent_prediction/01_generate_phrases/training_data.csv + hash: md5 + md5: b2389396a57cb3e1743a90d440f1f3f0 + size: 1109671 + generate_speech_samples: + cmd: python scripts/speech_to_text/01_generate_speech_samples.py + --input-file + data/speech_to_text/01a_generate_speech_sample_variations/speech_sample_variations.csv + --output-dir data/speech_to_text/01_generate_speech_samples + deps: + - path: + data/speech_to_text/01a_generate_speech_sample_variations/speech_sample_variations.csv + hash: md5 + md5: 0b9519ac9376bc628bc9649d8023919c + size: 2253468 + - path: scripts/speech_to_text/01_generate_speech_samples.py + hash: md5 + md5: 336da58be6ced4d9cf00bfd9ed90129d + size: 2007 + outs: + - path: data/speech_to_text/01_generate_speech_samples + hash: md5 + md5: 6616c0712ae80d9294bac6e2b5b9113e.dir + size: 269995680 + nfiles: 19577 + add_delays: + cmd: python scripts/speech_to_text/02_add_delays.py --input-file + data/speech_to_text/02a_randomize_delay_variations/randomized_delay_variations.csv + --output-dir data/speech_to_text/02_add_delays + deps: + - path: data/speech_to_text/01_generate_speech_samples + hash: md5 + md5: 6616c0712ae80d9294bac6e2b5b9113e.dir + size: 269995680 + nfiles: 19577 + - path: + data/speech_to_text/02a_randomize_delay_variations/randomized_delay_variations.csv + hash: md5 + md5: 4cda75d46b6fc45fed5aee323689c1ae + size: 2880273 + - path: scripts/speech_to_text/02_add_delays.py + hash: md5 + md5: 23cb02386882de26d6ea1f824e6c8043 + size: 1804 + outs: + - path: data/speech_to_text/02_add_delays + hash: md5 + md5: 194046a1b6c711520073d25460c9dae0.dir + size: 2531359252 + nfiles: 19577 + add_noise: + cmd: python .\ml\scripts\speech_to_text\04_add_microphone_noise.py + deps: + - path: .\ml\data\speech_to_text\03_add_background_noise\ + hash: md5 + md5: cf74edb8d91353c347a2543d8207cc0d.dir + size: 2584542 + nfiles: 20 + - path: .\ml\scripts\speech_to_text\04_add_microphone_noise.py + hash: md5 + md5: f84815e2570664eeb5f91b46abf0bcf7 + size: 1924 + outs: + - path: .\ml\data\speech_to_text\04_add_microphone_noise\ + hash: md5 + md5: 15057a761839bfd352e6d70d6ba1a208.dir + size: 2584542 + nfiles: 20 + add_background_noise: + cmd: python scripts/speech_to_text/03_add_background_noise.py --input-dir + data/speech_to_text/02_add_delays --noise-dir + data/speech_to_text/03a_download_background_noise --output-dir + data/speech_to_text/03_add_background_noise + deps: + - path: data/speech_to_text/02_add_delays + hash: md5 + md5: 194046a1b6c711520073d25460c9dae0.dir + size: 2531359252 + nfiles: 19577 + - path: data/speech_to_text/03a_download_background_noise + hash: md5 + md5: 179e24549458b5a33140379d770da95d.dir + size: 117845790 + nfiles: 5 + - path: scripts/speech_to_text/03_add_background_noise.py + hash: md5 + md5: 0e5f45ac26aafe57218478a306a2c73a + size: 2822 + outs: + - path: data/speech_to_text/03_add_background_noise + hash: md5 + md5: da31d5faf1a8e622bddc7fd2dbf59444.dir + size: 2531359252 + nfiles: 19577 + add_microphone_noise: + cmd: python scripts/speech_to_text/04_add_microphone_noise.py --input-dir + data/speech_to_text/03_add_background_noise --output-dir + data/speech_to_text/04_add_microphone_noise + deps: + - path: data/speech_to_text/03_add_background_noise + hash: md5 + md5: da31d5faf1a8e622bddc7fd2dbf59444.dir + size: 2531359252 + nfiles: 19577 + - path: scripts/speech_to_text/04_add_microphone_noise.py + hash: md5 + md5: 35239ebc40b21fab29032f6f2832105c + size: 2055 + outs: + - path: data/speech_to_text/04_add_microphone_noise + hash: md5 + md5: c3142733b000aca3ce2a8207c9038de2.dir + size: 2531359252 + nfiles: 19577 + create_zipped_sets: + cmd: python .\ml\scripts\speech_to_text\05_create_zipped_sets.py + deps: + - path: .\ml\data\speech_to_text\01_generate_speech_samples\ + hash: md5 + md5: ac2b96405745c2c9b35f112b1ce15855.dir + size: 2680272 + nfiles: 200 + - path: .\ml\data\speech_to_text\04_add_microphone_noise\ + hash: md5 + md5: d18b8fe475e019b01232bc688302ecea.dir + size: 25520554 + nfiles: 200 + - path: .\ml\scripts\speech_to_text\05_create_zipped_sets.py + hash: md5 + md5: bfaa31ff93b60874c65281cf9d5ad602 + size: 1757 + outs: + - path: .\ml\data\speech_to_text\05_create_zipped_sets\ + hash: md5 + md5: df3f4e0539d52398b760c4bf1743f30a.dir + size: 61539 + nfiles: 3 + 05_create_set_manifests: + cmd: python .\ml\scripts\speech_to_text\05_create_set_manifests.py + deps: + - path: .\ml\data\speech_to_text\01_generate_speech_samples\ + hash: md5 + md5: c7112131b48176db775eebd5fc316a38.dir + size: 301536 + nfiles: 20 + - path: .\ml\data\speech_to_text\04_add_microphone_noise\ + hash: md5 + md5: 9e0a41a1cd2d094cb052a27005f7b414.dir + size: 2637688 + nfiles: 20 + - path: .\ml\scripts\speech_to_text\05_create_set_manifests.py + hash: md5 + md5: 658302695382e931722a3a0c94e438c8 + size: 2641 + outs: + - path: .\ml\data\speech_to_text\05_create_set_manifests\ + hash: md5 + md5: fbf8d994a38244302db466c731d82c04.dir + size: 5962 + nfiles: 3 + create_vocab_list: + cmd: python scripts/speech_to_text/06_create_vocab_list.py --input-file + data/intent_prediction/01_generate_phrases/training_data.csv --output-file + data/speech_to_text/06_create_vocab_list/vocab_list.txt + deps: + - path: data/intent_prediction/01_generate_phrases/training_data.csv + hash: md5 + md5: b2389396a57cb3e1743a90d440f1f3f0 + size: 1109671 + - path: scripts/speech_to_text/06_create_vocab_list.py + hash: md5 + md5: 1b75d62a497212e84e03f832201d9d9a + size: 1057 + outs: + - path: data/speech_to_text/06_create_vocab_list/vocab_list.txt + hash: md5 + md5: 58ed1a5697c5005cb315e69e39275886 + size: 381 + compute_spectrograms: + cmd: python scripts/speech_to_text/07_compute_spectrograms.py + --train-manifest + data/speech_to_text/05_create_set_manifests/train_manifest.csv + --eval-manifest + data/speech_to_text/05_create_set_manifests/val_manifest.csv + --test-manifest + data/speech_to_text/05_create_set_manifests/test_manifest.csv --vocab + data/speech_to_text/06_create_vocab_list/vocab_list.txt --output-dir + data/speech_to_text/07_compute_spectrograms + deps: + - path: data/speech_to_text/01_generate_speech_samples + hash: md5 + md5: 6616c0712ae80d9294bac6e2b5b9113e.dir + size: 269995680 + nfiles: 19577 + - path: data/speech_to_text/04_add_microphone_noise + hash: md5 + md5: c3142733b000aca3ce2a8207c9038de2.dir + size: 2531359252 + nfiles: 19577 + - path: data/speech_to_text/05_create_set_manifests + hash: md5 + md5: 03b34798299a09d888632a1a760ad05f.dir + size: 5719232 + nfiles: 3 + - path: data/speech_to_text/06_create_vocab_list/vocab_list.txt + hash: md5 + md5: 58ed1a5697c5005cb315e69e39275886 + size: 381 + - path: scripts/speech_to_text/07_compute_spectrograms.py + hash: md5 + md5: 7cde66959e2afe126bd6a73c92369c47 + size: 3424 + outs: + - path: data/speech_to_text/07_compute_spectrograms + hash: md5 + md5: 53d8122dd3dece2bf1b242768de67e9b.dir + size: 8307880416 + nfiles: 72012 + create_set_manifests: + cmd: python scripts/speech_to_text/05_create_set_manifests.py + --input-manifest + data/intent_prediction/01_generate_phrases/training_data.csv --clean-dir + data/speech_to_text/01_generate_speech_samples --noisy-dir + data/speech_to_text/04_add_microphone_noise --output-dir + data/speech_to_text/05_create_set_manifests + deps: + - path: data/speech_to_text/01_generate_speech_samples + hash: md5 + md5: 6616c0712ae80d9294bac6e2b5b9113e.dir + size: 269995680 + nfiles: 19577 + - path: data/speech_to_text/04_add_microphone_noise + hash: md5 + md5: c3142733b000aca3ce2a8207c9038de2.dir + size: 2531359252 + nfiles: 19577 + - path: scripts/speech_to_text/05_create_set_manifests.py + hash: md5 + md5: b681c834459e73f9eb2a9aa59e893c44 + size: 2842 + outs: + - path: data/speech_to_text/05_create_set_manifests + hash: md5 + md5: 03b34798299a09d888632a1a760ad05f.dir + size: 5719232 + nfiles: 3 + download_noise_samples: + cmd: python scripts/speech_to_text/03a_download_background_noise.py + --output-dir data/speech_to_text/03a_download_background_noise + deps: + - path: scripts/speech_to_text/03a_download_background_noise.py + hash: md5 + md5: ba94436e971fe5e0345054292d32e653 + size: 1755 + outs: + - path: data/speech_to_text/03a_download_background_noise + hash: md5 + md5: 179e24549458b5a33140379d770da95d.dir + size: 117845790 + nfiles: 5 + train_model: + cmd: python scripts/speech_to_text/08_train_model.py --manifest + data/speech_to_text/05_create_set_manifests/train_manifest.csv --vocab + data/speech_to_text/06_create_vocab_list/vocab_list.txt --spectrogram-dir + data/speech_to_text/07_compute_spectrograms --output-dir + data/speech_to_text/08_train_model + deps: + - path: data/speech_to_text/05_create_set_manifests/train_manifest.csv + hash: md5 + md5: 6dbbf1fccbc1f8edc8b91d71159111f2 + size: 4574393 + - path: data/speech_to_text/06_create_vocab_list/vocab_list.txt + hash: md5 + md5: 58ed1a5697c5005cb315e69e39275886 + size: 381 + - path: data/speech_to_text/07_compute_spectrograms + hash: md5 + md5: 53d8122dd3dece2bf1b242768de67e9b.dir + size: 8307880416 + nfiles: 72012 + - path: scripts/speech_to_text/08_train_model.py + hash: md5 + md5: 912021be45cbf6ca2dd84cc8ef013a9b + size: 4644 + outs: + - path: data/speech_to_text/08_train_model + hash: md5 + md5: a74f90bee40ac94cdcee7b35922e1ee3.dir + size: 72847379 + nfiles: 1 + generate_speech_sample_variations: + cmd: python scripts/speech_to_text/01a_generate_speech_sample_variations.py + --input-file data/intent_prediction/01_generate_phrases/training_data.csv + --samples-dir data/speech_to_text/01_generate_speech_samples --output-file + data/speech_to_text/01a_generate_speech_sample_variations/speech_sample_variations.csv + deps: + - path: data/intent_prediction/01_generate_phrases/training_data.csv + hash: md5 + md5: b2389396a57cb3e1743a90d440f1f3f0 + size: 1109671 + - path: scripts/speech_to_text/01a_generate_speech_sample_variations.py + hash: md5 + md5: 045cb3ba2fd061807c43887e38da47f8 + size: 4111 + outs: + - path: + data/speech_to_text/01a_generate_speech_sample_variations/speech_sample_variations.csv + hash: md5 + md5: 0b9519ac9376bc628bc9649d8023919c + size: 2253468 + randomize_delay_variations: + cmd: python scripts/speech_to_text/02a_randomize_delay_variations.py + --input-dir data/speech_to_text/01_generate_speech_samples --output-file + data/speech_to_text/02a_randomize_delay_variations/randomized_delay_variations.csv + deps: + - path: data/speech_to_text/01_generate_speech_samples + hash: md5 + md5: 6616c0712ae80d9294bac6e2b5b9113e.dir + size: 269995680 + nfiles: 19577 + - path: scripts/speech_to_text/02a_randomize_delay_variations.py + hash: md5 + md5: 35310ecf60abc397344d985ef44d17fa + size: 1892 + outs: + - path: + data/speech_to_text/02a_randomize_delay_variations/randomized_delay_variations.csv + hash: md5 + md5: 4cda75d46b6fc45fed5aee323689c1ae + size: 2880273 + evaluate_model: + cmd: python scripts/speech_to_text/09_evaluate_model.py --manifest + data/speech_to_text/05_create_set_manifests/val_manifest.csv --vocab + data/speech_to_text/06_create_vocab_list/vocab_list.txt --model + data/speech_to_text/08_train_model/speech_to_text_model.keras + --spectrogram-dir data/speech_to_text/07_compute_spectrograms --output-dir + data/speech_to_text/09_evaluate_model + deps: + - path: data/speech_to_text/05_create_set_manifests/val_manifest.csv + hash: md5 + md5: 68a89c50bfe081693b7715cfd490ffd1 + size: 571668 + - path: data/speech_to_text/06_create_vocab_list/vocab_list.txt + hash: md5 + md5: 58ed1a5697c5005cb315e69e39275886 + size: 381 + - path: data/speech_to_text/07_compute_spectrograms + hash: md5 + md5: 53d8122dd3dece2bf1b242768de67e9b.dir + size: 8307880416 + nfiles: 72012 + - path: data/speech_to_text/08_train_model/speech_to_text_model.keras + hash: md5 + md5: b494dffe3206af7e321f75b87b08f3ea + size: 72847379 + - path: scripts/speech_to_text/09_evaluate_model.py + hash: md5 + md5: eaee356e238783ed5f194b5af7369ec6 + size: 4182 + outs: + - path: data/speech_to_text/09_evaluate_model + hash: md5 + md5: 1bfd161fe0a727a9468497e2a81a66cc.dir + size: 243510 + nfiles: 1 + evaluate_test_samples: + cmd: python scripts/speech_to_text/10_evaluate_test_samples.py --manifest + data/speech_to_text/05_create_set_manifests/test_manifest.csv --vocab + data/speech_to_text/06_create_vocab_list/vocab_list.txt --model + data/speech_to_text/08_train_model/speech_to_text_model.keras + --spectrogram-dir data/speech_to_text/07_compute_spectrograms --output-zip + data/speech_to_text/10_evaluate_test_samples/test_samples.zip + deps: + - path: data/speech_to_text/05_create_set_manifests/test_manifest.csv + hash: md5 + md5: 94f43bd4b3ce74d9f0b3f73731fa9838 + size: 573171 + - path: data/speech_to_text/06_create_vocab_list/vocab_list.txt + hash: md5 + md5: 58ed1a5697c5005cb315e69e39275886 + size: 381 + - path: data/speech_to_text/07_compute_spectrograms + hash: md5 + md5: 53d8122dd3dece2bf1b242768de67e9b.dir + size: 8307880416 + nfiles: 72012 + - path: data/speech_to_text/08_train_model/speech_to_text_model.keras + hash: md5 + md5: b494dffe3206af7e321f75b87b08f3ea + size: 72847379 + - path: scripts/speech_to_text/10_evaluate_test_samples.py + hash: md5 + md5: 811a617777b73e934f3c0caff48aea8d + size: 3873 + outs: + - path: data/speech_to_text/10_evaluate_test_samples/test_samples.zip + hash: md5 + md5: 6a24997521d72c2fcd29156b5913db28 + size: 30496252 diff --git a/ml/dvc.yaml b/ml/dvc.yaml new file mode 100644 index 0000000..103cf84 --- /dev/null +++ b/ml/dvc.yaml @@ -0,0 +1,206 @@ +vars: + - intent01: + name: "01_generate_phrases" + script: "scripts/intent_prediction/01_generate_phrases.py" + phrasefile: "scripts/intent_prediction/01_input_phrases.csv" + outputfile: "data/intent_prediction/01_generate_phrases/training_data.csv" + + - tts01a: + name: "01a_generate_speech_sample_variations" + script: "scripts/speech_to_text/01a_generate_speech_sample_variations.py" + speech_sample_variations: "data/speech_to_text/01a_generate_speech_sample_variations/speech_sample_variations.csv" + - tts01: + name: "01_generate_speech_samples" + script: "scripts/speech_to_text/01_generate_speech_samples.py" + clean_speech_samples: "data/speech_to_text/01_generate_speech_samples" + - tts02a: + name: "02a_randomize_delay_variations" + script: "scripts/speech_to_text/02a_randomize_delay_variations.py" + delay_variations: "data/speech_to_text/02a_randomize_delay_variations/randomized_delay_variations.csv" + - tts02: + name: "02_add_delays" + script: "scripts/speech_to_text/02_add_delays.py" + speech_samples_with_delays: "data/speech_to_text/02_add_delays" + - tts03: + name: "03_add_background_noise" + script: "scripts/speech_to_text/03_add_background_noise.py" + speech_samples_with_bg_noise: "data/speech_to_text/03_add_background_noise" + - tts03a: + name: "03a_download_background_noise" + script: "scripts/speech_to_text/03a_download_background_noise.py" + background_noise_samples: "data/speech_to_text/03a_download_background_noise" + - tts04: + name: "04_add_microphone_noise" + script: "scripts/speech_to_text/04_add_microphone_noise.py" + noisy_speech_samples: "data/speech_to_text/04_add_microphone_noise" + - tts05: + name: "05_create_set_manifests" + script: "scripts/speech_to_text/05_create_set_manifests.py" + speech_sample_manifests: "data/speech_to_text/05_create_set_manifests" + - tts06: + name: "06_create_vocab_list" + script: "scripts/speech_to_text/06_create_vocab_list.py" + vocab_list: "data/speech_to_text/06_create_vocab_list/vocab_list.txt" + - tts07: + name: "07_compute_spectrograms" + script: "scripts/speech_to_text/07_compute_spectrograms.py" + speech_sample_spectrograms: "data/speech_to_text/07_compute_spectrograms" + - tts08: + name: "08_train_model" + script: "scripts/speech_to_text/08_train_model.py" + speech_to_text_model: "data/speech_to_text/08_train_model" + - tts09: + name: "09_evaluate_model" + script: "scripts/speech_to_text/09_evaluate_model.py" + evaluation_results: "data/speech_to_text/09_evaluate_model" + - tts10: + name: "10_evaluate_test_samples" + script: "scripts/speech_to_text/10_evaluate_test_samples.py" + test_samples_zip: "data/speech_to_text/10_evaluate_test_samples/test_samples.zip" + +stages: + # Intent Prediction Data Preparation and Model Training Stages + generate_intent_phrases: + cmd: python ${intent01.script} + deps: + - ${intent01.script} + - ${intent01.phrasefile} + outs: + - ${intent01.outputfile}: + persist: true + + # Speech-to-Text Data Preparation and Model Training Stages + generate_speech_sample_variations: + cmd: python ${tts01a.script} --input-file ${intent01.outputfile} + --samples-dir ${tts01.clean_speech_samples} + --output-file ${tts01a.speech_sample_variations} + deps: + - ${tts01a.script} + - ${intent01.outputfile} + outs: + - ${tts01a.speech_sample_variations}: + persist: true + generate_speech_samples: + cmd: python ${tts01.script} --input-file ${tts01a.speech_sample_variations} + --output-dir ${tts01.clean_speech_samples} + deps: + - ${tts01a.speech_sample_variations} + - ${tts01.script} + outs: + - ${tts01.clean_speech_samples}: + persist: true + randomize_delay_variations: + cmd: python ${tts02a.script} --input-dir ${tts01.clean_speech_samples} + --output-file ${tts02a.delay_variations} + deps: + - ${tts02a.script} + - ${tts01.clean_speech_samples} + outs: + - ${tts02a.delay_variations}: + persist: true + add_delays: + cmd: python ${tts02.script} --input-file ${tts02a.delay_variations} + --output-dir ${tts02.speech_samples_with_delays} + deps: + - ${tts01.clean_speech_samples} + - ${tts02.script} + - ${tts02a.delay_variations} + outs: + - ${tts02.speech_samples_with_delays} + add_microphone_noise: + cmd: python ${tts04.script} --input-dir ${tts03.speech_samples_with_bg_noise} + --output-dir ${tts04.noisy_speech_samples} + deps: + - ${tts03.speech_samples_with_bg_noise} + - ${tts04.script} + outs: + - ${tts04.noisy_speech_samples} + add_background_noise: + cmd: python ${tts03.script} --input-dir ${tts02.speech_samples_with_delays} + --noise-dir ${tts03a.background_noise_samples} + --output-dir ${tts03.speech_samples_with_bg_noise} + deps: + - ${tts02.speech_samples_with_delays} + - ${tts03a.background_noise_samples} + - ${tts03.script} + outs: + - ${tts03.speech_samples_with_bg_noise} + create_set_manifests: + cmd: python ${tts05.script} --input-manifest ${intent01.outputfile} + --clean-dir ${tts01.clean_speech_samples} + --noisy-dir ${tts04.noisy_speech_samples} + --output-dir ${tts05.speech_sample_manifests} + deps: + - ${tts01.clean_speech_samples} + - ${tts04.noisy_speech_samples} + - ${tts05.script} + outs: + - ${tts05.speech_sample_manifests} + create_vocab_list: + cmd: python ${tts06.script} --input-file ${intent01.outputfile} + --output-file ${tts06.vocab_list} + deps: + - ${intent01.outputfile} + - ${tts06.script} + outs: + - ${tts06.vocab_list} + compute_spectrograms: + cmd: python ${tts07.script} --train-manifest ${tts05.speech_sample_manifests}/train_manifest.csv + --eval-manifest ${tts05.speech_sample_manifests}/val_manifest.csv + --test-manifest ${tts05.speech_sample_manifests}/test_manifest.csv + --vocab ${tts06.vocab_list} + --output-dir ${tts07.speech_sample_spectrograms} + deps: + - ${tts01.clean_speech_samples} + - ${tts04.noisy_speech_samples} + - ${tts05.speech_sample_manifests} + - ${tts06.vocab_list} + - ${tts07.script} + outs: + - ${tts07.speech_sample_spectrograms} + download_noise_samples: + cmd: python ${tts03a.script} --output-dir ${tts03a.background_noise_samples} + deps: + - ${tts03a.script} + outs: + - ${tts03a.background_noise_samples} + train_model: + cmd: python ${tts08.script} --manifest ${tts05.speech_sample_manifests}/train_manifest.csv + --vocab ${tts06.vocab_list} + --spectrogram-dir ${tts07.speech_sample_spectrograms} + --output-dir ${tts08.speech_to_text_model} + deps: + - ${tts05.speech_sample_manifests}/train_manifest.csv + - ${tts06.vocab_list} + - ${tts07.speech_sample_spectrograms} + - ${tts08.script} + outs: + - ${tts08.speech_to_text_model} + evaluate_model: + cmd: python ${tts09.script} --manifest ${tts05.speech_sample_manifests}/val_manifest.csv + --vocab ${tts06.vocab_list} + --model ${tts08.speech_to_text_model}/speech_to_text_model.keras + --spectrogram-dir ${tts07.speech_sample_spectrograms} + --output-dir ${tts09.evaluation_results} + deps: + - ${tts05.speech_sample_manifests}/val_manifest.csv + - ${tts06.vocab_list} + - ${tts07.speech_sample_spectrograms} + - ${tts08.speech_to_text_model}/speech_to_text_model.keras + - ${tts09.script} + outs: + - ${tts09.evaluation_results} + evaluate_test_samples: + cmd: python ${tts10.script} --manifest ${tts05.speech_sample_manifests}/test_manifest.csv + --vocab ${tts06.vocab_list} + --model ${tts08.speech_to_text_model}/speech_to_text_model.keras + --spectrogram-dir ${tts07.speech_sample_spectrograms} + --output-zip ${tts10.test_samples_zip} + deps: + - ${tts05.speech_sample_manifests}/test_manifest.csv + - ${tts06.vocab_list} + - ${tts07.speech_sample_spectrograms} + - ${tts08.speech_to_text_model}/speech_to_text_model.keras + - ${tts10.script} + outs: + - ${tts10.test_samples_zip} \ No newline at end of file diff --git a/ml/scripts/intent_prediction/01_generate_phrases.py b/ml/scripts/intent_prediction/01_generate_phrases.py new file mode 100644 index 0000000..e8af56a --- /dev/null +++ b/ml/scripts/intent_prediction/01_generate_phrases.py @@ -0,0 +1,404 @@ +#!/usr/bin/env python3 +""" +Training Data Generator for LLM Intent Classification + +This script generates training data variations from the 01_input_phrases.csv file +for fine-tuning an LLM to handle remote control commands as a fallback. +""" + +import csv +import random +import re +from pathlib import Path +from typing import List, Dict, Tuple + + +# ===================== +# CONFIGURATION SETTINGS +# ===================== +SCRIPT_DIR = Path(__file__).parent +DATA_FOLDER = SCRIPT_DIR / "../../data/intent_prediction/01_generate_phrases/" +INPUT_FILE = SCRIPT_DIR / "01_input_phrases.csv" +OUTPUT_FILE = DATA_FOLDER / "training_data.csv" + +# Number of total samples to generate +TARGET_SAMPLES = 10000 + +# Probability settings for variations +REPEAT_MODIFIER_CHANCE = 0.25 +PLEASANTRY_CHANCE = 0.3 +HESITATION_CHANCE = 0.3 +SPELLING_VARIANT_CHANCE = 0.3 +CASE_VARIANT_CHANCE = 0.3 + +# ===================== + +# Variation components +REPEAT_MODIFIERS = { + 1: ["", "once", "one", "one time", "one more time", "another one", "another time", "again"], + 2: ["twice", "two", "two times", "two more times", "another two", "another two times"], + 3: ["three", "three times", "three more times", "another three", "another three times"], + 4: ["four", "four times", "four more times", "another four", "another four times"], + 5: ["five", "five times", "five more times", "another five", "another five times"], + 6: ["six", "six times", "six more times", "another six", "another six times"], + 7: ["seven", "seven times", "seven more times", "another seven", "another seven times"], + 8: ["eight", "eight times", "eight more times", "another eight", "another eight times"], + 9: ["nine", "nine times", "nine more times", "another nine", "another nine times"], +} + +PLEASANTRIES = [ + "", + "please ", + ", please ", + "please, ", + ", please, ", + "could you ", + "can you ", + "would you ", + ", thank you ", + ", thanks " +] + +HESITATIONS = [ + "", + ", um, ", + ", uh, ", + ", umm, ", + ", err, ", + ", hmm, ", + " ... ", +] + +# Homophone/spelling variations +SPELLING_VARIANTS = { + "one": ["one", "won", "1"], + "to": ["to", "too", "two", "2"], + "two": ["two", "to", "too", "2"], + "three": ["three", "3"], + "for": ["for", "four", "4"], + "four": ["four", "for", "4"], + "five": ["five", "5"], + "six": ["six", "6"], + "seven": ["seven", "7"], + "eight": ["eight", "ate", "8"], + "nine": ["nine", "9"], + "right": ["right", "rite", "write", "wright"], + "OK": ["OK", "okay", "ok"], + "pause": ["pause", "paws"] +} + +class VariationGenerator: + """Generates variations of command phrases.""" + + def __init__(self, target_samples: int): + self.target_samples = target_samples + (self.generated, self.existing_variations) = self.load_existing_variations(OUTPUT_FILE) + + def load_existing_variations(self, filepath: Path) -> Tuple[set, List[Dict[str, str]]]: + """Load existing variations from a CSV file to avoid duplicates.""" + variations = [] + generated_keys = set() + if not filepath.exists(): + return generated_keys, variations + with open(filepath, newline='', encoding='utf-8') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + variations.append(row) + key = self._create_key(row) + generated_keys.add(key) + return (generated_keys, variations) + + def generate_variations(self, commands: List[Dict[str, str]]) -> List[Dict[str, str]]: + """Generate variations for all commands.""" + variations = self.existing_variations + samples_per_phrase = max(1, self.target_samples // len(commands)) + + for cmd in commands: + existing_for_phrase = len([v for v in variations if v['base_phrase'] == cmd['phrase']]) + cmd_variations = self._generate_for_command( + cmd['phrase'], + cmd['canonical'], + samples_per_phrase - existing_for_phrase + ) + variations.extend(cmd_variations) + + # If we haven't reached target, add more variations + while len(variations) < self.target_samples and len(commands) > 0: + cmd = random.choice(commands) + extra = self._generate_for_command( + cmd['phrase'], + cmd['canonical'], + 1 + ) + variations.extend(extra) + + return variations[:self.target_samples] + + def _generate_for_command(self, phrase: str, canonical: str, count: int) -> List[Dict[str, str]]: + """Generate variations for a single command.""" + variations = [] + attempts = 0 + max_attempts = count * 10 # Prevent infinite loops + + while len(variations) < count and attempts < max_attempts: + attempts += 1 + + # Generate a variation + variation = self._create_variation(phrase) + + # Check for duplicates + variant_key = self._create_key(variation) + if variant_key not in self.generated: + self.generated.add(variant_key) + variation['canonical_label'] = canonical + variations.append(variation) + + return variations + + def _create_key(self, row: Dict[str, str]) -> str: + """Create a unique key for a variation based on base phrase and canonical label.""" + return row['base_phrase'].lower() + "|" + row['transformations'] + + def _create_variation(self, phrase: str) -> Dict[str, str]: + """Create a single variation of a phrase.""" + transformations = [] + result = phrase + repeat_count_used = 1 + + # Add repeat modifier (for data variety only) + repeat_modifier = "" + if random.random() < REPEAT_MODIFIER_CHANCE: + repeat_count_used = random.choice(list(REPEAT_MODIFIERS.keys())) + modifiers = REPEAT_MODIFIERS.get(repeat_count_used, []) + if modifiers: + repeat_modifier = random.choice(modifiers) + else: + repeat_modifier = "" + repeat_count_used = 1 + if repeat_modifier: + result = f"{result} {repeat_modifier}" + transformations.append(f"repeat_modifier:{repeat_count_used}") + + # Add pleasantry + if random.random() < PLEASANTRY_CHANCE: + pleasantry = random.choice(PLEASANTRIES) + if pleasantry: + if pleasantry.startswith("could you") or pleasantry.startswith("can you") or pleasantry.startswith("would you"): + result = f"{pleasantry} {result}" + transformations.append(f"prefix_pleasantry:{pleasantry}") + elif pleasantry.startswith(","): + result = f"{result}{pleasantry}" + transformations.append(f"suffix_pleasantry:{pleasantry}") + else: + if random.random() > 0.5: + result = f"{result}, {pleasantry}" + else: + result = f"{pleasantry}, {result}" + transformations.append(f"pleasantry:{pleasantry}") + + # Save the speech to detect at this point. This is what the STT system should recognize. + speech_to_detect = self._normalize_commas_and_whitespace(result) + + # Add hesitation + if random.random() < HESITATION_CHANCE: + hesitation = random.choice([h for h in HESITATIONS if h]) + if hesitation: + # Insert hesitation at random position + words = result.split() + if len(words) > 1: + pos = random.randint(0, len(words)) + words.insert(pos, hesitation) + result = " ".join(words) + transformations.append(f"hesitation:{hesitation}") + + # Apply spelling variations + if random.random() < SPELLING_VARIANT_CHANCE: + for word, variants in SPELLING_VARIANTS.items(): + lower_result = result.lower() + if word in lower_result and len(variants) > 1: + variant = random.choice([v for v in variants if v != word]) + result_lower = lower_result.replace(word, variant, 1) + # Match original casing roughly + if result.isupper(): + result = result_lower.upper() + elif result.istitle(): + result = result_lower.title() + else: + result = result_lower + transformations.append(f"spelling_variant:{word}->{variant}") + + # Random case variations + if random.random() < CASE_VARIANT_CHANCE: + case_transform = random.choice(["lower", "upper", "title", "original"]) + if case_transform == "lower": + result = result.lower() + transformations.append("lowercase") + elif case_transform == "upper": + result = result.upper() + transformations.append("uppercase") + elif case_transform == "title": + result = result.title() + transformations.append("titlecase") + + result = self._normalize_commas_and_whitespace(result) + + return { + 'base_phrase': phrase, + 'surface_form': result, + 'speech_to_detect': speech_to_detect, + 'transformations': "|".join(transformations) if transformations else "none", + 'repeat_count': repeat_count_used + } + + def _normalize_commas_and_whitespace(self, text: str) -> str: + """Normalize commas and whitespace in the given text.""" + # Collapse consecutive commas into a single comma + text = re.sub(r'(?:\s*,\s*){2,}', ',', text) + # Remove leading/trailing commas and normalize whitespace + text = re.sub(r'^\s*,\s*|\s*,\s*$', '', text).strip() + text = re.sub(r'\s+', ' ', text).strip() + text = re.sub(r'\s,', ',', text).strip() + return text + + def _tokenize(self, text: str) -> List[str]: + """Tokenize text into alphanumeric lowercase tokens.""" + return re.findall(r"[a-z0-9']+", text.lower()) + + def sanity_check(self, variations: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], List[str]]: + """Perform sanity checks on generated variations.""" + valid = [] + issues = [] + + for var in variations: + surface = var['surface_form'] + canonical = var['canonical_label'] + + # Check 1: Not empty + if not surface or not surface.strip(): + issues.append(f"Empty surface form for {canonical}") + continue + + # Check 2: Reasonable length (5-150 characters) + if len(surface) < 2 or len(surface) > 150: + issues.append(f"Unusual length ({len(surface)}) for: {surface}") + continue + + # Check 3: Contains at least one letter + if not any(c.isalpha() for c in surface): + issues.append(f"No letters in: {surface}") + continue + + # Check 4: Base phrase recognizable via token overlap (with punctuation stripped) + base_tokens_list = self._tokenize(var['base_phrase']) + surface_token_set = set(self._tokenize(surface)) + + # Fallback to canonical label tokens if base tokens are empty + if not base_tokens_list: + base_tokens_list = self._tokenize(canonical) + + augmented_base_tokens = set(base_tokens_list) + + transformations_str = var['transformations'] + if transformations_str and transformations_str != "none": + for t in transformations_str.split('|'): + if t.startswith("spelling_variant:"): + try: + mapping = t.split(":", 1)[1] + if "->" in mapping: + base_word, variant_word = mapping.split("->", 1) + augmented_base_tokens.update(self._tokenize(base_word)) + augmented_base_tokens.update(self._tokenize(variant_word)) + except ValueError: + # Ignore malformed transformation metadata + pass + + if augmented_base_tokens and surface_token_set: + if augmented_base_tokens.isdisjoint(surface_token_set): + issues.append(f"Base phrase '{var['base_phrase']}' not recognizable in '{surface}'") + continue + + valid.append(var) + + return valid, issues + + +def main(): + """Main execution function.""" + print("Training Data Generator for LLM Intent Classification") + print("=" * 60) + + # Step 1: Extract commands from CSV + print("\n1. Extracting commands from " + INPUT_FILE.as_posix() + "...") + csv_path = INPUT_FILE + commands = [] + with open(csv_path, newline='', encoding='utf-8') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + commands.append({ + 'phrase': row['phrase'], + 'canonical': row['command'] + }) + print(f" Found {len(commands)} command phrases") + + # Step 2: Generate variations + print(f"\n2. Generating {TARGET_SAMPLES} variations...") + generator = VariationGenerator(TARGET_SAMPLES) + variations = generator.generate_variations(commands) + print(f" Generated {len(variations)} initial variations") + + # Step 3: Sanity check + print("\n3. Running sanity checks...") + valid_variations, issues = generator.sanity_check(variations) + print(f" Valid: {len(valid_variations)}") + print(f" Issues: {len(issues)}") + if issues: + print("\n Sample issues:") + for issue in issues[:5]: + print(f" - {issue}") + + # Step 4: Write to CSV + print(f"\n4. Writing to {OUTPUT_FILE}...") + output_path = Path(OUTPUT_FILE) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w', newline='', encoding='utf-8') as f: + fieldnames = [ + 'surface_form', + 'base_phrase', + 'speech_to_detect', + 'canonical_label', + 'repeat_count', + 'transformations' + ] + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for row in valid_variations: + writer.writerow({ + 'surface_form': row['surface_form'], + 'base_phrase': row['base_phrase'], + 'speech_to_detect': row['speech_to_detect'], + 'canonical_label': row['canonical_label'], + 'repeat_count': row.get('repeat_count', ''), + 'transformations': row['transformations'] + }) + + print(f"\n✓ Successfully generated {len(valid_variations)} training samples") + print(f"✓ Output saved to: {OUTPUT_FILE}") + + # Step 5: Show sample outputs + print("\n5. Sample outputs:") + print("-" * 60) + samples = random.sample(valid_variations, min(10, len(valid_variations))) + for i, sample in enumerate(samples, 1): + print(f"\n{i}. Surface form: {sample['surface_form']}") + print(f" Canonical: {sample['canonical_label']}") + print(f" Base: {sample['base_phrase']}") + print(f" Repeat count: {sample.get('repeat_count', '')}") + print(f" Transforms: {sample['transformations']}") + + print("\n" + "=" * 60) + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/ml/scripts/intent_prediction/01_input_phrases.csv b/ml/scripts/intent_prediction/01_input_phrases.csv new file mode 100644 index 0000000..9394dfd --- /dev/null +++ b/ml/scripts/intent_prediction/01_input_phrases.csv @@ -0,0 +1,52 @@ +phrase,command +Back,Back +Go back,Back +Page down,ChannelDown +Channel down,ChannelDown +Page up,ChannelUp +Channel up,ChannelUp +Go down,Down +Down,Down +Quit,Exit +Exit,Exit +Guide,Guide +Go to Guide,Guide +Left,Left +Go left,Left +Mute,Mute +Netflix,Netflix +Go to Netflix,Netflix +Pause,Pause +Play,Play +Turn Off,PowerOff +Power Off,PowerOff +Off,PowerOff +Turn Off the TV,PowerOff +On,PowerOn +Power On,PowerOn +Turn On,PowerOn +Turn On the TV,PowerOn +Record,Record +Replay,Replay +Skip back,Replay +Right,Right +Go right,Right +Select,Select +OK,Select +Advance,Skip +Skip forward,Skip +Skip,Skip +Go to TiVo,TiVo +TiVo,TiVo +Up,Up +Go up,Up +Turn it down,VolumeDown +Quieter,VolumeDown +Softer,VolumeDown +Volume down,VolumeDown +Turn down the volume,VolumeDown +Louder,VolumeUp +Volume up,VolumeUp +Turn it up,VolumeUp +Crank it up,VolumeUp +Turn up the volume,VolumeUp diff --git a/ml/scripts/requirements.txt b/ml/scripts/requirements.txt new file mode 100644 index 0000000..2f94239 --- /dev/null +++ b/ml/scripts/requirements.txt @@ -0,0 +1,10 @@ +edge-tts==6.1.9 +pandas==2.2.0 +tqdm==4.66.1 +pydub==0.25.1 +soundfile==0.12.1 +librosa==0.10.1 +tensorflow==2.15.0 +onnx==1.17.0 +tf2onnx==1.16.1 +jiwer==3.0.3 \ No newline at end of file diff --git a/ml/scripts/speech_to_text/01_generate_speech_samples.py b/ml/scripts/speech_to_text/01_generate_speech_samples.py new file mode 100644 index 0000000..6251113 --- /dev/null +++ b/ml/scripts/speech_to_text/01_generate_speech_samples.py @@ -0,0 +1,55 @@ +import argparse +from pathlib import Path +import pandas as pd +import os +import asyncio +import edge_tts +from tqdm import tqdm + +# Settings +# Set subsampling rate (e.g., 1 = all, 2 = every other, 3 = every third) +subsample_rate = 1 # Change this value as needed + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Generate speech samples from variations CSV.") +parser.add_argument('--input-file', type=Path, required=True, help='Path to the input CSV file (variations)') +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for output speech samples') +paths = parser.parse_args() + +os.makedirs(paths.output_dir, exist_ok=True) + +# Load the phrases from CSV +phrases_df = pd.read_csv(paths.input_file, encoding='utf-8') + +subsampled_indices = list(range(0, len(phrases_df), subsample_rate)) + +# Functions +async def generate_samples(): + count = 0 + for idx in tqdm(subsampled_indices, desc="Generating speech samples"): + output_path = paths.output_dir / phrases_df.iloc[idx]['sample_file_name'] + if output_path.exists(): + continue # Skip if samples already exist for this phrase index + phrase = phrases_df.iloc[idx]['phrase_to_speak'] + voice = phrases_df.iloc[idx]['voice'] + speech_rate_str = phrases_df.iloc[idx]['speech_rate'] + try: + communicate = edge_tts.Communicate( + text=phrase, + voice=voice, + rate=speech_rate_str, + ) + await communicate.save(str(output_path)) + count += 1 + except Exception as e: + print(f"Error generating sample for index {idx} with voice {voice}: {e}") + continue + return count + +async def main(): + print("Starting sample generation...") + total_generated = await generate_samples() + print(f"Sample generation completed. Total samples generated: {total_generated}") + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/ml/scripts/speech_to_text/01a_generate_speech_sample_variations.py b/ml/scripts/speech_to_text/01a_generate_speech_sample_variations.py new file mode 100644 index 0000000..5b79f50 --- /dev/null +++ b/ml/scripts/speech_to_text/01a_generate_speech_sample_variations.py @@ -0,0 +1,98 @@ +import argparse +from pathlib import Path +import pandas as pd +import random +import os +import asyncio +import edge_tts +from tqdm import tqdm + +# Settings +# Set the speech rate range (in percentage) +min_speech_rate = -50 # Minimum speech rate +max_speech_rate = 80 # Maximum speech rate + +# Read file and directory paths from command line arguments +parser = argparse.ArgumentParser(description="Generate speech sample variations.") +parser.add_argument('--input-file', type=Path, required=True, help='Path to the input CSV file') +parser.add_argument('--output-file', type=Path, required=True, help='Path to the output CSV file') +parser.add_argument('--samples-dir', type=Path, required=True, help='Directory for speech samples') +paths = parser.parse_args() + +os.makedirs(paths.output_file.parent, exist_ok=True) + +phrases_df = pd.read_csv(paths.input_file, encoding='utf-8') +phrases = phrases_df['surface_form'].tolist() +labels = phrases_df['canonical_label'].tolist() +speech_to_detect = phrases_df['speech_to_detect'].tolist() + +# Load the existing records if the file exists +try: + existing_df = pd.read_csv(paths.output_file, encoding='utf-8') + existing_records = existing_df.to_dict(orient='records') + existing_phrases = existing_df['phrase_to_speak'].tolist() + print(f"Loaded {len(existing_records)} existing variation records from {paths.output_file}...") +except FileNotFoundError: + print(f"Did not find existing {paths.output_file}.") + existing_records = [] + existing_phrases = [] + +# Functions +async def get_voices(): + voices = await edge_tts.list_voices() + # Filter for female voices and exclude problematic ones + female_voices = [ + v for v in voices + if v['Gender'] == 'Female' + and v['Locale'] == 'en-US' + and ':' not in v['ShortName'] + and 'DragonHD' not in v['ShortName'] + and 'Turbo' not in v['ShortName'] + ] + print(f"Sample female voices: {[v['ShortName'] for v in female_voices[:5]]}") + print(f"Total female voices found: {len(female_voices)}") + return female_voices + +async def generate_variation_records(voices: list): + records = [] + if not voices: + print("No voices available.") + return records + for idx, (phrase, label, speech) in enumerate(tqdm(zip(phrases, labels, speech_to_detect), desc = "Generating variations", total=len(phrases))): + try: + existing_idx = existing_phrases.index(phrase) + records.append(existing_records[existing_idx]) + existing_phrases.pop(existing_idx) + continue + except ValueError: + pass # Phrase not found in existing records, proceed to create new variations + voice = random.choice(voices)['ShortName'] + speech_rate = random.randint(min_speech_rate, max_speech_rate) + speech_rate_str = f"+{speech_rate}%" if speech_rate >= 0 else f"{speech_rate}%" + records.append({ + 'phrase_to_speak': phrase, + 'phrase_to_detect': speech, + 'voice': voice, + 'speech_rate': speech_rate_str, + 'sample_file_name': f"{label}_{idx}_{voice}_r{speech_rate + 100}.wav", + }) + return records + +async def main(): + print("Fetching available voices...") + voices = await get_voices() + print("Generating variation records...") + variation_records = await generate_variation_records(voices) + print(f"Saving variation records to {paths.output_file}...") + variations_df = pd.DataFrame(variation_records) + variations_df.to_csv(paths.output_file, index=False, encoding='utf-8') + print("Deleting existing speech samples.") + # Remove files from SAMPLE_OUTPUT_DIR that are not in the new variations_df + if paths.samples_dir.exists(): + for file in paths.samples_dir.iterdir(): + if not variations_df["sample_file_name"].str.contains(file.name).any(): + print(f"Deleting obsolete sample file: {file.name}") + file.unlink() + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/ml/scripts/speech_to_text/02_add_delays.py b/ml/scripts/speech_to_text/02_add_delays.py new file mode 100644 index 0000000..0ca1bb9 --- /dev/null +++ b/ml/scripts/speech_to_text/02_add_delays.py @@ -0,0 +1,46 @@ +import argparse +import os +from pathlib import Path +import pandas as pd +import soundfile as sf +import numpy as np +from tqdm import tqdm + +# Settings +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Add random delays to audio samples.") +parser.add_argument('--input-file', type=Path, required=True, help='Input wav file list (CSV) with delay variations') +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for output wav files') +paths = parser.parse_args() + +os.makedirs(paths.output_dir, exist_ok=True) + +# Read the CSV file with delay variations +df = pd.read_csv(paths.input_file) + +for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing audio files", unit="file"): + file_path = Path(row['input_file_path']) + + data, samplerate = sf.read(file_path) + new_data = data + prefix_delay = 0.0 + suffix_delay = 0.0 + + # Add prefix silence + if row['prefix_delay_seconds'] > 0.0: + prefix_delay = row['prefix_delay_seconds'] + num_prefix_samples = int(prefix_delay * samplerate) + silence_prefix = np.zeros((num_prefix_samples, data.shape[1]) if data.ndim > 1 else num_prefix_samples, dtype=data.dtype) + new_data = np.concatenate([silence_prefix, new_data], axis=0) + + # Add suffix silence + if row['suffix_delay_seconds'] > 0.0: + suffix_delay = row['suffix_delay_seconds'] + num_suffix_samples = int(suffix_delay * samplerate) + silence_suffix = np.zeros((num_suffix_samples, data.shape[1]) if data.ndim > 1 else num_suffix_samples, dtype=data.dtype) + new_data = np.concatenate([new_data, silence_suffix], axis=0) + + out_path = paths.output_dir / row['new_file_name'] + + sf.write(out_path, new_data, samplerate) + diff --git a/ml/scripts/speech_to_text/02a_randomize_delay_variations.py b/ml/scripts/speech_to_text/02a_randomize_delay_variations.py new file mode 100644 index 0000000..f1cff0d --- /dev/null +++ b/ml/scripts/speech_to_text/02a_randomize_delay_variations.py @@ -0,0 +1,50 @@ +import argparse +from pathlib import Path +import pandas as pd +import random +import os + +# Settings +# Delay frequency +prefix_delay_frequency = 5 # Add a delay every N samples +suffix_delay_frequency = 5 # Add a delay every N samples + +# Delay duration +max_delay_duration = 1.5 # Maximum delay duration in seconds +min_delay_duration = 0.5 # Minimum delay duration in seconds + +# Read file and directory paths from command line arguments +parser = argparse.ArgumentParser(description="Generate speech sample variations.") +parser.add_argument('--input-dir', type=Path, required=True, help='Path to the input directory containing speech samples') +parser.add_argument('--output-file', type=Path, required=True, help='Path to the output CSV file containing random delay values') +paths = parser.parse_args() + +os.makedirs(paths.output_file.parent, exist_ok=True) + +records = [] +for file_path in paths.input_dir.glob('*.wav'): + stem = file_path.stem + + if random.randint(1, prefix_delay_frequency) == 1: + prefix_delay = random.uniform(min_delay_duration, max_delay_duration) + stem = f"{stem}_pre{int(prefix_delay * 1000):04d}" + else: + prefix_delay = 0.0 + + if random.randint(1, suffix_delay_frequency) == 1: + suffix_delay = random.uniform(min_delay_duration, max_delay_duration) + stem = f"{stem}_suf{int(suffix_delay * 1000):04d}" + else: + suffix_delay = 0.0 + + records.append({ + 'input_file_path': str(file_path), + 'prefix_delay_seconds': prefix_delay, + 'suffix_delay_seconds': suffix_delay, + 'new_file_name': f"{stem}{file_path.suffix}" + }) + +# Save to CSV +df = pd.DataFrame.from_records(records) +df.to_csv(paths.output_file, index=False, encoding='utf-8') +print(f"Delay variations saved to {paths.output_file}") \ No newline at end of file diff --git a/ml/scripts/speech_to_text/03_add_background_noise.py b/ml/scripts/speech_to_text/03_add_background_noise.py new file mode 100644 index 0000000..10f794e --- /dev/null +++ b/ml/scripts/speech_to_text/03_add_background_noise.py @@ -0,0 +1,77 @@ +import argparse +import os +from pathlib import Path +import random +import numpy as np +import soundfile as sf +from tqdm import tqdm + +# Settings +# Noise frequency +background_noise_frequency = 2 # Add noise every N samples + +# Noise volume +background_noise_volume_min = 0.01 # Minimum noise volume +background_noise_volume_max = 0.3 # Maximum noise volume + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Add background noise to audio samples.") +parser.add_argument('--input-dir', type=Path, required=True, help='Directory containing input wav files') +parser.add_argument('--noise-dir', type=Path, required=True, help='Directory containing noise wav files') +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for output wav files') +paths = parser.parse_args() + +os.makedirs(paths.output_dir, exist_ok=True) + +def get_random_noise(noise_files, length, sr): + noise_file = random.choice(noise_files) + noise, noise_sr = sf.read(noise_file) + if len(noise.shape) > 1: + noise = noise[:,0] # Use first channel if stereo + if noise_sr != sr: + # Resample noise to match target sample rate + num_samples = int(len(noise) * sr / noise_sr) + indices = np.linspace(0, len(noise) - 1, num_samples).astype(int) + noise = noise[indices] + if len(noise) < length: + # Loop noise if too short + repeats = int(np.ceil(length / len(noise))) + noise = np.tile(noise, repeats) + max_start = max(0, len(noise) - length) + start = random.randint(0, max_start) + return noise[start:start+length] + +def add_noise_to_audio(audio, noise, volume): + return audio + noise * volume + +def main(): + noise_files = list(paths.noise_dir.glob("*.wav")) + if not noise_files: + print(f"No noise samples found in {paths.noise_dir}") + return + input_files = list(paths.input_dir.glob("*.wav")) + for input_file in tqdm(input_files, desc="Processing audio files", unit="file", total=len(input_files)): + stem = input_file.stem + audio, sr = sf.read(input_file) + if len(audio.shape) > 1: + audio = audio[:,0] # Use first channel if stereo + add_noise = (random.randint(1, background_noise_frequency) == 1) + if add_noise: + noise = get_random_noise(noise_files, len(audio), sr) + volume = random.uniform(background_noise_volume_min, background_noise_volume_max) + audio_noisy = add_noise_to_audio(audio, noise, volume) + # Clip to [-1,1] to avoid overflow + audio_noisy = np.clip(audio_noisy, -1.0, 1.0) + out_audio = audio_noisy + # Modify filename to include _bg{volume} without leading '0.' + volume_str = f"{int(volume * 1000):03d}" + out_filename = f"{stem}_bg{volume_str}.wav" + else: + out_audio = audio + out_filename = input_file.name + out_path = paths.output_dir / out_filename + sf.write(out_path, out_audio, sr) + +if __name__ == "__main__": + main() + diff --git a/ml/scripts/speech_to_text/03a_download_background_noise.py b/ml/scripts/speech_to_text/03a_download_background_noise.py new file mode 100644 index 0000000..90c4e81 --- /dev/null +++ b/ml/scripts/speech_to_text/03a_download_background_noise.py @@ -0,0 +1,35 @@ +import argparse +import os +from pathlib import Path +import requests + +# Settings +# Noise samples +noise_samples = { + "creative-background-short-ver.wav": "https://cdn.freesound.org/sounds/721/721949-a0b57121-2a03-4dac-97c0-ee15fc5db207?filename=721949__audiocoffee__creative-background-short-ver.wav", + "trailer.wav": "https://cdn.freesound.org/sounds/785/785516-53995c18-2299-49bc-b042-357c8cb919fd?filename=785516__litesaturation__trailer.wav", + "tv-chatter.wav": "https://cdn.freesound.org/sounds/765/765157-8a98bb7d-6d3d-4869-af6c-4ba18aaddf27?filename=765157__mieckevanhoek__tv-chatter.wav", + "tv-news-loop.wav": "https://cdn.freesound.org/sounds/468/468539-e433c8eb-7f21-467d-9910-a37f4738c868?filename=468539__sergequadrado__tv-news-loop.wav", + "tv-recording-of-a-handball-match-3.wav": "https://cdn.freesound.org/sounds/786/786263-6ef16c1d-183a-4143-beca-6b9528e9cdb5?filename=786263__king_anna__tv-recording-of-a-handball-match-3.wav", +} + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Download background noise samples.") +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for downloaded noise wav files') +paths = parser.parse_args() + +os.makedirs(paths.output_dir, exist_ok=True) + +# Download noise samples +for filename, url in noise_samples.items(): + output_path = paths.output_dir / filename + if output_path.exists(): + print(f"File {output_path} already exists. Skipping download.") + continue + print(f"Downloading {filename} from {url}...") + response = requests.get(url) + response.raise_for_status() + with open(output_path, "wb") as f: + f.write(response.content) + print(f"Saved to {output_path}") + diff --git a/ml/scripts/speech_to_text/04_add_microphone_noise.py b/ml/scripts/speech_to_text/04_add_microphone_noise.py new file mode 100644 index 0000000..6fe3e70 --- /dev/null +++ b/ml/scripts/speech_to_text/04_add_microphone_noise.py @@ -0,0 +1,51 @@ +import argparse +import os +from pathlib import Path +import random +import numpy as np +import soundfile as sf +from tqdm import tqdm + +# Settings +# Noise frequency +microphone_noise_frequency = 2 # Add noise every N samples + +# Noise volume +microphone_noise_volume_min = 0.01 # Minimum noise volume +microphone_noise_volume_max = 0.05 # Maximum noise volume +microphone_noise_type = 'white' # Type of noise (future extension) + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Add microphone noise to audio samples.") +parser.add_argument('--input-dir', type=Path, required=True, help='Directory containing input wav files') +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for output wav files') +paths = parser.parse_args() + +os.makedirs(paths.output_dir, exist_ok=True) + +# Process each audio file in the input directory +input_files = list(paths.input_dir.glob('*.wav')) +for file_path in tqdm(input_files, desc="Processing audio files", unit="file", total=len(input_files)): + stem = file_path.stem + # Decide randomly whether to add noise + add_noise = random.randint(1, microphone_noise_frequency) == 1 + data, samplerate = sf.read(file_path) + noise_volume = 0.0 + if add_noise: + # Random noise volume + noise_volume = random.uniform(microphone_noise_volume_min, microphone_noise_volume_max) + # Generate white noise + noise = np.random.normal(0, 1, data.shape) * noise_volume + data_noisy = data + noise + # Clip to valid range + data_noisy = np.clip(data_noisy, -1.0, 1.0) + # Modify filename to indicate noise + noise_str = f"_mic{int(noise_volume * 1000):03d}" + out_name = file_path.stem + noise_str + file_path.suffix + out_path = paths.output_dir / out_name + sf.write(out_path, data_noisy, samplerate) + else: + # Save original file without noise + out_path = paths.output_dir / file_path.name + sf.write(out_path, data, samplerate) + diff --git a/ml/scripts/speech_to_text/05_create_set_manifests.py b/ml/scripts/speech_to_text/05_create_set_manifests.py new file mode 100644 index 0000000..8b0a4f0 --- /dev/null +++ b/ml/scripts/speech_to_text/05_create_set_manifests.py @@ -0,0 +1,71 @@ + +import argparse +import os +from pathlib import Path +import random +import csv +import re + +# Settings +training_set_percentage = 80 # Percentage of data for training set +validation_set_percentage = 10 # Percentage of data for validation set +test_set_percentage = 10 # Percentage of data for test set + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Create set manifests for training, validation, and test sets.") +parser.add_argument('--input-manifest', type=Path, required=True, help='Path to input manifest CSV (training_data.csv)') +parser.add_argument('--clean-dir', type=Path, required=True, help='Directory with clean speech samples') +parser.add_argument('--noisy-dir', type=Path, required=True, help='Directory with noisy speech samples') +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for output manifest files') +paths = parser.parse_args() + +os.makedirs(paths.output_dir, exist_ok=True) + +# Read surface_form from input manifest +surface_forms = [] +with open(paths.input_manifest, newline='', encoding='utf-8') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + surface_forms.append(row['speech_to_detect']) + +# Collect all files from clean and noisy dirs +all_files = [] +for input_dir in [paths.clean_dir, paths.noisy_dir]: + for root, _, files in os.walk(input_dir): + for file in files: + file_path = Path(root) / file + all_files.append(str(file_path.resolve())) + +# Shuffle the list (seeded for reproducibility) +random.seed(42) +random.shuffle(all_files) + +total_files = len(all_files) +train_count = int(total_files * training_set_percentage / 100) +val_count = int(total_files * validation_set_percentage / 100) +test_count = total_files - train_count - val_count + +train_files = all_files[:train_count] +val_files = all_files[train_count:train_count+val_count] +test_files = all_files[train_count+val_count:] + + +# Helper to extract number between underscores +def extract_number_from_filename(filename): + match = re.search(r'_(\d+)_', filename) + if match: + return int(match.group(1)) + return None + +def write_manifest(file_list, manifest_path): + with open(manifest_path, 'w', newline='', encoding='utf-8') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['filepath', 'speech_to_detect']) + for f in file_list: + num = extract_number_from_filename(os.path.basename(f)) + surface_form = surface_forms[num] if num is not None and num < len(surface_forms) else '' + writer.writerow([f, surface_form]) + +write_manifest(train_files, paths.output_dir / 'train_manifest.csv') +write_manifest(val_files, paths.output_dir / 'val_manifest.csv') +write_manifest(test_files, paths.output_dir / 'test_manifest.csv') diff --git a/ml/scripts/speech_to_text/06_create_vocab_list.py b/ml/scripts/speech_to_text/06_create_vocab_list.py new file mode 100644 index 0000000..1d5f1ef --- /dev/null +++ b/ml/scripts/speech_to_text/06_create_vocab_list.py @@ -0,0 +1,30 @@ +import argparse +import csv +from pathlib import Path +import os + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Create vocabulary list from training data CSV.") +parser.add_argument('--input-file', type=Path, required=True, help='Path to input CSV file (training_data.csv)') +parser.add_argument('--output-file', type=Path, required=True, help='Path for output vocab list') +paths = parser.parse_args() + +os.makedirs(paths.output_file.parent, exist_ok=True) +vocab_file = paths.output_file + +# Load phrases from CSV and create vocab list +words = set() +with open(paths.input_file, newline='', encoding='utf-8') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + phrase = row.get('speech_to_detect', '').lower() + for word in [w.strip() for w in phrase.replace(',', ' ').split()]: + if word: + words.add(word) + +vocab_list = sorted(words) + +with open(vocab_file, 'w', encoding='utf-8') as f: + for word in vocab_list: + f.write(word + '\n') + diff --git a/ml/scripts/speech_to_text/07_compute_spectrograms.py b/ml/scripts/speech_to_text/07_compute_spectrograms.py new file mode 100644 index 0000000..b7e1012 --- /dev/null +++ b/ml/scripts/speech_to_text/07_compute_spectrograms.py @@ -0,0 +1,78 @@ +import argparse +from pathlib import Path +import os +import numpy as np +import librosa +import pandas as pd +import soundfile as sf +from tqdm import tqdm +import re + +# Settings +time_steps = 360 # number of time steps in spectrogram output +input_token_length = 20 # max length of input token sequences + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Compute log-mel spectrograms for audio files.") +parser.add_argument('--train-manifest', type=Path, required=True, help='Path to train_manifest.csv') +parser.add_argument('--eval-manifest', type=Path, required=True, help='Path to eval_manifest.csv') +parser.add_argument('--test-manifest', type=Path, required=True, help='Path to test_manifest.csv') +parser.add_argument('--vocab', type=Path, required=True, help='Path to vocab_list.txt') +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for output spectrogram npy files') +paths = parser.parse_args() + +os.makedirs(paths.output_dir, exist_ok=True) + +# Read the vocabulary list from vocab file +with open(paths.vocab, 'r', encoding='utf-8') as vocabfile: + vocab_list = [line.strip() for line in vocabfile if line.strip()] +print(f"Loaded vocabulary list with {len(vocab_list)} entries.") +pad_value = len(vocab_list) # Padding token index + +def compute_melspectrogram(time_steps, wav_path): + y, sr = sf.read(str(wav_path)) + # If stereo, convert to mono (average channels) + if y.ndim > 1: + y = np.mean(y, axis=1) + S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=80) + log_S = librosa.power_to_db(S, ref=np.max) + if log_S.shape[1] < time_steps: + pad_width = time_steps - log_S.shape[1] + log_S = np.pad(log_S, ((0,0),(0,pad_width)), mode='constant') + else: + print(f'Warning: Truncating spectrogram for {wav_path}, has {log_S.shape[1]}>{time_steps} time steps.') + log_S = log_S[:, :time_steps] + return log_S + +def compute_tokens(vocab_list, transcription): + transcription = transcription.lower().replace(',', ' ') + tokens = [vocab_list.index(word) for word in transcription.split() if word in vocab_list] + tokens = tokens + [pad_value]*(input_token_length-len(tokens)) if len(tokens)= len(vocab_list))[0] + if len(idxs) > 0: + return seq[:idxs[0]] + else: + return seq + +refs = [' '.join(indices_to_words(trim_at_blank(seq, ctc_blank_idx))) for seq in y_eval] +hyps = [' '.join(indices_to_words(seq)) for seq in all_preds] +wer_score = wer(refs, hyps) +print(f'WER: {wer_score:.3f}') + +# Show a few example predictions +for i in range(5): + print('REF:', refs[i]) + print('HYP:', hyps[i]) + print() + +# Save all predictions to a file in output dir +output_predictions_file = paths.output_dir / "evaluation_predictions.txt" +with open(output_predictions_file, 'w', encoding='utf-8') as f: + f.write(f'WER: {wer_score:.3f}\n\n') + for ref, hyp in zip(refs, hyps): + f.write(f'REF: {ref}\n') + f.write(f'HYP: {hyp}\n\n') +print(f"Saved evaluation predictions to {output_predictions_file}") diff --git a/ml/scripts/speech_to_text/10_evaluate_test_samples.py b/ml/scripts/speech_to_text/10_evaluate_test_samples.py new file mode 100644 index 0000000..1531144 --- /dev/null +++ b/ml/scripts/speech_to_text/10_evaluate_test_samples.py @@ -0,0 +1,98 @@ +import argparse +from pathlib import Path +import os +import numpy as np +import pandas as pd +from tqdm import tqdm +from zipfile import ZipFile + +print("Initializing TensorFlow...") +import tensorflow as tf + +# Settings +input_token_length = 20 # max length of input token sequences +n_mels = 80 # number of mel frequency bins +time_steps = 360 # number of time steps in spectrogram input +batch_size = 32 # evaluation batch size + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Evaluate test samples and create ZIP of successfully recognized files.") +parser.add_argument('--manifest', type=Path, required=True, help='Path to test_manifest.csv') +parser.add_argument('--model', type=Path, required=True, help='Path to model file (speech_to_text_model.keras)') +parser.add_argument('--vocab', type=Path, required=True, help='Path to vocab_list.txt') +parser.add_argument('--spectrogram-dir', type=Path, required=True, help='Directory with spectrogram npy files') +parser.add_argument('--output-zip', type=Path, required=True, help='Path for output zip file') +paths = parser.parse_args() + +os.makedirs(paths.output_zip.parent, exist_ok=True) + + +# Read the sample file names from manifest +eval_set = pd.read_csv(paths.manifest, encoding='utf-8') +print(f"Loaded {len(eval_set)} evaluation samples from manifest.") + +# Prepare input/output pairs for evaluation +x_eval = [] +y_eval = [] +for _, row in tqdm(eval_set.iterrows(), total=len(eval_set), desc="Loading evaluation data"): + wav_path = row['filepath'] + # Get the corresponding spectrogram/tokens NPY file path + wav_filename = Path(wav_path).stem + spectrogram_file = paths.spectrogram_dir / f"{wav_filename}.npy" + tokens_file = paths.spectrogram_dir / f"{wav_filename}_tokens.npy" + # Load the numpy array from the npy file + x_eval.append(np.load(spectrogram_file)) + y_eval.append(np.load(tokens_file)) + +# Load the trained model +print("Loading speech-to-text model...") +model = tf.keras.models.load_model(paths.model) +print(f"Loaded {paths.model}") + +# Load the vocabulary list from vocab file +with open(paths.vocab, 'r', encoding='utf-8') as vocabfile: + vocab_list = [line.strip() for line in vocabfile if line.strip()] + ctc_blank_idx = len(vocab_list) # CTC blank token is conventionally at the last index + print(f"Vocabulary size: {len(vocab_list)}, Number of classes (with CTC blank): {len(vocab_list) + 1}") + +def ctc_greedy_decode(pred, blank=ctc_blank_idx): + pred_ids = np.argmax(pred, axis=-1) + decoded = [] + for seq in pred_ids: + prev = blank + out = [] + for idx in seq: + if idx != prev and idx != blank: + out.append(idx) + prev = idx + decoded.append(out) + return decoded + +eval_dataset = tf.data.Dataset.from_tensor_slices((x_eval, y_eval))\ + .batch(batch_size).prefetch(tf.data.AUTOTUNE) + +# Evaluate on eval set +all_preds = [] +for batch, _ in eval_dataset: + pred = model.predict(batch) + all_preds.extend(ctc_greedy_decode(pred)) + +# Convert predicted token indices to text +success_files = [] +def tokens_to_text(tokens): + return ''.join([vocab_list[idx] for idx in tokens if idx < len(vocab_list)]) + +for i, (pred_tokens, true_tokens) in enumerate(zip(all_preds, y_eval)): + pred_text = tokens_to_text(pred_tokens) + true_text = tokens_to_text(true_tokens) + if pred_text == true_text: + wav_path = eval_set.iloc[i]['filepath'] + success_files.append(wav_path) + +# Add successfully matched files to ZIP +with ZipFile(paths.output_zip, 'w') as zipf: + for file_path in success_files: + zipf.write(file_path, arcname=Path(file_path).name) + +print(f"Successfully matched and added {len(success_files)} files to {paths.output_zip}") +