00001 import sys
00002 import os
00003
00004 import pylab
00005 import numpy as np
00006 from load_ndesc import load_ndesc_pc
00007 from scipy import cluster
00008 import random
00009 from config import select_config
00010 import cPickle
00011 from yael import yael
00012 from shogun.Classifier import LibLinear, LibSVM
00013 from shogun.Features import RealFeatures, Labels
00014 from shogun.Kernel import Chi2Kernel
00015 import cv2
00016
00017 def load_descs(config, i, norm=1, old_desc=False, invert_mask=False, use_masks2=0):
00018 fname = config.desc_filename(i)
00019 mask_name = config.mask_filename(i)
00020 mask = pylab.imread(mask_name)
00021 descs = load_ndesc_pc(fname, norm=norm, old_desc=old_desc)
00022 if use_masks2==0:
00023 descs = [l for l in descs if bool(mask[l.v,l.u])!=bool(invert_mask)]
00024 elif use_masks2==1:
00025 mask2 = pylab.imread(config.mask2_filename(i))
00026 descs = [l for l in descs if bool(mask2[l.v,l.u])]
00027 elif use_masks2==2:
00028 mask2 = pylab.imread(config.mask2_filename(i))
00029 descs = [l for l in descs if not(mask2[l.v,l.u].astype('bool')) and mask[l.v,l.u].astype('bool')]
00030 return descs
00031
00032 def sel_descs(config, imnos, ratio_keep, norm=1, old_desc=False, invert_mask=False, use_masks2=0):
00033 sel_desc = []
00034
00035 for i in imnos:
00036 descs = load_descs(config, i, norm=norm, old_desc=old_desc, invert_mask=invert_mask, use_masks2=use_masks2)
00037 sel_desc.extend([l for l in descs if np.random.rand()<=ratio_keep])
00038 if not i%10: print i, len(sel_desc)
00039 print "selected", len(sel_desc), "descriptors"
00040 return np.array([l.desc for l in sel_desc]), sel_desc
00041
00042 def norm_descs(descs, l=1):
00043 print descs.shape
00044 norms = np.apply_along_axis(np.linalg.norm, 0, descs, l)
00045 print sum(norms==0)
00046 descs/=norms
00047 descs[np.isnan(descs)]=0
00048 descs=np.sqrt(descs)
00049 return descs
00050
00051 def visualize_classes(svm, config, imnos):
00052
00053 for n in imnos:
00054
00055 descs, origdescs = sel_descs(config, [n], 1)
00056 descs = norm_descs(descs)
00057
00058 print descs.shape
00059 descs=RealFeatures(descs.T)
00060 out=svm.apply(descs)
00061 print out.get_labels()
00062
00063 origdescs_F=[o for o,x in zip(origdescs,out.get_labels()) if x>0]
00064 ima=pylab.imread(config.img_filename(n))
00065
00066 convert_to_heatmap(ima, origdescs, out.get_labels())
00067
00068 def convert_to_heatmap(ima, descs, sco):
00069 imabw=np.zeros((ima.shape[0],ima.shape[1])).astype('uint8')
00070 imabw=cv2.cvtColor(ima,cv2.COLOR_RGB2GRAY)
00071 imabw2=np.zeros((ima.shape[0]/6,ima.shape[1]/6)).astype('uint8')
00072 imabw=(255*imabw).astype('uint8')
00073 imabw2=cv2.resize(src=imabw,dsize=(ima.shape[1]/6,ima.shape[0]/6),interpolation=cv2.INTER_LINEAR)
00074 edges=cv2.Canny(imabw2,50,60)
00075
00076 maxp=None
00077 maxpsco=None
00078 shape = ima.shape
00079 scores = min(sco)*np.ones((shape[0]/6,shape[1]/6))
00080 for s,desc in zip(sco,descs):
00081 u = desc.u/6
00082 v = desc.v/6
00083 if maxp==None or maxpsco<s:
00084 maxp=(desc.u,desc.v)
00085 maxpsco=s
00086 if edges[v,u]!=0:
00087 scores[v,u]=np.nan
00088 else:
00089 scores[v,u]=s
00090 cmap=pylab.cm.jet
00091 cmap.set_bad('black',1.)
00092 masked_array = np.ma.masked_where(np.isnan(scores), scores)
00093
00094 pylab.imshow(masked_array, cmap=cmap)
00095 pylab.colorbar()
00096 pylab.figure()
00097 pylab.imshow(ima)
00098 print maxp[0], maxp[1], "WTRGAR"
00099 pylab.hold(True)
00100 pylab.plot(maxp[0], maxp[1], 'y*', markersize=15)
00101 pylab.show()
00102
00103
00104 def visualize_descs(ima, descs):
00105 pylab.imshow(ima)
00106 pylab.hold(1)
00107 for desc in descs:
00108 u = desc.u
00109 v = desc.v
00110 pylab.plot(u,v,'r*')
00111 pylab.show()
00112
00113
00114
00115
00116 def visualize_box_in_image(desc, nspa, pixw):
00117 u = desc.u
00118 v = desc.v
00119 pylab.plot(np.array([[u-pixw,u-pixw,u-pixw,u+pixw],[u+pixw,u-pixw,u+pixw,u+pixw] ]), np.array([[v-pixw,v-pixw,v+pixw,v-pixw],[v-pixw,v+pixw,v+pixw,v+pixw]]),'r')
00120 for ix in range(1,nspa):
00121 pylab.plot(np.array([u-pixw, u+pixw]).T,np.array([(v-pixw) + ix*(2*pixw+1)/float(nspa),(v-pixw) + ix*(2*pixw+1)/float(nspa)]).T,'g')
00122 pylab.plot(np.array([(u-pixw) + ix*(2*pixw+1)/float(nspa),(u-pixw) + ix*(2*pixw+1)/float(nspa)]).T,np.array([v-pixw, v+pixw]).T,'g')
00123 pylab.plot(u,v,'*r')
00124
00125
00126
00127 def visualize_desc(descs, npdescs, ima, nori, nspa,pixw=15):
00128 pylab.ion()
00129 seldescs = range(len(descs))
00130 random.shuffle(seldescs)
00131 once=False
00132 cbar=None
00133 for i in seldescs:
00134 print descs[i].u, descs[i].v, i
00135 pylab.figure(1)
00136 pylab.cla()
00137 pylab.imshow(ima)
00138 pylab.hold('on')
00139 visualize_box_in_image(descs[i], nspa, pixw)
00140
00141 pylab.figure(2)
00142 visualize_one_desc(npdescs[i], nori, nspa)
00143
00144
00145
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158
00159
00160 pylab.show()
00161
00162 c=raw_input(">")
00163 if c=='q':
00164 break
00165
00166
00167
00168 def run_on_images(config,centroids,sel_i,imnos,select_mode=None, nspa=4,nori=4,pixw=15):
00169 colors=['*r','*g','*b','*k','*c','or','og','ob','ok','oc']
00170 for i in imnos:
00171 pylab.figure(1)
00172 pylab.cla()
00173 ima_name = config.img_filename(i)
00174 ima = pylab.imread(ima_name)
00175
00176 descs = load_descs(config, i, old_desc=True)
00177 descs_feats = np.array([x.desc for x in descs])
00178
00179 assign=cluster.vq.vq(robust_whiten(descs_feats),centroids[0])
00180 pylab.cla()
00181 pylab.imshow(ima)
00182 pylab.hold('on')
00183 if select_mode==None:
00184 for i in range(len(descs)):
00185
00186 if assign[0][i]!=sel_i:
00187 continue
00188 pylab.plot(descs[i].u,descs[i].v,colors[0])
00189 elif select_mode=='closest':
00190 sel = (assign[0]==sel_i).nonzero()
00191 if not len(sel[0])==0:
00192 print "se",sel, assign[1][sel[0]]
00193 best=np.argmin(assign[1][sel[0]])
00194 print best
00195 visualize_box_in_image(descs[sel[0][best]], nspa, pixw)
00196
00197 else:
00198 print "NONE FOUND"
00199 elif select_mode=='highest':
00200 pass
00201 pylab.show()
00202 c=raw_input('t')
00203 if c=='q':
00204 break
00205
00206
00207 if __name__=="__main__":
00208 args=sys.argv[1:]
00209
00210 todo=[]
00211 dbname="none"
00212 im_begin=0
00213 im_end=-1
00214 n_thread=1
00215 imnos=None
00216 verb=False
00217 max_desc=200*1000
00218 k=25
00219 ratio_keep=0.2
00220 outfile="tmp.txt"
00221 infile=None
00222 hand_pick=False
00223 save_center=False
00224 use_masks2=1
00225 while args:
00226 a=args.pop(0)
00227 if a in ['-h','--help']: usage()
00228 elif a=='-v': verb=True
00229 elif a=='-db': dbname=args.pop(0)
00230 elif a=='-begin': im_begin=int(args.pop(0))
00231 elif a=='-end': im_end=int(args.pop(0))
00232 elif a=='-imnos': imnos=[int(x) for x in args.pop(0).split(',')]
00233 elif a=='-nt': n_thread=parse_nt(args.pop(0))
00234 elif a=='-max_desc': max_desc=int(args.pop(0))
00235 elif a=='-k': k=int(args.pop(0))
00236 elif a=='-keep': ratio_keep=float(args.pop(0))
00237 elif a=='-outfile': outfile=args.pop(0)
00238 elif a=='-infile': infile=args.pop(0)
00239 elif a=='-hand_pick': hand_pick=True
00240 elif a=='-save': save_center=True
00241
00242 else:
00243 sys.stderr.write("unknown arg %s\n"%a)
00244 usage()
00245 sys.exit(-1)
00246
00247 config=select_config(dbname)
00248 keepvecs=[]
00249
00250
00251 if im_end==-1: im_end=config.nimg
00252 if im_end>config.nimg:
00253 sys.stderr.write("warn: forcing -end to %i"%config.nimg)
00254 im_end=config.nimg
00255 if imnos == None:
00256 imnos=range(im_begin, im_end)
00257 descs, origdescs = sel_descs(config, imnos, ratio_keep, use_masks2=use_masks2)
00258
00259
00260
00261 descs_neg, origdescs_neg = sel_descs(config, imnos, ratio_keep, use_masks2=2)
00262
00263 npos = descs.shape[0]
00264 sel = range(descs_neg.shape[0])
00265 random.shuffle(sel)
00266 sel = sel[:npos]
00267
00268 feats = np.vstack((descs.astype('float64'), descs_neg[sel].astype('float64')))
00269 feats = norm_descs(feats)
00270 feats = RealFeatures(feats.T)
00271 labels = Labels(np.hstack((np.ones((1,descs.shape[0])), -1*np.ones((1,len(sel)))))[0])
00272
00273 svm = LibLinear(1, feats, labels)
00274
00275
00276 svm.train()
00277 visualize_classes(svm, config, range(max(imnos),config.nimg))
00278
00279 print "Writting SVM"
00280 pf=open('last_classifier.pkl','w')
00281 cPickle.dump(svm, pf)
00282 pf.close()
00283
00284
00285
00286
00287
00288
00289
00290
00291
00292
00293
00294
00295
00296
00297
00298
00299
00300
00301
00302
00303
00304
00305