session.py
Go to the documentation of this file.
00001 # -*- coding: utf-8 -*-
00002 
00003 # python libraries
00004 import abc
00005 import threading
00006 from Queue import Queue
00007 from Queue import Empty
00008 
00009 # local libraries
00010 from rospeex_core import logging_util
00011 from rospeex_core import exceptions as ext
00012 
00013 
00014 __all__ = [
00015     'ISession',
00016     'Session',
00017     'IState',
00018     'SessionState',
00019     'PacketType',
00020 ]
00021 
00022 
00023 # get logger
00024 logger = logging_util.get_logger(__name__)
00025 
00026 
00027 class ISession(object):
00028     """ Session Interface class
00029     """
00030     __metaclass__ = abc.ABCMeta
00031 
00032     @abc.abstractmethod
00033     def set_next_state(self, state):
00034         pass    # pragma: no cover
00035 
00036     @abc.abstractmethod
00037     def add_packet(self, type, data):
00038         pass    # pragma: no cover
00039 
00040     @abc.abstractmethod
00041     def check_completion(self):
00042         pass    # pragma: no cover
00043 
00044     @abc.abstractmethod
00045     def result(self):
00046         pass    # pragma: no cover
00047 
00048     @abc.abstractmethod
00049     def register_result_cb(self):
00050         pass    # pragma: no cover
00051 
00052     @abc.abstractmethod
00053     def unregister_result_cb(self):
00054         pass    # pragma: no cover
00055 
00056 
00057 class IState(object):
00058     __metaclass__ = abc.ABCMeta
00059 
00060     @abc.abstractmethod
00061     def run(self, packet_data):
00062         pass    # pragma: no cover
00063 
00064     @abc.abstractmethod
00065     def next(self, packet_type):
00066         pass    # pragma: no cover
00067 
00068     @abc.abstractmethod
00069     def result(self):
00070         pass    # pragma: no cover
00071 
00072     @abc.abstractmethod
00073     def state(self):
00074         pass    # pragma: no cover
00075 
00076 
00077 class PacketType(object):
00078     """ Packet Type
00079     """
00080     START = 0
00081     DATA = 1
00082     END = 2
00083     CANCEL = 3
00084 
00085     __TYPE_TO_STR = {
00086         START: 'START',
00087         END: 'END',
00088         DATA: 'DATA',
00089         CANCEL: 'CANCEL',
00090     }
00091 
00092     @classmethod
00093     def to_str(cls, packet_type):
00094         if packet_type in cls.__TYPE_TO_STR.keys():
00095             return cls.__TYPE_TO_STR[packet_type]
00096         return 'UNKNOWN TYPE'
00097 
00098     @classmethod
00099     def check_packet_type(cls, packet_type):
00100         if packet_type not in cls.__TYPE_TO_STR.keys():
00101             raise ext.InvalidPacketTypeException()
00102 
00103 
00104 class SessionState():
00105     ERROR = -1
00106     INIT = 0
00107     START = 1
00108     DATA = 2
00109     END = 3
00110 
00111     __STATE_TO_STR = {
00112         ERROR: 'ERROR',
00113         INIT: 'INIT',
00114         START: 'START',
00115         DATA: 'DATA',
00116         END: 'END'
00117     }
00118 
00119     @classmethod
00120     def to_str(cls, state):
00121         if state in cls.__STATE_TO_STR.keys():
00122             return cls.__STATE_TO_STR[state]
00123         return 'UNKNOWN STATE'
00124 
00125     @classmethod
00126     def check_state(cls, state):
00127         if state not in cls.__STATE_TO_STR.keys():
00128             raise ext.InvalidSessionStateException()
00129 
00130 
00131 class Session(ISession, threading.Thread):
00132 
00133     def __init__(self, state):
00134         """ init nict speech recognition session
00135 
00136         @param num_workers: number of worker thread
00137         """
00138         threading.Thread.__init__(self)
00139 
00140         # data settings
00141         self._data_que = Queue()
00142 
00143         # que settings
00144         self._state = state
00145         self._next_state = None
00146         self._next_stata_lock = threading.Lock()
00147 
00148         # callback settings
00149         self._result_cb_lock = threading.Lock()
00150         self._result_cb_list = []
00151 
00152         # thread settings
00153         self._stop_request = threading.Event()
00154         # self.daemon = True
00155 
00156     def set_next_state(self, state):
00157         """ add packet
00158 
00159         @param state: next state
00160         @return: None
00161         """
00162         with self._next_stata_lock:
00163             self._next_state = state
00164 
00165     def _get_next_state(self):
00166         """ get next state
00167         @return: next state
00168         """
00169         ret = self._state
00170         with self._next_stata_lock:
00171             if self._next_state:
00172                 ret = self._next_state
00173                 self._next_state = None
00174         return ret
00175 
00176     def add_packet(self, packet_type, data):
00177         """ add send packet
00178         @param packet_type: packet type
00179         @type  packet_type: PakcetType (0, 1, 2, 3)
00180         @param data: input data
00181         @raises InvalidPacketTypeException:
00182         @return: None
00183         """
00184         # check packet type
00185         PacketType.check_packet_type(packet_type)
00186 
00187         # check current state
00188         self._data_que.put_nowait([packet_type, data])
00189 
00190     def run(self):
00191         """ run thread
00192         @return: None
00193         """
00194         while not self._stop_request.isSet():
00195             # process send packet
00196             try:
00197                 # get data from queue
00198                 data_type, data = self._data_que.get(timeout=0.05)
00199 
00200                 self._process(data_type, data)
00201 
00202                 # finish process data
00203                 self._data_que.task_done()
00204 
00205             except Empty:
00206                 pass
00207 
00208     def _process(self, data_type, data):
00209         """ process input data
00210         @param data_type: packet data type
00211         @param data: data
00212         """
00213         # get next state
00214         self._state = self._state.next(data_type)
00215         self._state.run(data)
00216 
00217         if self._state.state() == SessionState.ERROR:
00218             logger.debug('Error Occured. [%s]', self._state.result())
00219             self._state = self._get_next_state()
00220 
00221         elif self._state.state() == SessionState.END:
00222             self._result_text = self._state.result()
00223             logger.info('End Session.')
00224 
00225             # call all callback functions
00226             with self._result_cb_lock:
00227                 for callback in self._result_cb_list:
00228                     callback(self._result_text)
00229 
00230             # renew state
00231             self._state = self._get_next_state()
00232 
00233     def join(self, timeout=None):
00234         """ end thread
00235         @param timeout: time out time [s]
00236         @return: None
00237         """
00238         self._stop_request.set()
00239         self._data_que.join()
00240         super(Session, self).join(timeout)
00241 
00242     def check_completion(self):
00243         """ check completion tasks
00244         @return: True for finish tasks / False for executing tasks
00245         """
00246         return self._data_que.qsize() == 0
00247 
00248     def wait_completion(self, timeout=None):
00249         """ wait completions seesion task
00250         @return: None
00251         """
00252         self._data_que.join(timeout)
00253 
00254     def result(self):
00255         """ get result text
00256         @return: result text
00257         """
00258         return self._result_text
00259 
00260     def state(self):
00261         """ get current session state
00262         @return: session state
00263         """
00264         return self._state.state()
00265 
00266     def register_result_cb(self, cb):
00267         """ register result callback function
00268         @param cb: callback function
00269         @return: None
00270         """
00271         with self._result_cb_lock:
00272             self._result_cb_list.append(cb)
00273             self._result_cb_list = set(self._result_cb_list)
00274 
00275     def unregister_result_cb(self, cb):
00276         """ unregister result callback function
00277         @param cb: callback function
00278         @return: None
00279         """
00280         with self._result_cb_lock:
00281             new_list = [c for c in self._result_cb_list if c is not cb]
00282             self._result_cb_list = new_list


rospeex_core
Author(s): Komei Sugiura
autogenerated on Thu Jun 6 2019 18:53:10