Skip to content

Commit

Permalink
#84 Simple test case now passes
Browse files Browse the repository at this point in the history
  • Loading branch information
weka511 committed Feb 26, 2023
1 parent 8031b53 commit 5add434
Showing 1 changed file with 39 additions and 12 deletions.
51 changes: 39 additions & 12 deletions hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,17 +165,23 @@ def ConstructProfileHMM(theta,Alphabet,Alignment,sigma=0):
BA10E Construct a Profile HMM
Parameters:
theta Threshold. This isn't the same as the theta in the textbook
See David Eccles and fanta's comments - http://rosalind.info/problems/ba10e/questions/
Is this true?
theta Threshold.
Alphabet
Alignment
sigma
'''

def is_space(c):
return c=='-'
def is_space(ch):
'''
Check whether chatacter is a space (represented by hyphen)
'''
return ch=='-'

def create_mask():
'''
Construct a mask to exclude columns from an alignment
if the freaction of spaces exceeds theta
'''
def get_count(i):
return sum([is_space(c) for s in Alignment for c in s[i]])

Expand Down Expand Up @@ -235,6 +241,21 @@ def verify_constraints(matrix,eps=1e-15):
if row_total>eps and row_total<1-eps:
raise Exception(f'Constraint violated on row {i}, eps = {eps}')

def get_successors(state,m):
block = (state+1)//3
for i in range(3):
successor = 3*block+i
if successor<m:
yield(successor)

def normalize_rows(product):
row_totals = product.sum(axis=1)
row_totals[-1] = 1
for i in range(m):
if row_totals[i]==0:
row_totals[i]=1
return product/row_totals.reshape(m,1)

def create_transition(m,Paths):
def create_census():
States = {}
Expand All @@ -246,6 +267,10 @@ def create_census():
return States

product = np.zeros((m,m))
for i in range(m):
for j in get_successors(i,m):
product[i,j] = sigma

States = create_census()
for key1,successors in States.items():
state1,index1 = split_key(key1)
Expand All @@ -261,8 +286,8 @@ def create_census():
state_index2 = get_state_index((state2,index2))
product[state_index1,state_index2] = fraction

verify_constraints(product)
return product
return normalize_rows(product)


def create_emission(m,n,Paths):
def create_census():
Expand All @@ -275,20 +300,22 @@ def create_census():
return States

product = np.zeros((m,n))

States = create_census()
for key,chars in States.items():
state,index = split_key(key)
if state in ['M','I']:
counts = {ch:0 for ch in Alphabet}
for ch in chars:
counts[ch] += 1
fractions = {ch:count/len(chars) for ch,count in counts.items()}
fractions = {ch:max(count/len(chars),sigma) for ch,count in counts.items()}
state_index = get_state_index((state,index))
for j in range(n):
product[state_index,j] = fractions[Alphabet[j]]
return normalize_rows(product)



verify_constraints(product)
return product

mask = create_mask()
Paths = [create_path(s,mask) for s in Alignment]
Expand Down Expand Up @@ -480,6 +507,6 @@ def test_ba10f1(self):
self.assertEqual(12,m)
self.assertEqual(m1,m)
self.assertEqual(m2,m)
self.assertEqual(0.01, Transition[0,1])
self.assertEqual(0.01, Emission[2,1])
self.assertAlmostEqual(0.01, Transition[0,1],places=3)
self.assertAlmostEqual(0.01, Emission[2,1],places=3)
main()

0 comments on commit 5add434

Please sign in to comment.