#------------------------------------------------------------------------------
#------------------------------------------------------------------------------

"""
This program reads a "Labeled FASTA" file and outputs one or more separate
GMTK "observation" files.  The 2nd command-line parameter is optional and can
be used if a small window of labels should be removed on either side of a
segment boundary.
Copyright 2008 University of Washington
"""

__author__ = "Sheila M. Reynolds (sheila@ee.washington.edu)"

#------------------------------------------------------------------------------
#------------------------------------------------------------------------------

def isBlank ( aString ):

    nn = len(aString)

    for ii in range(nn):
	if ( aString[ii] != ' ' ):
	    if ( aString[ii] != '\t' ):
		if ( aString[ii] != '\n' ):
		    return ( 0 )

    return ( 1 )

#------------------------------------------------------------------------------

def stripBlanks ( inString ):

    outString = ''
    for ii in range(len(inString)):
	if ( inString[ii] != ' ' ): outString += inString[ii]

    return ( outString )

#------------------------------------------------------------------------------

def addPrimaryLabelString ( labelString, allLabels ):

    if ( '#' not in allLabels.keys() ):
	allLabels['#'] = ''

    allLabels['#'] += stripBlanks ( labelString )

    return

#------------------------------------------------------------------------------

def addOtherLabelString ( labelString, allLabels ):

    ii = labelString.find(' ')
    labelID = labelString[0:ii]

    if ( labelID not in allLabels.keys() ):
	allLabels[labelID] = ''

    allLabels[labelID] += stripBlanks ( labelString[ii:] )

    return

#------------------------------------------------------------------------------

def checkSequence ( seqString, allLabels ):

    seqLength = len(seqString)

    for aLabelId in allLabels.keys():
	labelLength = len ( allLabels[aLabelId] )
	if ( labelLength != seqLength ):
	    print ' ERROR in checkSequence ', seqLength, labelLength
	    print seqString
	    print allLabels[aLabelId]
	    print allLabels.keys()
	    return ( 0 )

    return ( 1 )

#------------------------------------------------------------------------------

def readLabeledFASTAfile ( fastaFilename ):

    print ' '

    # open the file ...
    try:
	fh = file ( fastaFilename )
	# print '     in readLabeledFASTAfile : reading <%s> ' % fastaFilename
    except IOError:
	print '     in readLabeledFASTAfile : failed to open <%s> ' %\
			fastaFilename
	return ( [] )

    numLines = 0
    readNewSeq = 0
    allSeqs = []

    # read each line in the input file ...
    for aLine in fh:

	# if this line is either blank or a comment, skip it ...
	if ( isBlank ( aLine ) ): continue
	if ( aLine[0] == '%' ): continue

	if ( aLine[0] == '>' ):
	    # if the line starts with a '>' then we are at the beginning of a 
	    # sequence and what follows is the sequence name and possibly 
	    # other information as well ...

	    # if we have just been reading in a sequence, then we need to finish
	    # that off before we start a new one ...
	    if ( readNewSeq ):

		# before we add this to our list of sequences, make sure that
		# the seqString and all of the labelStrings are of the same 
		# length ...
		if ( checkSequence ( seqString, labels ) ):
		    allSeqs += [ ( tokenList, seqString, labels ) ]
		    
		readNewSeq = 0

	    # now we get the new header tokens for this sequence ...
	    headerLine = aLine[1:-1]
	    tokenList = headerLine.split()

	    # and now we are ready to read a sequence
	    readNewSeq = 1
	    seqString = ''
	    labels = {}

	elif ( readNewSeq ):

	    if ( aLine[0] == ' ' ):
		seqString += stripBlanks ( aLine[:-1] )

	    elif ( aLine[0] == '#' ):
		addPrimaryLabelString ( aLine[1:-1], labels )

	    elif ( aLine[0] == '?' ):
		addOtherLabelString ( aLine[1:-1], labels )
	    

    # once we get here, we are done reading the file so we can close it ...
    fh.close()

    # and then we need to add the last sequence to our list ...
    allSeqs += [ ( tokenList, seqString, labels ) ]

    # and we are done
    return ( allSeqs )

#------------------------------------------------------------------------------

