
import os,sys
import MLab,Numeric
import math

import misc

homeDir = os.path.join(os.environ['HOME'], 'sppi')
dataDir = os.path.join(homeDir, 'data')
resultsDir = os.path.join(homeDir, 'results')


def mcc(E, Equery = None, E2 = None, **args) :
    """compute the mutual clustering coefficient of a set of query edges
    relative to a given network
    E - interaction network
    Equery - query edges [optional]
    E2 - converted graph (via convertGraph) [optional]
    """

    from sppi.src import analyzeInteractions
    if Equery is None :
        Equery = E
    if E2 is None :
        E2 = analyzeInteractions.convertGraph(E)

    numNo = 0
    jac = {}; mm = {}; geom = {}
    for pair in Equery :
        #if pair in E : print 'query in network'
        p1,p2 = pair
        if p1 not in E2 or p2 not in E2 :
            numNo += 1
            jac[pair] = 0
            mm[pair] = 0
            geom[pair] = 0
            continue
        n1 = E2[p1].keys()
        n2 = E2[p2].keys()
        intersect = float(len(misc.intersect(n1, n2)))
        #print n1,n2
        if intersect == 0 or p1 == p2 :
            jac[pair] = 0
            mm[pair] = 0
            geom[pair] = 0
            continue
        #print intersect
        jac[pair] = intersect / (len(n1) + len(n2))
        mm[pair] = intersect / min(len(n1), len(n2))
        geom[pair] = intersect / math.sqrt(len(n1) * len(n2))

    print 'number of proteins with no interaction:',numNo
    if type(Equery) == type([]) :
        jac = [jac[pair] for pair in Equery]
        mm = [mm[pair] for pair in Equery]
        geom = [geom[pair] for pair in Equery]        

    return jac,mm,geom



def mccKernel(fileName, E) :

    delim = '\t'

    fileHandle = open(fileName, 'w')

    if E2 is None :
        E2 = convertGraph(E)
    
    from sppi.src import orfHandlers
    o = orfHandlers.YeastOrfs()
    orfs = o.orfs.keys()
    orfs.sort()

    fileHandle.write(fileName + delim + delim.join(orfs) + '\n')
                                          
    for orf1 in orfs :
        print orf1
        tokens = []
        for orf2 in orfs :
            if orf1 == orf2 :
                kvalue = 1
            elif orf1 not in E2 or orf2 not in E2 :
                kvalue = 0
            else :
                n1 = E2[p1].keys()
                n2 = E2[p2].keys()
                intersect = float(len(misc.intersect(n1, n2)))
                kvalue = intersect / math.sqrt(len(n1) * len(n2))
            tokens.append(str(kvalue))
            
        fileHandle.write(orf1 + delim + delim.join(tokens) + '\n')


def addMCC(testingData, trainingData, **args) :

    if hasattr(trainingData, 'pydatas') :
        trainingData = trainingData.pydatas[1]
        testingData = testingData.pydatas[1]
    if 'mcc_jaccard_t' in testingData.featureID : return
    if 'mcc_jacard' in testingData.featureID :
        testingData.eliminateFeatures(['mcc_jacard', 'mcc_meetMin', 'mcc_geom'])        
    E = misc.list2dict([tuple(trainingData.labels.patternID[i].split('_'))
                        for i in range(len(trainingData))
                        if trainingData.labels.Y[i] == 1])
    pairs = [tuple(testingData.labels.patternID[i].split('_'))
             for i in range(len(testingData))]

    jac, meetMin, geom = mcc(E, pairs)
    testingData.addFeature('mcc_jaccard_t', jac)
    testingData.addFeature('mcc_meetMin_t', meetMin)
    testingData.addFeature('mcc_geom_t', geom)

def addDataTrain(data, **args) :

    from sppi.src import interactionDataset
    if 'bdata' in args :
        interactionDataset.addReliability(data, **args)

    addMCC(data, data, **args)

def addDataTest(testingData, trainingData, **args) :

    addMCC(testingData, trainingData, **args)
