#!/usr/bin/python # # sabertooth.py # punting from svm to nn and back again using thresholds # # outputs: # ap.minerr # bp.minerr # apb_bp.minerr # bpa_ap.minerr # ap.minerr_maxcov # bp.minerr_maxcov # apb_bp.minerr_maxcov # bpa_ap.minerr_maxcov thresh_method="multi" # or "single" covtstidx_fname ="a_tst.indices.txt" uncovtstidx_fname ="b_tst.indices.txt" covtrnidx_fname = "a_trn.indices.txt" uncovtrnidx_fname = "b_trn.indices.txt" targetlabels_fname = "scop.1.69.labels.txt" #for profiles kernel and PSI-BLAST experiment.. svm_matrix_fname = "svm-profile.predict.mtx" svm_matrix_fname_cols = "svm-profile.labels.txt" nn_matrix_fname = "psiblast.predict.mtx" nn_matrix_fname_cols = "psiblast.labels.txt" #for SVM-MAMMOTH and nn-MAMMOTH experiment.. #svm_matrix_fname = "svm-mammoth.mtx" #nn_matrix_fname = "mammoth.predict.mtx" #svm_matrix_fname_cols = "svm-mammoth.labels.txt" #nn_matrix_fname_cols = "mammoth.labels.txt" import sys from numpy import * from load import * xvsvm = False def frange(f,t,s): while f= ap[a_argmax[i]]: if a_cols[a_argmax[i]] != labels[i]: ierror=True else: ipunt=True ipunta=True ipuntf=True return ierror,ipunt,ipunt,ipuntf def errate_method_bp( ma, a_argmax, a_cols, mb, b_argmax, b_cols, anul, bp, labels, i ): ierror=False; ipunt=False; ipunta=False; ipuntf=False; if mb[i][b_argmax[i]] >= bp[b_argmax[i]]: if b_cols[b_argmax[i]] != labels[i]: ierror=True else: ipunt=True ipunta=False ipuntf=True return ierror,ipunt,ipunta,ipuntf def errate_method_apb_bp(ma, a_argmax, a_cols, mb, b_argmax, b_cols, apb, bp, labels, i ): ierror=False; ipunt=False; ipunta=False; ipuntf=False; if ma[i][a_argmax[i]] >= apb[a_argmax[i]]: if a_cols[a_argmax[i]] != labels[i]: ierror=True else: ipunta = True ipuntf = True if mb[i][b_argmax[i]] >= bp[b_argmax[i]]: if b_cols[b_argmax[i]] != labels[i]: ierror=True else: ipunt=True return ierror,ipunt,ipunta,ipuntf def errate_method_bpa_ap(ma, a_argmax, a_cols, mb, b_argmax, b_cols, ap, bpa, labels, i ): ierror=False; ipunt=False; ipunta=False; ipuntf=False; if mb[i][b_argmax[i]] >= bpa[b_argmax[i]]: if b_cols[b_argmax[i]] != labels[i]: ierror=True else: ipuntf=True if ma[i][a_argmax[i]] >= ap[a_argmax[i]]: if a_cols[a_argmax[i]] != labels[i]: ierror=True else: ipunt=True ipunta=True return ierror,ipunt,ipunta,ipuntf def errate_method_bca_bp_ap(ma, a_argmax, a_cols, mb, b_argmax, b_cols, ap, bp, labels, i ): ierror=False; ipunt=False; ipunta=False; ipuntf=False; if b_cols[b_argmax[i]] in a_cols: #punt to a ( predicted class is in a ) if ma[i][a_argmax[i]] >= ap[a_argmax[i]]: if a_cols[a_argmax[i]] != labels[i]: ierror=True else: ipunta=True ipunt=True else: ipuntf=True #may still want to punt from b if mb[i][b_argmax[i]] >= bp[b_argmax[i]]: if b_cols[b_argmax[i]] != labels[i]: ierror=True else: ipunt=True return ierror,ipunt,ipunta,ipuntf def errate_method( ma, a_argmax, a_cols, mb, b_argmax, b_cols, a2b, b2p, labels, tst_class_sizes, tstidx, method ): err = 0.0 cov = 0.0 err_bal = 0.0 cov_bal = 0.0 npuntf = 0 npunt = 0 nacorrect =0 nbcorrect =0 nacorrect_bwrong = 0 nbcorrect_awrong = 0 testerrs=[] for i in tstidx: acorrect=False bcorrect=False if a_cols[a_argmax[i]] == labels[i]: acorrect=True nacorrect+=1 if b_cols[b_argmax[i]] == labels[i]: bcorrect=True nbcorrect+=1 if acorrect and not bcorrect: nacorrect_bwrong+=1 if bcorrect and not acorrect: nbcorrect_awrong+=1 ierror,ipunt,ipunta,ipuntf = method( ma, a_argmax, a_cols, mb, b_argmax, b_cols, a2b, b2p, labels, i ) if ipuntf==True: npuntf+=1 if ipunt==False: cov+=1.0 cov_bal+=1.0/float(tst_class_sizes[labels[i]]) if ierror==True: err+=1.0 err_bal+=1.0/float(tst_class_sizes[labels[i]]) testerrs.append(1.0) #error else: testerrs.append(0.0) else: npunt+=1 testerrs.append(-1.0) #punt if cov > 0: err = err/cov else: err = 1.0 cov /= float(len(tstidx)) err_bal /= float(len(tst_class_sizes)) cov_bal /= float(len(tst_class_sizes)) print("tst:",len(tstidx),"punt:",npunt,"npuntf:",npuntf) return err,cov,err_bal,cov_bal,testerrs,nacorrect,nbcorrect,nacorrect_bwrong,nbcorrect_awrong,len(tstidx),npunt,npuntf def get_thresh(col,target): col.sort() tj = target*float(len(col)) tlower = int(tj) tup = int(tj)+1 if tlower>=len(col) or tup>=len(col): thresh = col[-1] else: thresh = col[tlower]+ (col[tup]-col[tlower])*(tj-float(tlower)) return thresh def thresh_single_nn( thresh ): #find single threshold for punting from uncov to cov #all_bad = all negative examples of predictor all_bad = [] for i in range(0,len(mnn[0])): for j in alltrnidx: if targetlabels[j] != mnn_cols[i]: all_bad.append(mnn[j][i]) single_th = get_thresh(all_bad, thresh) #vectorize out=[] for i in range(0,len(mnn[0])): out.append( single_th ) return out def thresh_multi_svm( thresh ): #find thresholds for punting from cov to uncov out = [] for i in range(0,len(mcodes[0])): col = [] #all preditions made by a particular detector ON THE B TRAIN for j in uncovtrnidx: col.append(mcodes[j][i]) out.append(get_thresh(col, thresh)) return out def thresh_multi_svm_xv( thresh ): assert(len(mcodes_cols)==len(mxv[0])) out = [] col = [] #all preditions made by a particular detector ON THE B TRAIN #CV for i in range(0,len(mcodes_cols)): col = [] #all preditions made by a detector #CV # for j in range(0,len(mxv)): # if mxv_labels[j] != mcodes_cols[i]: # col.append(mxv[j][i]) #Btrn for j in uncovtrnidx: col.append(mcodes[j][i]) out.append(get_thresh(col, thresh)) return out def thresh_single_svm(thresh): #find single threshold for punting from cov to uncov #all_uncov = all uncovered examples from training set (no chance they are from A) all_uncov = [] for i in range(0,len(mcodes[0])): for j in uncovtrnidx: all_uncov.append(mcodes[j][i]) single_th = get_thresh(all_uncov, thresh) out=[] for i in range(0,len(mcodes[0])): out.append( single_th ) return out def thresh_single_svm_xv(thresh): all_uncov = [] #CV # for i in range(0,len(mcodes_cols)): # for j in range(0,len(mxv)): # all_uncov.append(mxv[j][i]) #Btrn for i in range(0,len(mcodes[0])): for j in uncovtrnidx: all_uncov.append(mcodes[j][i]) single_th = get_thresh(all_uncov, thresh) out=[] for i in range(0,len(mcodes_cols)): out.append( single_th ) return out def thresh_multi_nn( thresh ): #find thresholds for punting from uncov to cov out = [] for i in range(0,len(mnn[0])): col = [] for j in alltrnidx: if targetlabels[j] != mnn_cols[i]: col.append(mnn[j][i]) out.append(get_thresh(col, thresh)) return out def find_thresholds( thresh_cov, thresh_uncov, thresh_method ): print "finding thresholds..%.4f,%.4f (%s)"%(thresh_cov,thresh_uncov,thresh_method) #find threashold by gathering all false results and finding a threshold at which, say 99.6% of those results #are below it if xvsvm==True: print "using svm xv" if thresh_method == "single": if xvsvm==True: cov2uncov = thresh_single_svm_xv( thresh_cov ) else: cov2uncov = thresh_single_svm( thresh_cov ) uncov2cov = thresh_single_nn ( thresh_uncov ) elif thresh_method == "multi": if xvsvm==True: cov2uncov = thresh_multi_svm_xv( thresh_cov ) else: cov2uncov = thresh_multi_svm( thresh_cov ) uncov2cov = thresh_multi_nn ( thresh_uncov ) else: print("Unknown threshold method"+thresh_method) sys.exit(1) return cov2uncov, uncov2cov def doit( thresh_cov, thresh_uncov, method, tstidx, thresh_method ): cov2uncov, uncov2cov = find_thresholds( thresh_cov, thresh_uncov, thresh_method ) # average both multi and single methods (performs better, but adds a parameter) # c2um, u2cm = find_thresholds( thresh_cov, thresh_uncov, "multi" ) # c2us, u2cs = find_thresholds( thresh_cov, thresh_uncov, "single" ) # cov2uncov = [] # uncov2cov = [] # for i in range(0,len(c2um)): # cov2uncov.append((c2um[i]+c2us[i])*0.5) # for i in range(0,len(u2cm)): # uncov2cov.append((u2cm[i]+u2cs[i])*0.5) # c_argmax = mcodes_argmax n_argmax = mnn_argmax # #fiddle the bias # #subtract thresholds from all columns, then find argmax # print("don't use this !") # tc = numpy.array(cov2uncov) # tuc = numpy.array(uncov2cov) # csubpunt = mcodes - tc # nsubpunt = mnn - tuc # c_argmax = csubpunt.argmax(axis=1) # n_argmax = nsubpunt.argmax(axis=1) er,cov,bal_err,bal_cov,tsterrs,nac,nbc,nac_bw,nbc_aw,nex,nexpunt,nexpuntf=errate_method( mcodes, c_argmax, mcodes_cols, mnn, n_argmax, mnn_cols, cov2uncov, uncov2cov, targetlabels, alltst_class_sizes, tstidx, method ) #go through tsterrs, determine test error for examples from A and B era = 0 na=0 erb = 0 nb=0 for i in range(0,len(tstidx)): if tsterrs[i] != -1: if tstidx[i] in covtstidx: na += 1 if tsterrs[i] == 1: era += 1 else: nb += 1 if tsterrs[i] == 1: erb += 1 if na != 0: err_a = float(era)/float(na) else: err_a= -1 if nb != 0: err_b = float(erb)/float(nb) else: err_b = -1 #find number of predictions made by methods a and b # - = prediction * = punt #********** 10 nex - number of examples #--------** 2 nexpunt - number punted completely #-----***** 5 nexpuntf - number punted by 1st method #aaaaabbbpp 5,3,2 - predictions/punt npred1 = nex - nexpuntf #10-5=5 npred2 = nexpuntf - nexpunt #5 -2=3 #na = number from test a with predictions #nb = number from test b with predictions #npred1 = number of predictions made by 1st method #npred1 = number of predictions made by 2nd method #testerrs = array of all test examples 1=error 0=correct -1=punted #nac = number method a got correct #nbc = number method b got corrext #nac_bw = number method a got correct and b got wrong #nbc_aw = number method b got correct and a got wrong #err = unblanaced error #cov = unblanaced error coverage #bal_err = balanced error #bal_cov = balanced error coverage return er,cov,bal_err,bal_cov,err_a,err_b,nac,nbc,nac_bw,nbc_aw,tsterrs,npred1,npred2 ##########################################################33 ## SCRIPT BEGINS covtstidx =load_vector_int(covtstidx_fname) uncovtstidx=load_vector_int(uncovtstidx_fname) alltstidx = [] for i in covtstidx: alltstidx.append(i) for i in uncovtstidx: alltstidx.append(i) print "loading dataset decriptions.." covtrnidx = load_vector_int(covtrnidx_fname) uncovtrnidx = load_vector_int(uncovtrnidx_fname) targetlabels = load_labels (targetlabels_fname) print "loading matrices.." mcodes= load_matrix_numpy(svm_matrix_fname) mcodes_cols = load_labels(svm_matrix_fname_cols) mnn = load_matrix_numpy (nn_matrix_fname) mnn_cols = load_labels (nn_matrix_fname_cols) print "calculating argmax.." mcodes_argmax = mcodes.argmax(axis=1) print "codes_argmax"+str(mcodes_argmax) #print mcodes_argmax mnn_argmax = mnn.argmax(axis=1) #print "nn_argmax"+str(mnn_argmax) allidx = [] for i in range(0,len(mnn)): allidx.append(i) alltrnidx = [] for i in covtrnidx: alltrnidx.append(i) for i in uncovtrnidx: alltrnidx.append(i) alltst_class_sizes = {} for i in alltstidx: l = targetlabels[i] if alltst_class_sizes.has_key(l): alltst_class_sizes[l] += 1 else: alltst_class_sizes[l] = 1 print "covtstidx ="+str(len(covtstidx) ) print "mcodes_argmax="+str(len(mcodes_argmax) ) print "mcodes_cols ="+str(len(mcodes_cols) ) print "targetlabels ="+str(len(targetlabels) ) Atst_codes = error_rate( covtstidx, mcodes_argmax, mcodes_cols, targetlabels ) Atst_nn = error_rate( covtstidx, mnn_argmax, mnn_cols, targetlabels ) Btst_codes = error_rate( uncovtstidx, mcodes_argmax, mcodes_cols, targetlabels ) Btst_nn = error_rate( uncovtstidx, mnn_argmax, mnn_cols, targetlabels ) ABtst_codes= error_rate( alltstidx, mcodes_argmax, mcodes_cols, targetlabels ) ABtst_nn = error_rate( alltstidx, mnn_argmax, mnn_cols, targetlabels ) print " A B A+B " print "codes "+str(Atst_codes)+" "+str(Btst_codes)+" "+str(ABtst_codes) print "nn "+str(Atst_nn) +" "+str(Btst_nn) +" "+str(ABtst_nn) mlist = [ {"name":"ap", "emethod":errate_method_ap}, {"name":"bp", "emethod": errate_method_bp}, {"name":"apb_bp", "emethod": errate_method_apb_bp}, {"name":"bpa_ap", "emethod": errate_method_bpa_ap}, #{"name":"bca_bp_ap", "emethod": errate_method_bca_bp_ap} ] if 1: #do it for m in mlist: print(m) fname = m["name"]+".errcov" print("writing to:"+fname) fout = open(fname,"w+") fout.close() records = [] mini = 0 ########################################## 1->20 yi for yi in frange(0,10,0.5): #yi=0 #if True: fy = 1.0-pow(3,-yi) for xi in frange(0,10,0.5): fx = 1.0-pow(3,-xi) er,cov,bal_er,bal_cov,era,erb,nac,nbc,nac_bw,nbc_aw,testerrs,n1,n2 = doit(fx,fy,m["emethod"],alltstidx,thresh_method) fout = open(fname,"a") fout.write("%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%d\t%d\n"% (fx,fy,er,cov,bal_er,bal_cov,era,erb,n1,n2)) fout.close() print "error rate:%.4f coverage:%.4f balanced err:%.4f balanced cov:%.4f"%(er,cov,bal_er,bal_cov) records.append({"fx":fx, "fy":fy, "er":er, "cov":cov, "bal_er":bal_er, "bal_cov":bal_cov }) if records[len(records)-1]["cov"] == 1.0: if records[len(records)-1]["er"] <= records[mini]["er"]: print ("congrats, a new min er at max coverage") mini = len(records)-1 fout.close() #for thresholds with min err at maximum coverage, do it again for A,B + AB test sets.. if True: fout = open(m["name"]+".minerr_maxcov","w+") fx = records[mini]["fx"] fy = records[mini]["fy"] fout.write("thresholds\t%.4f\t%.4f\n"%(fx,fy)) er,cov,bal_er,bal_cov,era,erb,nac,nbc,nac_bw,nbc_aw,testerrs,n1,n2 = doit(fx,fy,m["emethod"],covtstidx,thresh_method) fout.write("Atst\t%.4f\t%.4f\t%.4f\t%.4f\t%d\t%d\t%d\t%d\t%d\t%d\n"%(er,cov,bal_er,bal_cov,nac,nbc,nac_bw,nbc_aw,n1,n2)) er,cov,bal_er,bal_cov,era,erb,nac,nbc,nac_bw,nbc_aw,testerrs,n1,n2 = doit(fx,fy,m["emethod"],uncovtstidx,thresh_method) fout.write("Btst\t%.4f\t%.4f\t%.4f\t%.4f\t%d\t%d\t%d\t%d\t%d\t%d\n"%(er,cov,bal_er,bal_cov,nac,nbc,nac_bw,nbc_aw,n1,n2)) er,cov,bal_er,bal_cov,era,erb,nac,nbc,nac_bw,nbc_aw,testerrs,n1,n2 = doit(fx,fy,m["emethod"],alltstidx,thresh_method) fout.write("A+B tst\t%.4f\t%.4f\t%.4f\t%.4f\t%d\t%d\t%d\t%d\t%d\t%d\n"%(er,cov,bal_er,bal_cov,nac,nbc,nac_bw,nbc_aw,n1,n2)) for i in range(0,len(testerrs)): fout.write("%d\n"%(testerrs[i])) fout.close() else: print("unrecognised arg[1] :< test, errors_only, threshold_pair,all > got: "+sys.argv[1])