Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ grpc_extra_deps()
git_repository(
name = "nvriva_common",
remote = "https://github.com/atomer-nvidia/common.git",
commit = "3085c065085d15a284b37847470fe0182c9a6c67"
commit = "60e67e8ba30eac99d8cfb30275b03b76b6562a29"
)

http_archive(
Expand Down
14 changes: 7 additions & 7 deletions riva/clients/tts/riva_tts_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ DEFINE_string(custom_dictionary, "", " User dictionary containing graph-to-phone
DEFINE_string(zero_shot_transcript, "", "Transcript corresponding to Zero shot audio prompt.");
DEFINE_uint64(timeout_ms, 10000, "Timeout for GRPC channel creation");
DEFINE_uint64(max_grpc_message_size, MAX_GRPC_MESSAGE_SIZE, "Max GRPC message size");
DEFINE_double(speed, 1.0, "Speed of generated audio, ranges between 0.5-2.0");
DEFINE_double(exaggeration_factor, 1.0, "Exaggeration factor for generated audio, ranges between 0.0-2.0");

static const std::string LC_enUS = "en-US";

Expand Down Expand Up @@ -119,7 +119,7 @@ main(int argc, char** argv)
str_usage << " --custom_dictionary=<filename> " << std::endl;
str_usage << " --timeout_ms=<timeout_ms> " << std::endl;
str_usage << " --max_grpc_message_size=<max_grpc_message_size> " << std::endl;
str_usage << " --speed=<speed> " << std::endl;
str_usage << " --exaggeration_factor=<exaggeration_factor> " << std::endl;
gflags::SetUsageMessage(str_usage.str());
gflags::SetVersionString(::riva::utils::kBuildScmRevision);

Expand Down Expand Up @@ -188,11 +188,6 @@ main(int argc, char** argv)

request.set_sample_rate_hz(rate);
request.set_voice_name(FLAGS_voice_name);
if (FLAGS_speed < 0.5 || FLAGS_speed > 2.0) {
LOG(ERROR) << "Speed must be between 0.5 and 2.0" << std::endl;
return -1;
}
request.set_speed(FLAGS_speed);
if (not FLAGS_zero_shot_audio_prompt.empty()) {
auto zero_shot_data = request.mutable_zero_shot_data();
std::vector<std::shared_ptr<WaveData>> audio_prompt;
Expand Down Expand Up @@ -225,6 +220,11 @@ main(int argc, char** argv)
if (not FLAGS_online and not FLAGS_zero_shot_transcript.empty()) {
zero_shot_data->set_transcript(FLAGS_zero_shot_transcript);
}
if (FLAGS_exaggeration_factor < 0.0 || FLAGS_exaggeration_factor > 2.0) {
LOG(ERROR) << "Exaggeration factor must be between 0.0 and 2.0" << std::endl;
return -1;
}
zero_shot_data->set_exaggeration_factor(FLAGS_exaggeration_factor);
}

// Send text content using Synthesize().
Expand Down
35 changes: 18 additions & 17 deletions riva/clients/tts/riva_tts_perf_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ DEFINE_string(
DEFINE_int32(zero_shot_quality, 20, "Required quality of output audio, ranges between 1-40.");
DEFINE_string(custom_dictionary, "", " User dictionary containing graph-to-phone custom words");
DEFINE_string(zero_shot_transcript, "", "Transcript corresponding to Zero shot audio prompt.");
DEFINE_double(speed, 1.0, "Speed of generated audio, ranges between 0.5-2.0");
DEFINE_double(
exaggeration_factor, 1.0, "Exaggeration factor for generated audio, ranges between 0.0-2.0");

static const std::string LC_enUS = "en-US";

Expand Down Expand Up @@ -115,19 +116,14 @@ synthesizeBatch(
std::unique_ptr<nr_tts::RivaSpeechSynthesis::Stub> tts, std::string text, std::string language,
uint32_t rate, std::string voice_name, std::string filepath,
std::string zero_shot_prompt_filename, int32_t zero_shot_quality, std::string custom_dictionary,
std::string zero_shot_transcript, double speed)
std::string zero_shot_transcript, double exaggeration_factor)
{
// Parse command line arguments.
nr_tts::SynthesizeSpeechRequest request;
request.set_text(text);
request.set_language_code(language);
request.set_sample_rate_hz(rate);
request.set_voice_name(voice_name);
if (speed < 0.5 || speed > 2.0) {
LOG(ERROR) << "Speed must be between 0.5 and 2.0" << std::endl;
return -1;
}
request.set_speed(speed);
if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") {
request.set_encoding(nr::LINEAR_PCM);
} else if (FLAGS_audio_encoding == "opus") {
Expand Down Expand Up @@ -168,6 +164,11 @@ synthesizeBatch(
if (not FLAGS_zero_shot_transcript.empty()) {
zero_shot_data->set_transcript(FLAGS_zero_shot_transcript);
}
if (exaggeration_factor < 0.0 || exaggeration_factor > 2.0) {
LOG(ERROR) << "Exaggeration factor must be between 0.0 and 2.0" << std::endl;
return -1;
}
zero_shot_data->set_exaggeration_factor(exaggeration_factor);
}

// Send text content using Synthesize().
Expand Down Expand Up @@ -211,18 +212,13 @@ synthesizeOnline(
std::unique_ptr<nr_tts::RivaSpeechSynthesis::Stub> tts, std::string text, std::string language,
uint32_t rate, std::string voice_name, double* time_to_first_chunk,
std::vector<double>* time_to_next_chunk, size_t* num_samples, std::string filepath,
std::string zero_shot_prompt_filename, int32_t zero_shot_quality, double speed)
std::string zero_shot_prompt_filename, int32_t zero_shot_quality, double exaggeration_factor)
{
nr_tts::SynthesizeSpeechRequest request;
request.set_text(text);
request.set_language_code(language);
request.set_sample_rate_hz(rate);
request.set_voice_name(voice_name);
if (speed < 0.5 || speed > 2.0) {
LOG(ERROR) << "Speed must be between 0.5 and 2.0" << std::endl;
return;
}
request.set_speed(speed);
auto ae = nr::AudioEncoding::ENCODING_UNSPECIFIED;
if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") {
ae = nr::LINEAR_PCM;
Expand Down Expand Up @@ -260,6 +256,11 @@ synthesizeOnline(
}
zero_shot_data->set_sample_rate_hz(zero_shot_sample_rate);
zero_shot_data->set_quality(zero_shot_quality);
if (exaggeration_factor < 0.0 || exaggeration_factor > 2.0) {
LOG(ERROR) << "Exaggeration factor must be between 0.0 and 2.0" << std::endl;
return;
}
zero_shot_data->set_exaggeration_factor(exaggeration_factor);
}


Expand Down Expand Up @@ -366,8 +367,7 @@ main(int argc, char** argv)
str_usage << " --zero_shot_quality=<quality>" << std::endl;
str_usage << " --zero_shot_transcript=<text>" << std::endl;
str_usage << " --custom_dictionary=<filename> " << std::endl;
str_usage << " --speed=<speed> " << std::endl;

str_usage << " --exaggeration_factor=<exaggeration_factor> " << std::endl;
gflags::SetUsageMessage(str_usage.str());
gflags::SetVersionString(::riva::utils::kBuildScmRevision);

Expand Down Expand Up @@ -494,7 +494,7 @@ main(int argc, char** argv)
std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name,
&time_to_first_chunk, time_to_next_chunk, &num_samples,
std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt,
FLAGS_zero_shot_quality, FLAGS_speed);
FLAGS_zero_shot_quality, FLAGS_exaggeration_factor);
latencies_first_chunk[i]->push_back(time_to_first_chunk);
latencies_next_chunks[i]->insert(
latencies_next_chunks[i]->end(), time_to_next_chunk->begin(),
Expand Down Expand Up @@ -570,7 +570,8 @@ main(int argc, char** argv)
int32_t num_samples = synthesizeBatch(
std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name,
std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt,
FLAGS_zero_shot_quality, FLAGS_custom_dictionary, FLAGS_zero_shot_transcript, FLAGS_speed);
FLAGS_zero_shot_quality, FLAGS_custom_dictionary, FLAGS_zero_shot_transcript,
FLAGS_exaggeration_factor);
results_num_samples[i]->push_back(num_samples);
}
}));
Expand Down