PrefetchingIter.py
Go to the documentation of this file.
00001 # --------------------------------------------------------
00002 # Deformable Convolutional Networks
00003 # Copyright (c) 2016 by Contributors
00004 # Copyright (c) 2017 Microsoft
00005 # Licensed under The Apache-2.0 License [see LICENSE for details]
00006 # Modified by Yuwen Xiong
00007 # --------------------------------------------------------
00008 
00009 
00010 import mxnet as mx
00011 from mxnet.io import DataDesc, DataBatch
00012 import threading
00013 
00014 
00015 class PrefetchingIter(mx.io.DataIter):
00016     """Base class for prefetching iterators. Takes one or more DataIters (
00017     or any class with "reset" and "next" methods) and combine them with
00018     prefetching. For example:
00019 
00020     Parameters
00021     ----------
00022     iters : DataIter or list of DataIter
00023         one or more DataIters (or any class with "reset" and "next" methods)
00024     rename_data : None or list of dict
00025         i-th element is a renaming map for i-th iter, in the form of
00026         {'original_name' : 'new_name'}. Should have one entry for each entry
00027         in iter[i].provide_data
00028     rename_label : None or list of dict
00029         Similar to rename_data
00030 
00031     Examples
00032     --------
00033     iter = PrefetchingIter([NDArrayIter({'data': X1}), NDArrayIter({'data': X2})],
00034                            rename_data=[{'data': 'data1'}, {'data': 'data2'}])
00035     """
00036     def __init__(self, iters, rename_data=None, rename_label=None):
00037         super(PrefetchingIter, self).__init__()
00038         if not isinstance(iters, list):
00039             iters = [iters]
00040         self.n_iter = len(iters)
00041         assert self.n_iter ==1, "Our prefetching iter only support 1 DataIter"
00042         self.iters = iters
00043         self.rename_data = rename_data
00044         self.rename_label = rename_label
00045         self.batch_size = len(self.provide_data) * self.provide_data[0][0][1][0]
00046         self.data_ready = [threading.Event() for i in range(self.n_iter)]
00047         self.data_taken = [threading.Event() for i in range(self.n_iter)]
00048         for e in self.data_taken:
00049             e.set()
00050         self.started = True
00051         self.current_batch = [None for _ in range(self.n_iter)]
00052         self.next_batch = [None for _ in range(self.n_iter)]
00053         def prefetch_func(self, i):
00054             """Thread entry"""
00055             while True:
00056                 self.data_taken[i].wait()
00057                 if not self.started:
00058                     break
00059                 try:
00060                     self.next_batch[i] = self.iters[i].next()
00061                 except StopIteration:
00062                     self.next_batch[i] = None
00063                 self.data_taken[i].clear()
00064                 self.data_ready[i].set()
00065         self.prefetch_threads = [threading.Thread(target=prefetch_func, args=[self, i]) \
00066                                  for i in range(self.n_iter)]
00067         for thread in self.prefetch_threads:
00068             thread.setDaemon(True)
00069             thread.start()
00070 
00071     def __del__(self):
00072         self.started = False
00073         for e in self.data_taken:
00074             e.set()
00075         for thread in self.prefetch_threads:
00076             thread.join()
00077 
00078     @property
00079     def provide_data(self):
00080         """The name and shape of data provided by this iterator"""
00081         if self.rename_data is None:
00082             return sum([i.provide_data for i in self.iters], [])
00083         else:
00084             return sum([[
00085                 DataDesc(r[x.name], x.shape, x.dtype)
00086                 if isinstance(x, DataDesc) else DataDesc(*x)
00087                 for x in i.provide_data
00088             ] for r, i in zip(self.rename_data, self.iters)], [])
00089 
00090     @property
00091     def provide_label(self):
00092         """The name and shape of label provided by this iterator"""
00093         if self.rename_label is None:
00094             return sum([i.provide_label for i in self.iters], [])
00095         else:
00096             return sum([[
00097                 DataDesc(r[x.name], x.shape, x.dtype)
00098                 if isinstance(x, DataDesc) else DataDesc(*x)
00099                 for x in i.provide_label
00100             ] for r, i in zip(self.rename_label, self.iters)], [])
00101 
00102     def reset(self):
00103         for e in self.data_ready:
00104             e.wait()
00105         for i in self.iters:
00106             i.reset()
00107         for e in self.data_ready:
00108             e.clear()
00109         for e in self.data_taken:
00110             e.set()
00111 
00112     def iter_next(self):
00113         for e in self.data_ready:
00114             e.wait()
00115         if self.next_batch[0] is None:
00116             return False
00117         else:
00118             self.current_batch = self.next_batch[0]
00119             for e in self.data_ready:
00120                 e.clear()
00121             for e in self.data_taken:
00122                 e.set()
00123             return True
00124 
00125     def next(self):
00126         if self.iter_next():
00127             return self.current_batch
00128         else:
00129             raise StopIteration
00130 
00131     def getdata(self):
00132         return self.current_batch.data
00133 
00134     def getlabel(self):
00135         return self.current_batch.label
00136 
00137     def getindex(self):
00138         return self.current_batch.index
00139 
00140     def getpad(self):
00141         return self.current_batch.pad


rail_object_detector
Author(s):
autogenerated on Sat Jun 8 2019 20:26:30