learn_codebook.py
Go to the documentation of this file.
00001 import sys
00002 import os
00003 
00004 import pylab
00005 import numpy as np
00006 from load_ndesc import load_ndesc_pc
00007 from scipy import cluster
00008 import random
00009 from config import select_config
00010 import cPickle
00011 from yael import yael
00012 from shogun.Classifier import LibLinear, LibSVM
00013 from shogun.Features import RealFeatures, Labels
00014 from shogun.Kernel import Chi2Kernel
00015 import cv2
00016 
00017 def load_descs(config, i, norm=1, old_desc=False, invert_mask=False, use_masks2=0):
00018     fname = config.desc_filename(i)
00019     mask_name = config.mask_filename(i)
00020     mask      = pylab.imread(mask_name)
00021     descs     = load_ndesc_pc(fname, norm=norm, old_desc=old_desc)
00022     if use_masks2==0:
00023         descs     = [l for l in descs if bool(mask[l.v,l.u])!=bool(invert_mask)]
00024     elif use_masks2==1:
00025         mask2 = pylab.imread(config.mask2_filename(i))
00026         descs = [l for l in descs if bool(mask2[l.v,l.u])]
00027     elif use_masks2==2:
00028         mask2 = pylab.imread(config.mask2_filename(i))
00029         descs = [l for l in descs if not(mask2[l.v,l.u].astype('bool')) and mask[l.v,l.u].astype('bool')]
00030     return descs
00031     
00032 def sel_descs(config, imnos, ratio_keep, norm=1, old_desc=False, invert_mask=False, use_masks2=0):
00033     sel_desc  = [] 
00034 
00035     for i in imnos:
00036         descs = load_descs(config, i, norm=norm, old_desc=old_desc, invert_mask=invert_mask, use_masks2=use_masks2)
00037         sel_desc.extend([l for l in descs if np.random.rand()<=ratio_keep])
00038         if not i%10: print i, len(sel_desc)
00039     print "selected", len(sel_desc), "descriptors"
00040     return np.array([l.desc for l in sel_desc]), sel_desc
00041 
00042 def norm_descs(descs, l=1):
00043     print descs.shape
00044     norms = np.apply_along_axis(np.linalg.norm, 0, descs, l) #leave bias alone
00045     print sum(norms==0)
00046     descs/=norms
00047     descs[np.isnan(descs)]=0
00048     descs=np.sqrt(descs)
00049     return descs
00050 
00051 def visualize_classes(svm, config, imnos):
00052     
00053     for n in imnos:
00054         # TODO BEGIN
00055         descs, origdescs = sel_descs(config, [n], 1)
00056         descs = norm_descs(descs)
00057         #print descs
00058         print descs.shape
00059         descs=RealFeatures(descs.T)
00060         out=svm.apply(descs)
00061         print out.get_labels()
00062         # TODO END
00063         origdescs_F=[o for o,x in zip(origdescs,out.get_labels()) if x>0]
00064         ima=pylab.imread(config.img_filename(n))
00065         #visualize_descs(ima, origdescs_F)
00066         convert_to_heatmap(ima, origdescs, out.get_labels())
00067 
00068 def convert_to_heatmap(ima, descs, sco):
00069     imabw=np.zeros((ima.shape[0],ima.shape[1])).astype('uint8')
00070     imabw=cv2.cvtColor(ima,cv2.COLOR_RGB2GRAY)
00071     imabw2=np.zeros((ima.shape[0]/6,ima.shape[1]/6)).astype('uint8')
00072     imabw=(255*imabw).astype('uint8')
00073     imabw2=cv2.resize(src=imabw,dsize=(ima.shape[1]/6,ima.shape[0]/6),interpolation=cv2.INTER_LINEAR)
00074     edges=cv2.Canny(imabw2,50,60)        
00075 
00076     maxp=None
00077     maxpsco=None
00078     shape = ima.shape
00079     scores = min(sco)*np.ones((shape[0]/6,shape[1]/6))
00080     for s,desc in zip(sco,descs):
00081         u = desc.u/6
00082         v = desc.v/6
00083         if maxp==None or maxpsco<s:
00084             maxp=(desc.u,desc.v)
00085             maxpsco=s
00086         if edges[v,u]!=0:
00087             scores[v,u]=np.nan
00088         else:
00089             scores[v,u]=s
00090     cmap=pylab.cm.jet #spectral
00091     cmap.set_bad('black',1.)
00092     masked_array = np.ma.masked_where(np.isnan(scores), scores)
00093     
00094     pylab.imshow(masked_array, cmap=cmap)
00095     pylab.colorbar()
00096     pylab.figure()
00097     pylab.imshow(ima)
00098     print maxp[0], maxp[1], "WTRGAR"
00099     pylab.hold(True)
00100     pylab.plot(maxp[0], maxp[1], 'y*', markersize=15)
00101     pylab.show()
00102 
00103 # TODO delete?
00104 def visualize_descs(ima, descs):
00105     pylab.imshow(ima)
00106     pylab.hold(1)
00107     for desc in descs:
00108         u = desc.u 
00109         v = desc.v
00110         pylab.plot(u,v,'r*')
00111     pylab.show()
00112     #raw_input('sss')
00113 
00114 
00115 # TODO delete?
00116 def visualize_box_in_image(desc, nspa, pixw):
00117     u = desc.u 
00118     v = desc.v
00119     pylab.plot(np.array([[u-pixw,u-pixw,u-pixw,u+pixw],[u+pixw,u-pixw,u+pixw,u+pixw] ]), np.array([[v-pixw,v-pixw,v+pixw,v-pixw],[v-pixw,v+pixw,v+pixw,v+pixw]]),'r')
00120     for ix in range(1,nspa):
00121         pylab.plot(np.array([u-pixw, u+pixw]).T,np.array([(v-pixw) + ix*(2*pixw+1)/float(nspa),(v-pixw) + ix*(2*pixw+1)/float(nspa)]).T,'g')
00122         pylab.plot(np.array([(u-pixw) + ix*(2*pixw+1)/float(nspa),(u-pixw) + ix*(2*pixw+1)/float(nspa)]).T,np.array([v-pixw, v+pixw]).T,'g')
00123     pylab.plot(u,v,'*r')
00124 
00125 
00126 # TODO delete?
00127 def visualize_desc(descs, npdescs, ima, nori, nspa,pixw=15):
00128     pylab.ion()
00129     seldescs = range(len(descs))
00130     random.shuffle(seldescs)
00131     once=False
00132     cbar=None
00133     for i in seldescs: 
00134         print descs[i].u, descs[i].v, i
00135         pylab.figure(1)
00136         pylab.cla()
00137         pylab.imshow(ima)
00138         pylab.hold('on')
00139         visualize_box_in_image(descs[i], nspa, pixw)
00140 
00141         pylab.figure(2)
00142         visualize_one_desc(npdescs[i], nori, nspa)
00143         #pylab.cla()
00144         # offset=0
00145         # for ix in range(1,nspa*nspa+1):
00146         #     pylab.subplot(nspa,nspa,ix)
00147         #     pylab.imshow(npdescs[i][offset:offset+nori*nori].reshape(nori,nori), interpolation='nearest',vmin=vmin, vmax=vmax)
00148         #     pylab.yticks(np.arange(-0.5,4.5,1),[str(np.round(a,2)) for a in np.arange(-np.pi,np.pi+np.pi/2.,2*np.pi/4.)])
00149         #     pylab.xticks(np.arange(-0.5,4.5,1),[str(np.round(a,2)) for a in np.arange(np.pi/2.,np.pi+np.pi/4.,np.pi/8.)])
00150         
00151             
00152         #     #pylab.bar(np.arange(0,nori*nori),npdescs[i][offset:offset+nori*nori],width=0.2)
00153         #     print zip(np.arange(0,nori*nori),npdescs[i][offset:offset+nori*nori])
00154         #     offset+=nori*nori
00155         # if once==False:
00156         #     cbar=pylab.colorbar()
00157         #     once=True
00158         # else:
00159         #     pass #cbar.update_normal()
00160         pylab.show()
00161 
00162         c=raw_input(">")
00163         if c=='q':
00164             break
00165 
00166 #def visualize_desc2(desc, nori, nspa):
00167 
00168 def run_on_images(config,centroids,sel_i,imnos,select_mode=None, nspa=4,nori=4,pixw=15):
00169     colors=['*r','*g','*b','*k','*c','or','og','ob','ok','oc']
00170     for i in imnos:
00171         pylab.figure(1)
00172         pylab.cla()
00173         ima_name = config.img_filename(i)
00174         ima = pylab.imread(ima_name)
00175         
00176         descs = load_descs(config, i, old_desc=True)
00177         descs_feats = np.array([x.desc for x in descs])
00178         #visualize_desc(descs, descs_feats, pylab.imread(config.img_filename(i)), nori=4,nspa=4,pixw=20)
00179         assign=cluster.vq.vq(robust_whiten(descs_feats),centroids[0])
00180         pylab.cla()
00181         pylab.imshow(ima)
00182         pylab.hold('on')
00183         if select_mode==None:
00184             for i in range(len(descs)):
00185             #if assign[1][i]>centroids[1]*0.8:
00186                 if assign[0][i]!=sel_i:
00187                     continue
00188                 pylab.plot(descs[i].u,descs[i].v,colors[0]) #colors[assign[0][i]])
00189         elif select_mode=='closest':
00190             sel = (assign[0]==sel_i).nonzero()
00191             if not len(sel[0])==0:
00192                 print "se",sel, assign[1][sel[0]]
00193                 best=np.argmin(assign[1][sel[0]])
00194                 print best
00195                 visualize_box_in_image(descs[sel[0][best]], nspa, pixw)
00196                 #pylab.plot(descs[sel[0][best]].u,descs[sel[0][best]].v,colors[0],markersize=15)
00197             else:
00198                 print "NONE FOUND"
00199         elif select_mode=='highest':
00200             pass
00201         pylab.show()
00202         c=raw_input('t')
00203         if c=='q':
00204             break
00205 
00206 
00207 if __name__=="__main__":
00208     args=sys.argv[1:]
00209     
00210     todo=[]
00211     dbname="none"
00212     im_begin=0
00213     im_end=-1
00214     n_thread=1
00215     imnos=None
00216     verb=False
00217     max_desc=200*1000
00218     k=25
00219     ratio_keep=0.2
00220     outfile="tmp.txt"
00221     infile=None
00222     hand_pick=False
00223     save_center=False
00224     use_masks2=1
00225     while args:
00226         a=args.pop(0)
00227         if a in ['-h','--help']:      usage()
00228         elif a=='-v':                 verb=True
00229         elif a=='-db':                dbname=args.pop(0)
00230         elif a=='-begin':             im_begin=int(args.pop(0))
00231         elif a=='-end':               im_end=int(args.pop(0))
00232         elif a=='-imnos':             imnos=[int(x) for x in args.pop(0).split(',')]
00233         elif a=='-nt':                n_thread=parse_nt(args.pop(0))
00234         elif a=='-max_desc':          max_desc=int(args.pop(0))
00235         elif a=='-k':                 k=int(args.pop(0))
00236         elif a=='-keep':              ratio_keep=float(args.pop(0))
00237         elif a=='-outfile':           outfile=args.pop(0)
00238         elif a=='-infile':            infile=args.pop(0)
00239         elif a=='-hand_pick':         hand_pick=True
00240         elif a=='-save':              save_center=True
00241         #elif a=='-mask2':             use_masks2=1
00242         else:
00243             sys.stderr.write("unknown arg %s\n"%a)
00244             usage()
00245             sys.exit(-1)
00246         
00247     config=select_config(dbname)
00248     keepvecs=[]
00249 
00250     ## LEARN CENTROIDS
00251     if im_end==-1: im_end=config.nimg
00252     if im_end>config.nimg:
00253         sys.stderr.write("warn: forcing -end to %i"%config.nimg)
00254         im_end=config.nimg
00255     if imnos == None:
00256         imnos=range(im_begin, im_end)
00257     descs, origdescs = sel_descs(config, imnos, ratio_keep, use_masks2=use_masks2)
00258     #if use_masks2>0:
00259     #descs_neg, origdescs_neg = sel_descs(config, [3], ratio_keep, use_masks2=2)
00260     #visualize_descs(pylab.imread(config.img_filename(3)), origdescs_neg)
00261     descs_neg, origdescs_neg = sel_descs(config, imnos, ratio_keep, use_masks2=2)
00262     #SVM
00263     npos = descs.shape[0]
00264     sel = range(descs_neg.shape[0])
00265     random.shuffle(sel)
00266     sel = sel[:npos]
00267 
00268     feats = np.vstack((descs.astype('float64'), descs_neg[sel].astype('float64')))
00269     feats = norm_descs(feats)
00270     feats = RealFeatures(feats.T)
00271     labels = Labels(np.hstack((np.ones((1,descs.shape[0])), -1*np.ones((1,len(sel)))))[0])
00272 
00273     svm = LibLinear(1, feats, labels)
00274     #k = Chi2Kernel(feats,feats, 1.0, 100)
00275     #svm = LibSVM(1, k, labels)
00276     svm.train()
00277     visualize_classes(svm, config, range(max(imnos),config.nimg))
00278     
00279     print "Writting SVM"
00280     pf=open('last_classifier.pkl','w')
00281     cPickle.dump(svm, pf)
00282     pf.close()
00283     
00284     
00285     #VISUALIZE CLASSES
00286     #pylab.ion()
00287     #for i in range(len(centroids[0])):
00288         #pylab.figure(2)
00289         #visualize_one_desc(centroids[0][i],nori=4,nspa=4)
00290         #pylab.draw()
00291         #pylab.draw()
00292     ##     pylab.cla()
00293     ##     pylab.bar(range(0,centroids[0][i].shape[0]),centroids[0][i])
00294     ##     pylab.show()
00295         #c=raw_input('w')
00296         #while c=='t':
00297             #pylab.figure(1)
00298             #pylab.cla()
00299             #run_on_images(config,centroids,i,imnos,select_mode='closest',nspa=4,nori=4,pixw=20)
00300             #c=raw_input('w')
00301             #if c=='k':
00302                 #keepvecs.append(i)
00303     #if save_center:
00304         #np.savetxt(outfile+'.selected',np.array(keepvecs))
00305 


iri_bow_object_detector
Author(s): dmartinez
autogenerated on Fri Dec 6 2013 22:45:46