29 #include <grpc++/grpc++.h> 30 #include <grpc/impl/codegen/connectivity_state.h> 42 "Timeout (seconds) to connect to gRPC service.");
45 "Timeout (seconds) to wait for input samples.");
47 namespace gspeech = ::google::cloud::speech::v1;
49 using gspeech::RecognitionConfig;
50 using gspeech::StreamingRecognizeRequest;
51 using gspeech::StreamingRecognizeResponse;
54 using std::unique_ptr;
66 "speech.googleapis.com", grpc::GoogleDefaultCredentials());
80 const std::vector<string>& hints,
int max_audio_seconds,
81 int max_wait_seconds,
int max_alternatives) {
94 "Recognizer just started, nothing received yet.");
97 this, audio_queue, result_queue, hints, max_audio_seconds,
98 max_wait_seconds, max_alternatives] {
100 max_wait_seconds, max_alternatives);
156 const std::vector<string>& hints,
int max_audio_seconds,
157 int max_wait_seconds,
int max_alternatives) {
161 bool fail_flag =
false;
163 if (max_audio_seconds < 1) {
166 "max_audio_seconds must be greater than 0");
171 if (max_audio_seconds >= 65) {
174 "max_audio_seconds must be less than 65");
179 if (max_wait_seconds <= max_audio_seconds) {
182 "max_wait_seconds must be greater than max_audio_seconds.");
189 LOG(ERROR) <<
"There are some errors on preconditions, " 190 <<
"finishing RecognitionThread.";
194 LOG(INFO) <<
"RecognitionThread started, will listen for " 195 << max_audio_seconds <<
" seconds, and will return in " 196 << max_wait_seconds <<
" seconds.";
199 std::chrono::system_clock::time_point deadline
200 = std::chrono::system_clock::now() + std::chrono::seconds(max_wait_seconds);
201 double left_audio_time = max_audio_seconds;
204 grpc::ClientContext context;
205 grpc::CompletionQueue completion_queue;
207 uintptr_t cq_seq = 0;
214 std::chrono::system_clock::time_point stream_connect_deadline =
215 std::chrono::system_clock::now()
216 + std::chrono::seconds(FLAGS_grpc_speech_connect_timeout_secs);
221 if (!
channel_->WaitForConnected(stream_connect_deadline)) {
223 "gRPC error: Channel connection took too long.");
229 string channel_state =
"";
230 grpc_connectivity_state channel_state_enum =
channel_->GetState(
false);
231 switch (channel_state_enum) {
232 case GRPC_CHANNEL_IDLE:
233 channel_state =
"GRPC_CHANNEL_IDLE";
235 case GRPC_CHANNEL_CONNECTING:
236 channel_state =
"GRPC_CHANNEL_CONNECTING";
238 case GRPC_CHANNEL_READY:
239 channel_state =
"GRPC_CHANNEL_READY";
241 case GRPC_CHANNEL_TRANSIENT_FAILURE:
242 channel_state =
"GRPC_CHANNEL_TRANSIENT_FAILURE";
244 case GRPC_CHANNEL_SHUTDOWN:
245 channel_state =
"GRPC_CHANNEL_SHUTDOWN";
248 CHECK(
false) <<
"Unknown channel state: " << channel_state_enum;
252 if (channel_state_enum == GRPC_CHANNEL_READY) {
253 LOG(INFO) <<
"Channel state is " << channel_state;
255 LOG(ERROR) <<
"Channel state is " << channel_state;
260 LOG(ERROR) <<
"There are some errors on gRPC channels, " 261 <<
"finishing RecognitionThread.";
268 &context, &completion_queue, reinterpret_cast<void*>(++cq_seq));
269 uintptr_t stream_cq_seq = cq_seq;
270 LOG(INFO) <<
"Start a call to Google Speech gRPC server, cq_seq: " 277 void* stream_cq_tag =
nullptr;
278 bool stream_cq_ok =
false;
280 grpc::CompletionQueue::NextStatus stream_cq_state
281 = completion_queue.AsyncNext(&stream_cq_tag, &stream_cq_ok,
282 stream_connect_deadline);
283 if (stream_cq_state == grpc::CompletionQueue::GOT_EVENT) {
286 "gRPC error: Stream failed to create.");
292 if (stream_cq_tag == reinterpret_cast<void*>(stream_cq_seq)) {
294 LOG(INFO) <<
"gRPC created stream, tag is " 295 <<
reinterpret_cast<uintptr_t
>(stream_cq_tag);
298 "gRPC fatal error: wrong stream creation tag.");
302 }
else if (stream_cq_state == grpc::CompletionQueue::TIMEOUT) {
304 "gRPC error: Stream creation timed out.");
307 }
else if (stream_cq_state == grpc::CompletionQueue::SHUTDOWN) {
309 "gRPC gRPC server shuted down the connection.");
315 gspeech::StreamingRecognizeRequest config_request;
316 auto* streaming_config = config_request.mutable_streaming_config();
317 streaming_config->set_single_utterance(
false);
318 streaming_config->set_interim_results(
true);
319 streaming_config->mutable_config()->set_encoding(RecognitionConfig::LINEAR16);
320 streaming_config->mutable_config()->set_sample_rate_hertz(
SAMPLE_RATE);
321 streaming_config->mutable_config()->set_language_code(
"en-US");
322 streaming_config->mutable_config()->set_max_alternatives(max_alternatives);
324 gspeech::SpeechContext* speech_context =
325 streaming_config->mutable_config()->add_speech_contexts();
326 for (
const string& hint : hints) {
327 speech_context->add_phrases(hint);
331 uintptr_t last_write_seq = 0;
332 uintptr_t write_done_seq = 0;
333 bool allow_write =
true;
334 bool write_done_issued =
false;
335 bool write_done_finished =
false;
336 bool final_read_returned =
false;
339 uintptr_t last_read_seq = 0;
340 gspeech::StreamingRecognizeResponse response;
341 bool allow_read =
true;
350 config_request, reinterpret_cast<void*>(last_write_seq = ++cq_seq));
352 LOG(INFO) <<
"Wrote speech config to Google Speech gRPC server with tag " 359 streamer->Read(&response,
360 reinterpret_cast<void*>(last_read_seq = ++cq_seq));
362 LOG(INFO) <<
"Issued first Read request with tag " << cq_seq;
369 while (!fail_flag && std::chrono::system_clock::now() < deadline) {
375 std::chrono::system_clock::time_point cq_deadline
376 = std::chrono::system_clock::now() + std::chrono::milliseconds(10);
379 void* cq_tag =
nullptr;
383 grpc::CompletionQueue::NextStatus cq_state = completion_queue.AsyncNext(
384 &cq_tag, &cq_ok, cq_deadline);
386 if (cq_state == grpc::CompletionQueue::GOT_EVENT) {
387 if (cq_tag == reinterpret_cast<void*>(last_read_seq)) {
388 LOG(INFO) <<
"Got Read result from completion queue, cq_tag is " 389 <<
reinterpret_cast<uintptr_t
>(cq_tag)
390 <<
", cq_ok is " << cq_ok;
392 if (write_done_finished) {
393 final_read_returned =
true;
408 LOG(INFO) <<
"Got recognition response from server " 409 << response.DebugString();
412 if (response.has_error()) {
415 LOG(ERROR) <<
"Response has error: " << response.error().message();
420 for (
auto& result_record : response.results()) {
422 RecognitionResult recog_result;
423 recog_result.set_is_final(result_record.is_final());
424 recog_result.set_stability(result_record.stability());
425 for (
auto& alternative : result_record.alternatives()) {
426 RecognitionResult::Candidate* can = recog_result.add_candidates();
427 can->set_transcript(alternative.transcript());
428 can->set_confidence(alternative.confidence());
432 result_queue->
push(recog_result);
439 &response, reinterpret_cast<void*>(last_read_seq = ++cq_seq));
441 LOG(INFO) <<
"Issued Read request with tag " << last_read_seq;
442 }
else if (cq_tag == reinterpret_cast<void*>(last_write_seq)) {
443 LOG(INFO) <<
"Got Write result from completion queue, cq_tag is " 444 <<
reinterpret_cast<uintptr_t
>(cq_tag)
445 <<
", cq_ok is " << cq_ok;
455 if (!write_done_issued) {
458 }
else if (cq_tag == reinterpret_cast<void*>(write_done_seq)) {
459 LOG(INFO) <<
"Got WritesDone result from completion queue, cq_tag is " 460 <<
reinterpret_cast<uintptr_t
>(cq_tag)
461 <<
", cq_ok is " << cq_ok;
469 write_done_finished =
true;
472 LOG(ERROR) <<
"Got unknown result from completion queue, cq_tag is " 473 <<
reinterpret_cast<uintptr_t
>(cq_tag)
474 <<
", cq_ok is " << cq_ok;
479 "Got non-read/write results, and it is a failure.");
484 }
else if (cq_state == grpc::CompletionQueue::TIMEOUT) {
485 LOG(INFO) <<
"CompletionQueue timeout (expected behavior).";
487 }
else if (cq_state == grpc::CompletionQueue::SHUTDOWN) {
489 LOG(ERROR) <<
"gRPC server shuted down the connection.";
498 if (final_read_returned) {
502 if (allow_write && !write_done_issued) {
504 if (left_audio_time > 0 && !
stop_flag_.load()) {
506 gspeech::StreamingRecognizeRequest request;
511 = audio_queue->
blocking_pop(FLAGS_gspeech_wait_input_timeout_msecs);
512 if (pop_result.
ok()) {
514 left_audio_time -=
static_cast<double>(audio_sample->size())
517 request.set_audio_content(audio_sample->data(), audio_sample->size());
518 LOG(INFO) <<
"Prepare to send " << audio_sample->size()
519 <<
" bytes of data to the server.";
522 request, reinterpret_cast<void*>(last_write_seq = ++cq_seq));
524 LOG(INFO) <<
"Issued Write request with tag " << last_write_seq;
526 LOG(WARNING) <<
"Read from audio_queue failed.";
531 CHECK(!write_done_issued);
532 streamer->WritesDone(
533 reinterpret_cast<void*>(write_done_seq = ++cq_seq));
534 LOG(INFO) <<
"Issued WritesDone request with tag " << write_done_seq;
536 write_done_issued =
true;
542 LOG(ERROR) <<
"gRPC failed, continue to finish the requests. " 543 <<
"There could be more failures, be prepared.";
546 bool late_issue_write_done =
false;
547 if (!write_done_issued) {
548 LOG(ERROR) <<
"WritesDone not issued, but we are not streaming anymore. " 549 <<
"Is max_wait_seconds is too close to max_audio_seconds? " 550 <<
"Is network connection too slow?";
552 streamer->WritesDone(reinterpret_cast<void*>(write_done_seq = ++cq_seq));
554 write_done_issued =
true;
555 late_issue_write_done =
true;
556 LOG(INFO) <<
"Issued WritesDone request with tag " << write_done_seq;
558 LOG(ERROR) <<
"WritesDone is not yet issued, but not able to write.";
562 grpc::Status finish_status;
563 uintptr_t finish_seq = 0;
564 streamer->Finish(&finish_status,
565 reinterpret_cast<void*>(finish_seq = ++cq_seq));
566 LOG(INFO) <<
"Issued Finish request with tag " << finish_seq;
567 completion_queue.Shutdown();
570 bool tried_cancel =
false;
572 void* cq_tag =
nullptr;
575 std::chrono::system_clock::time_point cq_deadline
576 = std::chrono::system_clock::now() + std::chrono::milliseconds(100);
579 grpc::CompletionQueue::NextStatus cq_state = completion_queue.AsyncNext(
580 &cq_tag, &cq_ok, cq_deadline);
582 if (cq_state == grpc::CompletionQueue::GOT_EVENT) {
583 if (cq_tag == reinterpret_cast<void*>(last_read_seq)) {
588 LOG(WARNING) <<
"Got unexpected Read from completion queue, cq_tag is " 589 <<
reinterpret_cast<uintptr_t
>(cq_tag)
590 <<
", cq_ok is " << cq_ok;
591 }
else if (cq_tag == reinterpret_cast<void*>(last_write_seq)) {
592 LOG(ERROR) <<
"Got unexpected Write from completion queue, cq_tag is " 593 <<
reinterpret_cast<uintptr_t
>(cq_tag)
594 <<
", cq_ok is " << cq_ok;
595 }
else if (cq_tag == reinterpret_cast<void*>(write_done_seq)) {
596 if (late_issue_write_done) {
597 LOG(WARNING) <<
"Got WritesDone (late issued) from completion queue, " 598 <<
"cq_tag is " <<
reinterpret_cast<uintptr_t
>(cq_tag)
599 <<
", cq_ok is " << cq_ok;
601 LOG(ERROR) <<
"Got WritesDone from completion queue, cq_tag is " 602 <<
reinterpret_cast<uintptr_t
>(cq_tag)
603 <<
", cq_ok is " << cq_ok;
605 }
else if (cq_tag == reinterpret_cast<void*>(finish_seq)) {
606 if (!finish_status.ok()) {
609 LOG(ERROR) <<
"Finish not OK: " << finish_status.
error_message();
611 LOG(INFO) <<
"Finish OK.";
614 LOG(ERROR) <<
"Got unexpected event from completion queue, cq_tag is " 615 <<
reinterpret_cast<uintptr_t
>(cq_tag)
616 <<
", cq_ok is " << cq_ok;
618 }
else if (cq_state == grpc::CompletionQueue::TIMEOUT) {
619 LOG(ERROR) <<
"CompletionQueue AsyncNext timeout";
628 }
else if (cq_state == grpc::CompletionQueue::SHUTDOWN) {
630 LOG(INFO) <<
"CompletionQueue Shutdown";
636 LOG(INFO) <<
"RecognitionThread finished";
bool IsRunning() override
void RecognitionThread(AudioQueue *audio_queue, util::SimpleThreadSafeQueue< RecognitionResult > *result_queue, const std::vector< std::string > &hints, int max_audio_seconds, int max_wait_seconds, int max_alternatives)
std::unique_ptr< std::thread > thread_
std::atomic_bool stop_flag_
util::Status StartRecognize(AudioQueue *audio_queue, util::SimpleThreadSafeQueue< RecognitionResult > *result_queue, const std::vector< std::string > &hints, int max_audio_seconds, int max_wait_seconds, int max_alternatives) override
std::string error_message() const
util::StatusOr< RecognitionResult > latest_result_
util::Status Stop() override
std::unique_ptr<::google::cloud::speech::v1::Speech::Stub > gspeech_stub_
constexpr size_t SAMPLE_RATE
void push(Args &&...args)
std::atomic_bool done_flag_
const Status & status() const
util::StatusOr< RecognitionResult > GetLastResult() override
util::Status Wait() override
std::mutex general_mutex_
std::shared_ptr< grpc::ChannelInterface > channel_
~GoogleSpeechRecognizer() override
DEFINE_int32(grpc_speech_connect_timeout_secs, 5,"Timeout (seconds) to connect to gRPC service.")