diff --git a/WORKSPACE b/WORKSPACE index a91c61f..4ff14ab 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -77,7 +77,7 @@ grpc_extra_deps() git_repository( name = "nvriva_common", remote = "https://github.com/atomer-nvidia/common.git", - commit = "3085c065085d15a284b37847470fe0182c9a6c67" + commit = "60e67e8ba30eac99d8cfb30275b03b76b6562a29" ) http_archive( diff --git a/riva/clients/tts/riva_tts_client.cc b/riva/clients/tts/riva_tts_client.cc index 446fcc6..a9ef3e9 100644 --- a/riva/clients/tts/riva_tts_client.cc +++ b/riva/clients/tts/riva_tts_client.cc @@ -54,7 +54,7 @@ DEFINE_string(custom_dictionary, "", " User dictionary containing graph-to-phone DEFINE_string(zero_shot_transcript, "", "Transcript corresponding to Zero shot audio prompt."); DEFINE_uint64(timeout_ms, 10000, "Timeout for GRPC channel creation"); DEFINE_uint64(max_grpc_message_size, MAX_GRPC_MESSAGE_SIZE, "Max GRPC message size"); -DEFINE_double(speed, 1.0, "Speed of generated audio, ranges between 0.5-2.0"); +DEFINE_double(exaggeration_factor, 1.0, "Exaggeration factor for generated audio, ranges between 0.0-2.0"); static const std::string LC_enUS = "en-US"; @@ -119,7 +119,7 @@ main(int argc, char** argv) str_usage << " --custom_dictionary= " << std::endl; str_usage << " --timeout_ms= " << std::endl; str_usage << " --max_grpc_message_size= " << std::endl; - str_usage << " --speed= " << std::endl; + str_usage << " --exaggeration_factor= " << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); @@ -188,11 +188,6 @@ main(int argc, char** argv) request.set_sample_rate_hz(rate); request.set_voice_name(FLAGS_voice_name); - if (FLAGS_speed < 0.5 || FLAGS_speed > 2.0) { - LOG(ERROR) << "Speed must be between 0.5 and 2.0" << std::endl; - return -1; - } - request.set_speed(FLAGS_speed); if (not FLAGS_zero_shot_audio_prompt.empty()) { auto zero_shot_data = request.mutable_zero_shot_data(); std::vector> audio_prompt; @@ -225,6 +220,11 @@ main(int argc, char** argv) if (not FLAGS_online and not FLAGS_zero_shot_transcript.empty()) { zero_shot_data->set_transcript(FLAGS_zero_shot_transcript); } + if (FLAGS_exaggeration_factor < 0.0 || FLAGS_exaggeration_factor > 2.0) { + LOG(ERROR) << "Exaggeration factor must be between 0.0 and 2.0" << std::endl; + return -1; + } + zero_shot_data->set_exaggeration_factor(FLAGS_exaggeration_factor); } // Send text content using Synthesize(). diff --git a/riva/clients/tts/riva_tts_perf_client.cc b/riva/clients/tts/riva_tts_perf_client.cc index 8125049..adf867e 100644 --- a/riva/clients/tts/riva_tts_perf_client.cc +++ b/riva/clients/tts/riva_tts_perf_client.cc @@ -63,7 +63,8 @@ DEFINE_string( DEFINE_int32(zero_shot_quality, 20, "Required quality of output audio, ranges between 1-40."); DEFINE_string(custom_dictionary, "", " User dictionary containing graph-to-phone custom words"); DEFINE_string(zero_shot_transcript, "", "Transcript corresponding to Zero shot audio prompt."); -DEFINE_double(speed, 1.0, "Speed of generated audio, ranges between 0.5-2.0"); +DEFINE_double( + exaggeration_factor, 1.0, "Exaggeration factor for generated audio, ranges between 0.0-2.0"); static const std::string LC_enUS = "en-US"; @@ -115,7 +116,7 @@ synthesizeBatch( std::unique_ptr tts, std::string text, std::string language, uint32_t rate, std::string voice_name, std::string filepath, std::string zero_shot_prompt_filename, int32_t zero_shot_quality, std::string custom_dictionary, - std::string zero_shot_transcript, double speed) + std::string zero_shot_transcript, double exaggeration_factor) { // Parse command line arguments. nr_tts::SynthesizeSpeechRequest request; @@ -123,11 +124,6 @@ synthesizeBatch( request.set_language_code(language); request.set_sample_rate_hz(rate); request.set_voice_name(voice_name); - if (speed < 0.5 || speed > 2.0) { - LOG(ERROR) << "Speed must be between 0.5 and 2.0" << std::endl; - return -1; - } - request.set_speed(speed); if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") { request.set_encoding(nr::LINEAR_PCM); } else if (FLAGS_audio_encoding == "opus") { @@ -168,6 +164,11 @@ synthesizeBatch( if (not FLAGS_zero_shot_transcript.empty()) { zero_shot_data->set_transcript(FLAGS_zero_shot_transcript); } + if (exaggeration_factor < 0.0 || exaggeration_factor > 2.0) { + LOG(ERROR) << "Exaggeration factor must be between 0.0 and 2.0" << std::endl; + return -1; + } + zero_shot_data->set_exaggeration_factor(exaggeration_factor); } // Send text content using Synthesize(). @@ -211,18 +212,13 @@ synthesizeOnline( std::unique_ptr tts, std::string text, std::string language, uint32_t rate, std::string voice_name, double* time_to_first_chunk, std::vector* time_to_next_chunk, size_t* num_samples, std::string filepath, - std::string zero_shot_prompt_filename, int32_t zero_shot_quality, double speed) + std::string zero_shot_prompt_filename, int32_t zero_shot_quality, double exaggeration_factor) { nr_tts::SynthesizeSpeechRequest request; request.set_text(text); request.set_language_code(language); request.set_sample_rate_hz(rate); request.set_voice_name(voice_name); - if (speed < 0.5 || speed > 2.0) { - LOG(ERROR) << "Speed must be between 0.5 and 2.0" << std::endl; - return; - } - request.set_speed(speed); auto ae = nr::AudioEncoding::ENCODING_UNSPECIFIED; if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") { ae = nr::LINEAR_PCM; @@ -260,6 +256,11 @@ synthesizeOnline( } zero_shot_data->set_sample_rate_hz(zero_shot_sample_rate); zero_shot_data->set_quality(zero_shot_quality); + if (exaggeration_factor < 0.0 || exaggeration_factor > 2.0) { + LOG(ERROR) << "Exaggeration factor must be between 0.0 and 2.0" << std::endl; + return; + } + zero_shot_data->set_exaggeration_factor(exaggeration_factor); } @@ -366,8 +367,7 @@ main(int argc, char** argv) str_usage << " --zero_shot_quality=" << std::endl; str_usage << " --zero_shot_transcript=" << std::endl; str_usage << " --custom_dictionary= " << std::endl; - str_usage << " --speed= " << std::endl; - + str_usage << " --exaggeration_factor= " << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); @@ -494,7 +494,7 @@ main(int argc, char** argv) std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name, &time_to_first_chunk, time_to_next_chunk, &num_samples, std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt, - FLAGS_zero_shot_quality, FLAGS_speed); + FLAGS_zero_shot_quality, FLAGS_exaggeration_factor); latencies_first_chunk[i]->push_back(time_to_first_chunk); latencies_next_chunks[i]->insert( latencies_next_chunks[i]->end(), time_to_next_chunk->begin(), @@ -570,7 +570,8 @@ main(int argc, char** argv) int32_t num_samples = synthesizeBatch( std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name, std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt, - FLAGS_zero_shot_quality, FLAGS_custom_dictionary, FLAGS_zero_shot_transcript, FLAGS_speed); + FLAGS_zero_shot_quality, FLAGS_custom_dictionary, FLAGS_zero_shot_transcript, + FLAGS_exaggeration_factor); results_num_samples[i]->push_back(num_samples); } }));