00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027 #include "cogrob/cloud/speech/google_speech.h"
00028
00029 #include <grpc++/grpc++.h>
00030 #include <grpc/impl/codegen/connectivity_state.h>
00031 #include <chrono>
00032 #include <memory>
00033 #include <string>
00034 #include <utility>
00035 #include <vector>
00036
00037 #include "google/cloud/speech/v1/cloud_speech.grpc.pb.h"
00038 #include "third_party/gflags.h"
00039 #include "third_party/glog.h"
00040
00041 DEFINE_int32(grpc_speech_connect_timeout_secs, 5,
00042 "Timeout (seconds) to connect to gRPC service.");
00043
00044 DEFINE_int32(gspeech_wait_input_timeout_msecs, 100,
00045 "Timeout (seconds) to wait for input samples.");
00046
00047 namespace gspeech = ::google::cloud::speech::v1;
00048
00049 using gspeech::RecognitionConfig;
00050 using gspeech::StreamingRecognizeRequest;
00051 using gspeech::StreamingRecognizeResponse;
00052
00053 using std::string;
00054 using std::unique_ptr;
00055 using util::Status;
00056 using util::StatusOr;
00057
00058 namespace cogrob {
00059 namespace cloud {
00060 namespace speech {
00061
00062 GoogleSpeechRecognizer::GoogleSpeechRecognizer() {
00063 std::lock_guard<std::mutex> lock(general_mutex_);
00064
00065 channel_ = grpc::CreateChannel(
00066 "speech.googleapis.com", grpc::GoogleDefaultCredentials());
00067 gspeech_stub_ = std::move(gspeech::Speech::NewStub(channel_));
00068
00069
00070 latest_result_ = Status(
00071 util::error::FAILED_PRECONDITION, "Recognizer not yet started.");
00072 }
00073
00074 GoogleSpeechRecognizer::~GoogleSpeechRecognizer() {
00075 Stop();
00076 }
00077
00078 Status GoogleSpeechRecognizer::StartRecognize(AudioQueue* audio_queue,
00079 util::SimpleThreadSafeQueue<RecognitionResult>* result_queue,
00080 const std::vector<string>& hints, int max_audio_seconds,
00081 int max_wait_seconds, int max_alternatives) {
00082
00083 std::lock_guard<std::mutex> lock(general_mutex_);
00084
00085 if (thread_) {
00086 return Status(
00087 util::error::ALREADY_EXISTS, "Recognizer is already running.");
00088 }
00089
00090
00091 stop_flag_.store(false);
00092 done_flag_.store(false);
00093 latest_result_ = Status(util::error::UNAVAILABLE,
00094 "Recognizer just started, nothing received yet.");
00095
00096 thread_.reset(new std::thread([
00097 this, audio_queue, result_queue, hints, max_audio_seconds,
00098 max_wait_seconds, max_alternatives] {
00099 RecognitionThread(audio_queue, result_queue, hints, max_audio_seconds,
00100 max_wait_seconds, max_alternatives);
00101 }));
00102 return Status::OK;
00103 }
00104
00105 StatusOr<RecognitionResult> GoogleSpeechRecognizer::GetLastResult() {
00106
00107 std::lock_guard<std::mutex> lock(general_mutex_);
00108 return latest_result_;
00109 }
00110
00111
00112 Status GoogleSpeechRecognizer::Wait() {
00113
00114 std::lock_guard<std::mutex> lock(general_mutex_);
00115 if (thread_) {
00116 thread_->join();
00117 thread_.reset(nullptr);
00118 }
00119 return Status::OK;
00120 }
00121
00122
00123 Status GoogleSpeechRecognizer::Stop() {
00124
00125 std::lock_guard<std::mutex> lock(general_mutex_);
00126
00127
00128 if (thread_) {
00129 stop_flag_.store(true);
00130 thread_->join();
00131 thread_.reset(nullptr);
00132 }
00133 return Status::OK;
00134 }
00135
00136
00137 bool GoogleSpeechRecognizer::IsRunning() {
00138
00139 std::lock_guard<std::mutex> lock(general_mutex_);
00140
00141
00142 if (done_flag_.load() && thread_) {
00143 thread_->join();
00144 thread_.reset(nullptr);
00145 }
00146
00147 if (thread_) {
00148 return true;
00149 }
00150 return false;
00151 }
00152
00153
00154 void GoogleSpeechRecognizer::RecognitionThread(AudioQueue* audio_queue,
00155 util::SimpleThreadSafeQueue<RecognitionResult>* result_queue,
00156 const std::vector<string>& hints, int max_audio_seconds,
00157 int max_wait_seconds, int max_alternatives) {
00158
00159
00160
00161 bool fail_flag = false;
00162
00163 if (max_audio_seconds < 1) {
00164 latest_result_ = Status(
00165 util::error::FAILED_PRECONDITION,
00166 "max_audio_seconds must be greater than 0");
00167 LOG(ERROR) << latest_result_.status();
00168 fail_flag = true;
00169 }
00170
00171 if (max_audio_seconds >= 65) {
00172 latest_result_ = Status(
00173 util::error::FAILED_PRECONDITION,
00174 "max_audio_seconds must be less than 65");
00175 LOG(ERROR) << latest_result_.status();
00176 fail_flag = true;
00177 }
00178
00179 if (max_wait_seconds <= max_audio_seconds) {
00180 latest_result_ = Status(
00181 util::error::FAILED_PRECONDITION,
00182 "max_wait_seconds must be greater than max_audio_seconds.");
00183 LOG(ERROR) << latest_result_.status();
00184 fail_flag = true;
00185 }
00186
00187 if (fail_flag) {
00188 done_flag_.store(true);
00189 LOG(ERROR) << "There are some errors on preconditions, "
00190 << "finishing RecognitionThread.";
00191 return;
00192 }
00193
00194 LOG(INFO) << "RecognitionThread started, will listen for "
00195 << max_audio_seconds << " seconds, and will return in "
00196 << max_wait_seconds << " seconds.";
00197
00198
00199 std::chrono::system_clock::time_point deadline
00200 = std::chrono::system_clock::now() + std::chrono::seconds(max_wait_seconds);
00201 double left_audio_time = max_audio_seconds;
00202
00203
00204 grpc::ClientContext context;
00205 grpc::CompletionQueue completion_queue;
00206
00207 uintptr_t cq_seq = 0;
00208
00209
00210
00211
00212
00213
00214 std::chrono::system_clock::time_point stream_connect_deadline =
00215 std::chrono::system_clock::now()
00216 + std::chrono::seconds(FLAGS_grpc_speech_connect_timeout_secs);
00217
00218
00219
00220 channel_->GetState(true);
00221 if (!channel_->WaitForConnected(stream_connect_deadline)) {
00222 latest_result_ = Status(util::error::ABORTED,
00223 "gRPC error: Channel connection took too long.");
00224 LOG(ERROR) << latest_result_.status();
00225 fail_flag = true;
00226 }
00227
00228
00229 string channel_state = "";
00230 grpc_connectivity_state channel_state_enum = channel_->GetState(false);
00231 switch (channel_state_enum) {
00232 case GRPC_CHANNEL_IDLE:
00233 channel_state = "GRPC_CHANNEL_IDLE";
00234 break;
00235 case GRPC_CHANNEL_CONNECTING:
00236 channel_state = "GRPC_CHANNEL_CONNECTING";
00237 break;
00238 case GRPC_CHANNEL_READY:
00239 channel_state = "GRPC_CHANNEL_READY";
00240 break;
00241 case GRPC_CHANNEL_TRANSIENT_FAILURE:
00242 channel_state = "GRPC_CHANNEL_TRANSIENT_FAILURE";
00243 break;
00244 case GRPC_CHANNEL_SHUTDOWN:
00245 channel_state = "GRPC_CHANNEL_SHUTDOWN";
00246 break;
00247 default:
00248 CHECK(false) << "Unknown channel state: " << channel_state_enum;
00249 break;
00250 }
00251
00252 if (channel_state_enum == GRPC_CHANNEL_READY) {
00253 LOG(INFO) << "Channel state is " << channel_state;
00254 } else {
00255 LOG(ERROR) << "Channel state is " << channel_state;
00256 }
00257
00258 if (fail_flag) {
00259 done_flag_.store(true);
00260 LOG(ERROR) << "There are some errors on gRPC channels, "
00261 << "finishing RecognitionThread.";
00262 return;
00263 }
00264
00265
00266 auto streamer
00267 = gspeech_stub_->AsyncStreamingRecognize(
00268 &context, &completion_queue, reinterpret_cast<void*>(++cq_seq));
00269 uintptr_t stream_cq_seq = cq_seq;
00270 LOG(INFO) << "Start a call to Google Speech gRPC server, cq_seq: "
00271 << stream_cq_seq;
00272
00273
00274
00275
00276
00277 void* stream_cq_tag = nullptr;
00278 bool stream_cq_ok = false;
00279
00280 grpc::CompletionQueue::NextStatus stream_cq_state
00281 = completion_queue.AsyncNext(&stream_cq_tag, &stream_cq_ok,
00282 stream_connect_deadline);
00283 if (stream_cq_state == grpc::CompletionQueue::GOT_EVENT) {
00284 if (!stream_cq_ok) {
00285 latest_result_ = Status(util::error::ABORTED,
00286 "gRPC error: Stream failed to create.");
00287 LOG(ERROR) << latest_result_.status();
00288 fail_flag = true;
00289 }
00290
00291
00292 if (stream_cq_tag == reinterpret_cast<void*>(stream_cq_seq)) {
00293
00294 LOG(INFO) << "gRPC created stream, tag is "
00295 << reinterpret_cast<uintptr_t>(stream_cq_tag);
00296 } else {
00297 latest_result_ = Status(util::error::ABORTED,
00298 "gRPC fatal error: wrong stream creation tag.");
00299 LOG(ERROR) << latest_result_.status();
00300 fail_flag = true;
00301 }
00302 } else if (stream_cq_state == grpc::CompletionQueue::TIMEOUT) {
00303 latest_result_ = Status(util::error::ABORTED,
00304 "gRPC error: Stream creation timed out.");
00305 LOG(ERROR) << latest_result_.status();
00306 fail_flag = true;
00307 } else if (stream_cq_state == grpc::CompletionQueue::SHUTDOWN) {
00308 latest_result_ = Status(util::error::ABORTED,
00309 "gRPC gRPC server shuted down the connection.");
00310 LOG(ERROR) << latest_result_.status();
00311 fail_flag = true;
00312 }
00313
00314
00315 gspeech::StreamingRecognizeRequest config_request;
00316 auto* streaming_config = config_request.mutable_streaming_config();
00317 streaming_config->set_single_utterance(false);
00318 streaming_config->set_interim_results(true);
00319 streaming_config->mutable_config()->set_encoding(RecognitionConfig::LINEAR16);
00320 streaming_config->mutable_config()->set_sample_rate_hertz(SAMPLE_RATE);
00321 streaming_config->mutable_config()->set_language_code("en-US");
00322 streaming_config->mutable_config()->set_max_alternatives(max_alternatives);
00323
00324 gspeech::SpeechContext* speech_context =
00325 streaming_config->mutable_config()->add_speech_contexts();
00326 for (const string& hint : hints) {
00327 speech_context->add_phrases(hint);
00328 }
00329
00330
00331 uintptr_t last_write_seq = 0;
00332 uintptr_t write_done_seq = 0;
00333 bool allow_write = true;
00334 bool write_done_issued = false;
00335 bool write_done_finished = false;
00336 bool final_read_returned = false;
00337
00338
00339 uintptr_t last_read_seq = 0;
00340 gspeech::StreamingRecognizeResponse response;
00341 bool allow_read = true;
00342
00343
00344
00345 if (!fail_flag) {
00346
00347
00348 CHECK(allow_write);
00349 streamer->Write(
00350 config_request, reinterpret_cast<void*>(last_write_seq = ++cq_seq));
00351 allow_write = false;
00352 LOG(INFO) << "Wrote speech config to Google Speech gRPC server with tag "
00353 << last_write_seq;
00354
00355
00356
00357
00358 CHECK(allow_read);
00359 streamer->Read(&response,
00360 reinterpret_cast<void*>(last_read_seq = ++cq_seq));
00361 allow_read = false;
00362 LOG(INFO) << "Issued first Read request with tag " << cq_seq;
00363 }
00364
00365
00366
00367
00368
00369 while (!fail_flag && std::chrono::system_clock::now() < deadline) {
00370
00371 while (true) {
00372
00373
00374
00375 std::chrono::system_clock::time_point cq_deadline
00376 = std::chrono::system_clock::now() + std::chrono::milliseconds(10);
00377
00378
00379 void* cq_tag = nullptr;
00380 bool cq_ok = false;
00381
00382
00383 grpc::CompletionQueue::NextStatus cq_state = completion_queue.AsyncNext(
00384 &cq_tag, &cq_ok, cq_deadline);
00385
00386 if (cq_state == grpc::CompletionQueue::GOT_EVENT) {
00387 if (cq_tag == reinterpret_cast<void*>(last_read_seq)) {
00388 LOG(INFO) << "Got Read result from completion queue, cq_tag is "
00389 << reinterpret_cast<uintptr_t>(cq_tag)
00390 << ", cq_ok is " << cq_ok;
00391 if (!cq_ok) {
00392 if (write_done_finished) {
00393 final_read_returned = true;
00394 break;
00395 } else {
00396 fail_flag = true;
00397 latest_result_ = Status(
00398 util::error::INTERNAL, "gRPC Read failed.");
00399 LOG(ERROR) << latest_result_.status();
00400 break;
00401 }
00402 }
00403
00404
00405 allow_read = true;
00406
00407
00408 LOG(INFO) << "Got recognition response from server "
00409 << response.DebugString();
00410
00411
00412 if (response.has_error()) {
00413 latest_result_ = Status(
00414 util::error::INTERNAL, response.error().message());
00415 LOG(ERROR) << "Response has error: " << response.error().message();
00416 fail_flag = true;
00417 break;
00418 }
00419
00420 for (auto& result_record : response.results()) {
00421
00422 RecognitionResult recog_result;
00423 recog_result.set_is_final(result_record.is_final());
00424 recog_result.set_stability(result_record.stability());
00425 for (auto& alternative : result_record.alternatives()) {
00426 RecognitionResult::Candidate* can = recog_result.add_candidates();
00427 can->set_transcript(alternative.transcript());
00428 can->set_confidence(alternative.confidence());
00429 }
00430 latest_result_ = recog_result;
00431 if (result_queue) {
00432 result_queue->push(recog_result);
00433 }
00434 }
00435
00436
00437 CHECK(allow_read);
00438 streamer->Read(
00439 &response, reinterpret_cast<void*>(last_read_seq = ++cq_seq));
00440 allow_read = false;
00441 LOG(INFO) << "Issued Read request with tag " << last_read_seq;
00442 } else if (cq_tag == reinterpret_cast<void*>(last_write_seq)) {
00443 LOG(INFO) << "Got Write result from completion queue, cq_tag is "
00444 << reinterpret_cast<uintptr_t>(cq_tag)
00445 << ", cq_ok is " << cq_ok;
00446 if (!cq_ok) {
00447 fail_flag = true;
00448 latest_result_ = Status(
00449 util::error::INTERNAL, "gRPC Write failed.");
00450 LOG(ERROR) << latest_result_.status();
00451 break;
00452 }
00453
00454
00455 if (!write_done_issued) {
00456 allow_write = true;
00457 }
00458 } else if (cq_tag == reinterpret_cast<void*>(write_done_seq)) {
00459 LOG(INFO) << "Got WritesDone result from completion queue, cq_tag is "
00460 << reinterpret_cast<uintptr_t>(cq_tag)
00461 << ", cq_ok is " << cq_ok;
00462 if (!cq_ok) {
00463 fail_flag = true;
00464 latest_result_ = Status(
00465 util::error::INTERNAL, "gRPC WritesDone failed.");
00466 LOG(ERROR) << latest_result_.status();
00467 break;
00468 }
00469 write_done_finished = true;
00470 } else {
00471
00472 LOG(ERROR) << "Got unknown result from completion queue, cq_tag is "
00473 << reinterpret_cast<uintptr_t>(cq_tag)
00474 << ", cq_ok is " << cq_ok;
00475 if (!cq_ok) {
00476 fail_flag = true;
00477 latest_result_ = Status(
00478 util::error::INTERNAL,
00479 "Got non-read/write results, and it is a failure.");
00480 LOG(ERROR) << latest_result_.status();
00481 break;
00482 }
00483 }
00484 } else if (cq_state == grpc::CompletionQueue::TIMEOUT) {
00485 LOG(INFO) << "CompletionQueue timeout (expected behavior).";
00486 break;
00487 } else if (cq_state == grpc::CompletionQueue::SHUTDOWN) {
00488 fail_flag = true;
00489 LOG(ERROR) << "gRPC server shuted down the connection.";
00490 break;
00491 }
00492 }
00493
00494 if (fail_flag) {
00495 break;
00496 }
00497
00498 if (final_read_returned) {
00499 break;
00500 }
00501
00502 if (allow_write && !write_done_issued) {
00503
00504 if (left_audio_time > 0 && !stop_flag_.load()) {
00505
00506 gspeech::StreamingRecognizeRequest request;
00507
00508
00509
00510 StatusOr<unique_ptr<AudioSample>> pop_result
00511 = audio_queue->blocking_pop(FLAGS_gspeech_wait_input_timeout_msecs);
00512 if (pop_result.ok()) {
00513 unique_ptr<AudioSample> audio_sample = pop_result.ConsumeValueOrDie();
00514 left_audio_time -= static_cast<double>(audio_sample->size())
00515 / (SAMPLE_RATE * 2);
00516
00517 request.set_audio_content(audio_sample->data(), audio_sample->size());
00518 LOG(INFO) << "Prepare to send " << audio_sample->size()
00519 << " bytes of data to the server.";
00520 CHECK(allow_write);
00521 streamer->Write(
00522 request, reinterpret_cast<void*>(last_write_seq = ++cq_seq));
00523 allow_write = false;
00524 LOG(INFO) << "Issued Write request with tag " << last_write_seq;
00525 } else {
00526 LOG(WARNING) << "Read from audio_queue failed.";
00527 }
00528 } else {
00529
00530 CHECK(allow_write);
00531 CHECK(!write_done_issued);
00532 streamer->WritesDone(
00533 reinterpret_cast<void*>(write_done_seq = ++cq_seq));
00534 LOG(INFO) << "Issued WritesDone request with tag " << write_done_seq;
00535 allow_write = false;
00536 write_done_issued = true;
00537 }
00538 }
00539 }
00540
00541 if (fail_flag) {
00542 LOG(ERROR) << "gRPC failed, continue to finish the requests. "
00543 << "There could be more failures, be prepared.";
00544 }
00545
00546 bool late_issue_write_done = false;
00547 if (!write_done_issued) {
00548 LOG(ERROR) << "WritesDone not issued, but we are not streaming anymore. "
00549 << "Is max_wait_seconds is too close to max_audio_seconds? "
00550 << "Is network connection too slow?";
00551 if (allow_write) {
00552 streamer->WritesDone(reinterpret_cast<void*>(write_done_seq = ++cq_seq));
00553 allow_write = false;
00554 write_done_issued = true;
00555 late_issue_write_done = true;
00556 LOG(INFO) << "Issued WritesDone request with tag " << write_done_seq;
00557 } else {
00558 LOG(ERROR) << "WritesDone is not yet issued, but not able to write.";
00559 }
00560 }
00561
00562 grpc::Status finish_status;
00563 uintptr_t finish_seq = 0;
00564 streamer->Finish(&finish_status,
00565 reinterpret_cast<void*>(finish_seq = ++cq_seq));
00566 LOG(INFO) << "Issued Finish request with tag " << finish_seq;
00567 completion_queue.Shutdown();
00568
00569
00570 bool tried_cancel = false;
00571 while (true) {
00572 void* cq_tag = nullptr;
00573 bool cq_ok = false;
00574
00575 std::chrono::system_clock::time_point cq_deadline
00576 = std::chrono::system_clock::now() + std::chrono::milliseconds(100);
00577
00578
00579 grpc::CompletionQueue::NextStatus cq_state = completion_queue.AsyncNext(
00580 &cq_tag, &cq_ok, cq_deadline);
00581
00582 if (cq_state == grpc::CompletionQueue::GOT_EVENT) {
00583 if (cq_tag == reinterpret_cast<void*>(last_read_seq)) {
00584
00585
00586
00587
00588 LOG(WARNING) << "Got unexpected Read from completion queue, cq_tag is "
00589 << reinterpret_cast<uintptr_t>(cq_tag)
00590 << ", cq_ok is " << cq_ok;
00591 } else if (cq_tag == reinterpret_cast<void*>(last_write_seq)) {
00592 LOG(ERROR) << "Got unexpected Write from completion queue, cq_tag is "
00593 << reinterpret_cast<uintptr_t>(cq_tag)
00594 << ", cq_ok is " << cq_ok;
00595 } else if (cq_tag == reinterpret_cast<void*>(write_done_seq)) {
00596 if (late_issue_write_done) {
00597 LOG(WARNING) << "Got WritesDone (late issued) from completion queue, "
00598 << "cq_tag is " << reinterpret_cast<uintptr_t>(cq_tag)
00599 << ", cq_ok is " << cq_ok;
00600 } else {
00601 LOG(ERROR) << "Got WritesDone from completion queue, cq_tag is "
00602 << reinterpret_cast<uintptr_t>(cq_tag)
00603 << ", cq_ok is " << cq_ok;
00604 }
00605 } else if (cq_tag == reinterpret_cast<void*>(finish_seq)) {
00606 if (!finish_status.ok()) {
00607 latest_result_ = Status(
00608 util::error::INTERNAL, finish_status.error_message());
00609 LOG(ERROR) << "Finish not OK: " << finish_status.error_message();
00610 } else {
00611 LOG(INFO) << "Finish OK.";
00612 }
00613 } else {
00614 LOG(ERROR) << "Got unexpected event from completion queue, cq_tag is "
00615 << reinterpret_cast<uintptr_t>(cq_tag)
00616 << ", cq_ok is " << cq_ok;
00617 }
00618 } else if (cq_state == grpc::CompletionQueue::TIMEOUT) {
00619 LOG(ERROR) << "CompletionQueue AsyncNext timeout";
00620
00621
00622 if (!tried_cancel) {
00623 context.TryCancel();
00624 tried_cancel = true;
00625 }
00626
00627
00628 } else if (cq_state == grpc::CompletionQueue::SHUTDOWN) {
00629
00630 LOG(INFO) << "CompletionQueue Shutdown";
00631 break;
00632 }
00633 }
00634
00635 done_flag_.store(true);
00636 LOG(INFO) << "RecognitionThread finished";
00637 }
00638
00639
00640 }
00641 }
00642 }