# the 20 canonical amino acids ...
CanonicalAminos = [ 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', \
		    'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y' ]

# we will actually use a list of length 24, where:
#	the unknown "X" occupies the "special" position 0
# 	"U"(21) is the 21st amino acid, selenocysteine
#	the ambiguous "B"(22) can represent either 'D'(3) or 'N'(12)
# and 	the ambiguous "Z"(23) can represent either 'E'(4) or 'Q'(14)

aaList = [ 'X' ] + CanonicalAminos + [ 'U', 'B', 'Z' ]

# transmembrane protein labels (again, the first position is kept for 'unknown')
# IMPORTANT : these indices must be consistent with the definitions used in
# defining the GMTK structures / parameters / etc !!!

tmList = [ '.', 	##  0 : unknown state/label
           'i', 	##  1 : cytoplasmic loop ('inside')
	   'M', 	##  2 : Membrane helix
	   'o', 	##  3 : short non-cytoplasmic loop ('outside')
	   'O',		##  4 : long non-cytoplasmic loop ('OUTSIDE')
	   'n',		##  5 : N-terminus of signal peptide
	   'h',		##  6 : hydrophobic region of signal peptide
	   'c',		##  7 : signal peptide : before cleavage point (multiple residues)
	   'C',		##  8 : signal peptide : cleavage point (single residue)
	   's' ]	##  9 : generic label for signal peptide

#------------------------------------------------------------------------------

def aaId(aaInitial):
    try:
	return ( aaList.index(aaInitial) )
    except:
	return ( 0 )

#------------------------------------------------------------------------------

def tmId(tmInitial):
    try:
	return ( tmList.index(tmInitial) )
    except:
	return ( 0 )

#------------------------------------------------------------------------------

def makeGMTKobsFilename ( nameTokens, faaRootname ):

    # print ' in makeGMTKobsFilename ... '
    # print nameTokens
    # print faaRootname

    filename = nameTokens[0]

    # if there appears to be some sort of id number in the next token,
    # we'll use that for the file name too ...
    if ( len(nameTokens) > 1 ):
        if ( nameTokens[1][0] == 'P'  or  nameTokens[1][0] == 'Q' ):
	    if ( ord(nameTokens[1][1]) in range(48,58) ):
	        filename += '_' + nameTokens[1][0]
	        for ii in range(1,len(nameTokens[1])):
	            if ( ord(nameTokens[1][ii]) in range(48,58) ):
		        filename += nameTokens[1][ii]

    filename += '.' + faaRootname + '.obs'

    return ( filename )

#------------------------------------------------------------------------------

def getRootName ( inFile ):

    i1 = len(inFile) - 1
    while ( i1 > 0  and  inFile[i1] != '/' ): i1 -= 1
    if ( inFile[i1] == '/' ): i1 += 1

    i2 = len(inFile) - 1
    while ( i2 > 0  and  inFile[i2] != '.' ): i2 -= 1

    return ( inFile[i1:i2] )

#------------------------------------------------------------------------------

def writeGMTKobsFiles ( allSeqs, faaFilename ):

    # we want the 'root' of the faaFilename to use as part of the GMTK
    # observation filename ...
    faaRootname = getRootName ( faaFilename )

    numSeq = len(allSeqs)

    for iSeq in range(numSeq):

	# for each sequence, we need to write out a single GMTK observation
	# file ... if the name tokens for the sequence look like:
	#	MOTA_ECOLI P09348; 295 AA.
	# we will write an output file called "MOTA_ECOLI_P09384.obs"

	# new 6/15 : the name will include, for example, ".t0." to indicate
	# which .faa file it originally came from
	gmtkFilename = makeGMTKobsFilename ( allSeqs[iSeq][0], faaRootname )

	try:
	    fh = file ( gmtkFilename, 'w' )
	except:
	    print ' ERROR in writeGMTKobsFiles : failed to open output file ', gmtkFilename
	    sys.exit(-1)

	print ' '
	print ' writing sequence to output file ', gmtkFilename

	seqString = allSeqs[iSeq][1]
	allLabels = allSeqs[iSeq][2]

	# sanity check ...
	if ( '#' not in allLabels.keys() ):
	    print ' ERROR in writeGMTKobsFiles ??? no primary label string ??? '
	    print seqString
	    print allLabels.keys()
	    sys.exit(-1)

	seqLength = len(seqString)
	labString = allLabels['#']

	# the TOPDB files use 'I' rather than 'i' ...
	newLabString = ''
	lastChar = 'X'
	for ii in range(len(labString)):
	    if ( labString[ii] == 'I' ):
		newLabString += 'i'
		lastChar = 'i'
	    elif ( labString[ii] == 'L' ):
		newLabString += lastChar
	    else:
		newLabString += labString[ii]
		lastChar = labString[ii]

	# also, the TOPDB files sometimes put one or more L's within a 
	# stretch of O's or I's ...

	for ii in range(seqLength):
	    fh.write ( ' %3d  %3d \n' % \
		      ( aaId(seqString[ii]), tmId(newLabString[ii]) ) )

	fh.write('\n')
	fh.close()
	    

#------------------------------------------------------------------------------

def findBoundaryLocations ( labelString ):

    boundaryList = []

    # walk along the labelString looking for a change in labeling,
    # e.g. "ooooooooMMMMMMMM"
    #
    # NOTE that "oooooo....MMMMMM" is NOT considered a 'change' in labeling ...

    for ii in range(1,len(labelString)):
	if ( labelString[ii] != labelString[ii-1] ): 
	    if ( labelString[ii] != '.'  and  labelString[ii-1] != '.' ):
		boundaryList += [ ii ]

    return ( boundaryList )

#------------------------------------------------------------------------------

def removeBoundaryLabels ( allSeqs, winSize ):

    numSeq = len(allSeqs)

    # for each sequence in allSeqs :
    for iSeq in range(numSeq):

	seqString = allSeqs[iSeq][1]
	allLabels = allSeqs[iSeq][2]

	numLabels = len(allLabels)

	# for each label-string :
	for aLabelId in allLabels.keys():

	    labelLength = len ( allLabels[aLabelId] )
	    labelString = allLabels[aLabelId]
	    # print labelString

	    # get the list of boundaries in labelString
	    boundaryList = findBoundaryLocations ( labelString )

	    boundaryList = [0] + boundaryList + [len(labelString)]
	    # print boundaryList
	    numB = len(boundaryList)

	    newLabelString = ''
	    kk = 0

	    for iB in range(1,numB-1):
	 	# print ' '
		aLen = boundaryList[iB] - boundaryList[iB-1]
		bLen = boundaryList[iB+1] - boundaryList[iB]
		aMax = (aLen-1)/2
		bMax = (bLen-1)/2
		aWin = min(winSize,aMax)
		bWin = min(winSize,bMax)
		# print boundaryList[iB], aMax, aWin, bMax, bWin

		while ( kk < boundaryList[iB]-aWin ):
		    # print ' copying old labelString at ', kk
		    newLabelString += labelString[kk]
		    kk += 1
		while ( kk < boundaryList[iB]+bWin ):
		    # print ' setting new labels to blank at ', kk
		    newLabelString += '.'
		    kk += 1


	    while kk < len(labelString):
		newLabelString += labelString[kk]
		kk += 1

	    # print newLabelString

	    # once we have the new labelString, we need to assign
	    # it in place of the old one ...
	    allLabels[aLabelId] = newLabelString

    return ( allSeqs )

#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
# main program is here -- assumes command-line argument giving name of .faa
# file to be read

import sys

if __name__ == "__main__":

    if ( len(sys.argv)!=2 and len(sys.argv)!=3 ):
	print ' Usage : %s <faa filename> <unlabeled window length (default=5)> ' % sys.argv[0] 
	sys.exit(-1)

    faaFilename = sys.argv[1]

    winSize = 5
    if ( len(sys.argv) == 3 ):
	try:
	    winSize = int ( sys.argv[2] )
	except:
	    print ' ERROR in second argument -- should be an integer > 0 '
	    sys.exit(-1)

    if ( winSize < 0 ):
	print ' ERROR in second argument -- should be an integer >= 0 '
    
    allSeqs = readLabeledFASTAfile ( faaFilename )

    numSeq = len(allSeqs)
    print '>>> Got %d sequences from %s ' % ( numSeq, faaFilename )

    # before we write out these sequences, we may want to "unlabel" the
    # portions near any boundaries ...
    if ( winSize > 0 ):
	removeBoundaryLabels ( allSeqs, winSize )

    writeGMTKobsFiles ( allSeqs, faaFilename )

#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
