Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 import mxnet as mx
00011 from mxnet.io import DataDesc, DataBatch
00012 import threading
00013
00014
00015 class PrefetchingIter(mx.io.DataIter):
00016 """Base class for prefetching iterators. Takes one or more DataIters (
00017 or any class with "reset" and "next" methods) and combine them with
00018 prefetching. For example:
00019
00020 Parameters
00021 ----------
00022 iters : DataIter or list of DataIter
00023 one or more DataIters (or any class with "reset" and "next" methods)
00024 rename_data : None or list of dict
00025 i-th element is a renaming map for i-th iter, in the form of
00026 {'original_name' : 'new_name'}. Should have one entry for each entry
00027 in iter[i].provide_data
00028 rename_label : None or list of dict
00029 Similar to rename_data
00030
00031 Examples
00032 --------
00033 iter = PrefetchingIter([NDArrayIter({'data': X1}), NDArrayIter({'data': X2})],
00034 rename_data=[{'data': 'data1'}, {'data': 'data2'}])
00035 """
00036 def __init__(self, iters, rename_data=None, rename_label=None):
00037 super(PrefetchingIter, self).__init__()
00038 if not isinstance(iters, list):
00039 iters = [iters]
00040 self.n_iter = len(iters)
00041 assert self.n_iter ==1, "Our prefetching iter only support 1 DataIter"
00042 self.iters = iters
00043 self.rename_data = rename_data
00044 self.rename_label = rename_label
00045 self.batch_size = len(self.provide_data) * self.provide_data[0][0][1][0]
00046 self.data_ready = [threading.Event() for i in range(self.n_iter)]
00047 self.data_taken = [threading.Event() for i in range(self.n_iter)]
00048 for e in self.data_taken:
00049 e.set()
00050 self.started = True
00051 self.current_batch = [None for _ in range(self.n_iter)]
00052 self.next_batch = [None for _ in range(self.n_iter)]
00053 def prefetch_func(self, i):
00054 """Thread entry"""
00055 while True:
00056 self.data_taken[i].wait()
00057 if not self.started:
00058 break
00059 try:
00060 self.next_batch[i] = self.iters[i].next()
00061 except StopIteration:
00062 self.next_batch[i] = None
00063 self.data_taken[i].clear()
00064 self.data_ready[i].set()
00065 self.prefetch_threads = [threading.Thread(target=prefetch_func, args=[self, i]) \
00066 for i in range(self.n_iter)]
00067 for thread in self.prefetch_threads:
00068 thread.setDaemon(True)
00069 thread.start()
00070
00071 def __del__(self):
00072 self.started = False
00073 for e in self.data_taken:
00074 e.set()
00075 for thread in self.prefetch_threads:
00076 thread.join()
00077
00078 @property
00079 def provide_data(self):
00080 """The name and shape of data provided by this iterator"""
00081 if self.rename_data is None:
00082 return sum([i.provide_data for i in self.iters], [])
00083 else:
00084 return sum([[
00085 DataDesc(r[x.name], x.shape, x.dtype)
00086 if isinstance(x, DataDesc) else DataDesc(*x)
00087 for x in i.provide_data
00088 ] for r, i in zip(self.rename_data, self.iters)], [])
00089
00090 @property
00091 def provide_label(self):
00092 """The name and shape of label provided by this iterator"""
00093 if self.rename_label is None:
00094 return sum([i.provide_label for i in self.iters], [])
00095 else:
00096 return sum([[
00097 DataDesc(r[x.name], x.shape, x.dtype)
00098 if isinstance(x, DataDesc) else DataDesc(*x)
00099 for x in i.provide_label
00100 ] for r, i in zip(self.rename_label, self.iters)], [])
00101
00102 def reset(self):
00103 for e in self.data_ready:
00104 e.wait()
00105 for i in self.iters:
00106 i.reset()
00107 for e in self.data_ready:
00108 e.clear()
00109 for e in self.data_taken:
00110 e.set()
00111
00112 def iter_next(self):
00113 for e in self.data_ready:
00114 e.wait()
00115 if self.next_batch[0] is None:
00116 return False
00117 else:
00118 self.current_batch = self.next_batch[0]
00119 for e in self.data_ready:
00120 e.clear()
00121 for e in self.data_taken:
00122 e.set()
00123 return True
00124
00125 def next(self):
00126 if self.iter_next():
00127 return self.current_batch
00128 else:
00129 raise StopIteration
00130
00131 def getdata(self):
00132 return self.current_batch.data
00133
00134 def getlabel(self):
00135 return self.current_batch.label
00136
00137 def getindex(self):
00138 return self.current_batch.index
00139
00140 def getpad(self):
00141 return self.current_batch.pad