DEBUG = False
UP_ARROW, LEFT_ARROW, DIAG_ARROW = 0, 1, 2

# v - First String - length m
# w - Second String - length n
def LCS(v,w):
	# m - Number of rows in matrix
	# n - Number of columns in matrix
	m,n = len(v)+1,len(w)+1
	# Initialize s and b matrices
	s,b = [],[]
	for i in range(m):
		s.append([0]*n)
		b.append([0]*n)
	for i in range(1,m):
		for j in range(1,n):
			if v[i-1] == w[j-1]:
				s[i][j] = s[i-1][j-1] + 1
				b[i][j] = DIAG_ARROW				
			else:
				a = [s[i-1][j]-SIGMA, s[i][j-1]-SIGMA, s[i-1][j-1]-MU]
				b[i][j], s[i][j] = posmax(a)
	return [s[m-1][n-1],b]

def PrintLCS(b,v,w):
	i, j = len(v), len(w)
	s = ["",""]
	lcs = ""
	while i > 0 or j > 0:
		if i > 0 and j > 0:
			if b[i][j] == UP_ARROW:
				s = [v[i-1] + s[0],"-"+s[1]]
				i -= 1
			elif b[i][j] == LEFT_ARROW:
				s = ["-"+s[0],w[j-1]+s[1]]
				j -= 1
			else:
				s = [v[i-1]+s[0],w[j-1]+s[1]]
				if v[i-1] == w[j-1]:
					lcs = w[j-1] + lcs
				i,j = i-1,j-1
		elif i > 0 and j == 0:
			s = [v[i-1] + s[0],"-"+s[1]]
			i -= 1
		else:
			s = ["-"+s[0],w[j-1]+s[1]]
			j -= 1
	print "The longest commmon subsequence is %s" % (lcs)
	print "Alignment: "
	print "\t%s" % (s[0])
	print "\t%s" % (s[1])
	return lcs
				

# Returns a 2-tuple (key, value) for the max value
def posmax(seq, key=lambda x:x):
	return max(enumerate(seq), key=lambda k: key(k[1]))
 
def Case(n, v, w):
	print "Case %d with MU = %d and SIGMA = %d" % (n, MU, SIGMA)
	distance, b = LCS(v,w)
	print "The similarity score for %s and %s is: %d" % (v, w, distance)
	PrintLCS(b,v,w)
	print ""

def Cases():
	Case(1,"ATCTGAT", "TGCATA")
	Case(2,"GTAGGCTTAAGGTTA","TAGATA") 
	Case(3,"AGTCCAATGACTCCAGT","TCGCCTGTTAAGC")
	Case(4,"GGACGTACG","TACGGGTAT")
	Case(5,"GGGGGGGGGGGGG","AAAAAAAA")
	

if __name__ == '__main__':
	# Normal Cases
	# MU - Mismatch Penalty
	# SIGMA - Indel Penalty     		
	SIGMA, MU = 0, 0
	Cases()
	# Experimental Cases
	SIGMA, MU = 1, 1 # SIGMA == MU
	Cases()
	SIGMA, MU = 0,1 # SIGMA < MU
	Cases()
	SIGMA, MU = 1,0 # SIGMA > MU
	Cases()
	if DEBUG:
		SIGMA, MU = 0,0
		a = [500,400,300]
		print posmax(a) == (0,500)
		a = [0,0,0]
		print posmax(a) == (0,0)
		s1, s2 = "ATCTGAT","TGCATA" 
		distance, b = LCS(s1,s2)
		print distance == 4
		print PrintLCS(b,s1,s2) == "TCTA"