import os
import sys
import math

def find_in_dic(matrix,key):
    if key in matrix:
        return matrix[key]
    else:
        return 0.0    
    
# Task: add core H-type pseudoknot to dictionaries according to the first
# and second pseudoknot stems, but only if certain length requirements hold
def add_pk(key, key1, key2, l1, l2, l3, pseudoknot_first, pseudoknot_second):
    if l2 >= 0 and l2 < 400:            
        if l1 >= 0 and l1 < 400: 
            if l3 >= 0 and l3 < 400: 
                if (l1 > 3 or l3 > 3) and (l1 + l2 + l3) < 400:                                
                    if pseudoknot_second.get(key2):
                        pseudoknot_second[key2].append(key)
                    else:                                
                        pseudoknot_second[key2] = []
                        pseudoknot_second[key2].append(key)
                    if pseudoknot_first.get(key1):
                        pseudoknot_first[key1].append(key)
                    else:                                
                        pseudoknot_first[key1] = []
                        pseudoknot_first[key1].append(key)
    return pseudoknot_first, pseudoknot_second
    
def resolve_overlap_l1(i, j, stemlength1, k, l, stemlength2, l1, l2, l3):  
    # Case that base pair overlap of 1 bp occurs at L1
    if l1 == -1 and stemlength1 >= 4 and l2 >= 0:          
        stemlength1 = stemlength1 - 1       # Cut stem S1 if possible
        l1 = l1 + 1
        l2 = l2 + 1
        s1_shortended = i, j, stemlength1
        key = i, j, stemlength1, k, l, stemlength2
    # Case that base pair overlap of 2 bp occurs at L1
    elif l1 == -2 and stemlength1 >= 5 and l2 >= -1:          
        stemlength1 = stemlength1 - 2       # Cut stem S1 if possible
        l1 = l1 + 2
        l2 = l2 + 2
        s1_shortended = i, j, stemlength1 
        key = i, j, stemlength1, k, l, stemlength2
    else:
        key, s1_shortended = None, None
    return key, s1_shortended, l1, l2                                                    

def resolve_overlap_l2_S1(i, j, stemlength1, k, l, stemlength2, l1, l2, l3):  
    # Case that base pair overlap of 1 bp occurs at L2
    if l2 == -1 and stemlength1 >= 4 and l1 >= 0:          
        stemlength1 = stemlength1 - 1       # Cut stem S1 if possible
        l1 = l1 + 1
        l2 = l2 + 1        
        s1_shortended = i, j, stemlength1
        key = i, j, stemlength1, k, l, stemlength2
    # Case that base pair overlap of 2 bp occurs at L2
    elif l2 == -2 and stemlength1 >= 5 and l1 >= -1:          
        stemlength1 = stemlength1 - 2       # Cut stem S1 if possible
        l1 = l1 + 2
        l2 = l2 + 2        
        s1_shortended = i, j, stemlength1 
        key = i, j, stemlength1, k, l, stemlength2
    else:
        key, s1_shortended = None, None
    return key, s1_shortended, l1, l2     

def resolve_overlap_l2_S2(i, j, stemlength1, k, l, stemlength2, l1, l2, l3):
    # Case that base pair overlap of 1 bp occurs at L2
    if l2 == -1 and stemlength2 >= 4 and l3 >= 0:          
        stemlength2 = stemlength2 - 1       # Cut stem S2 if possible
        l2 = l2 + 1
        l3 = l3 + 1
        s2_shortended = k, l, stemlength2   
        key = i, j, stemlength1, k, l, stemlength2
    # Case that base pair overlap of 2 bp occurs at L2
    elif l2 == -2 and stemlength2 >= 5 and l3 >= -1:          
        stemlength2 = stemlength2 - 2       # Cut stem S2 if possible
        l2 = l2 + 2
        l3 = l3 + 2        
        s2_shortended = k, l, stemlength2   
        key = i, j, stemlength1, k, l, stemlength2
    else:
        key, s2_shortended = None, None        
    return key, s2_shortended, l2, l3

def resolve_overlap_l3(i, j, stemlength1, k, l, stemlength2, l1, l2, l3):
    # Case that base pair overlap of 1 bp occurs at L3
    if l3 == -1 and stemlength2 >= 4 and l2 >= 0:      
        stemlength2 = stemlength2 - 1       # Cut stem S2 if possible
        l2 = l2 + 1
        l3 = l3 + 1        
        s2_shortended = k, l, stemlength2   
        key = i, j, stemlength1, k, l, stemlength2
    # Case that base pair overlap of 2 bp occurs at L3
    elif l3 == -2 and stemlength2 >= 5 and l2 >= -1:      
        stemlength2 = stemlength2 - 2       # Cut stem S2 if possible
        l2 = l2 + 2
        l3 = l3 + 2                
        s2_shortended = k, l, stemlength2   
        key = i, j, stemlength1, k, l, stemlength2
    else:
        key, s2_shortended = None, None        
    return key, s2_shortended, l2, l3

# Taks: Construct core H-type pseudoknots with regular stems. Return core
# H-type pseudoknot dictionary and dictionary for the shortened stems
def build_pseudoknots(matrix_stems):
    pseudoknot_dic = {}    
    pseudoknot_second = {}  # Store all pseudoknots by their second stem    
    pseudoknot_first = {}   # Store all pseudoknots by their first stem
    stems_shortened = {}
    
    matrix_stems_list =  matrix_stems.items()
    matrix_stems_list.sort()
    for x in xrange(len(matrix_stems_list)):       
        for y in xrange(x,len(matrix_stems_list)):
            i, j = matrix_stems_list[x][0][0], matrix_stems_list[x][0][1]
            key1 = i, j
            stemlength1 = find_in_dic(matrix_stems,key1)[0]
            k, l = matrix_stems_list[y][0][0], matrix_stems_list[y][0][1]
            key2 = k, l            
            stemlength2 = find_in_dic(matrix_stems,key2)[0]
            l1 = k - (i + stemlength1)            
            l2 = (j - stemlength1 + 1) - (k + stemlength2)
            l3 = (l - stemlength2) - j 
            # Form core H-type pseudoknot
            if (l - i + 1) >= 16 and (l - i + 1) < 400:        
                # Case that no base pair overlap occurs
                key = i, j, stemlength1, k, l, stemlength2     
                pseudoknot_first, pseudoknot_second = add_pk(key, key1, key2, l1, l2, l3, pseudoknot_first, pseudoknot_second)                
                # First case of overlap at loops L1 and L2
                if l3 > 3 and l1 <= 3:
                    if (l1 == -1 or l1 == -2) and l2 >= -1:
                        key, s1_shortended, l1, l2 = resolve_overlap_l1(i, j, stemlength1, k, l, stemlength2, l1, l2, l3)
                        if key and s1_shortended:                    
                            stems_shortened[s1_shortended] = 0.0
                            pseudoknot_first, pseudoknot_second = add_pk(key, key1, key2, l1, l2, l3, pseudoknot_first, pseudoknot_second)
                    if (l2 == -1 or l2 == -2) and l1 >= -1:
                        key, s1_shortended, l1, l2 = resolve_overlap_l2_S1(i, j, stemlength1, k, l, stemlength2, l1, l2, l3)
                        if key and s1_shortended:                    
                            stems_shortened[s1_shortended] = 0.0
                            pseudoknot_first, pseudoknot_second = add_pk(key, key1, key2, l1, l2, l3, pseudoknot_first, pseudoknot_second)                            
                # Second case of overlap at loops L2 and L3
                if l1 > 3 and l3 <= 3:
                    if (l2 == -1 or l2 == -2) and l3 >= -1:
                        key, s2_shortended, l2, l3 = resolve_overlap_l2_S2(i, j, stemlength1, k, l, stemlength2, l1, l2, l3)
                        if key and s2_shortended:                            
                            stems_shortened[s2_shortended] = 0.0
                            pseudoknot_first, pseudoknot_second = add_pk(key, key1, key2, l1, l2, l3, pseudoknot_first, pseudoknot_second)                               
                    if (l3 == -1 or l3 == -2) and l2 >= -1:
                        key, s2_shortended, l2, l3 = resolve_overlap_l3(i, j, stemlength1, k, l, stemlength2, l1, l2, l3)
                        if key and s2_shortended:                            
                            stems_shortened[s2_shortended] = 0.0
                            pseudoknot_first, pseudoknot_second = add_pk(key, key1, key2, l1, l2, l3, pseudoknot_first, pseudoknot_second)                            
                # Third case of overlap at loop L2
                if l1 >= 3 and l3 >= 3:                
                    if l2 == -1:
                        if stemlength1 >= stemlength2 and stemlength1 >= 4:      
                            stemlength1 = stemlength1 - 1       # Cut longer stem S1
                            l1 = l1 + 1
                            l2 = l2 + 1
                            s1_shortended = i, j, stemlength1   # Add to stem dictionary
                            stems_shortened[s1_shortended] = 0.0                        
                            key = i, j, stemlength1, k, l, stemlength2                                                                     
                            pseudoknot_first, pseudoknot_second = add_pk(key, key1, key2, l1, l2, l3, pseudoknot_first, pseudoknot_second)
                        if stemlength2 > stemlength1 and stemlength2 >= 4:      
                            stemlength2 = stemlength2 - 1       # Cut longer stem S2
                            l2 = l2 + 1
                            l3 = l3 + 1                            
                            s2_shortended = k, l, stemlength2   # Add to stem dictionary
                            stems_shortened[s2_shortended] = 0.0                        
                            key = i, j, stemlength1, k, l, stemlength2                        
                            pseudoknot_first, pseudoknot_second = add_pk(key, key1, key2, l1, l2, l3, pseudoknot_first, pseudoknot_second)                                                    
                    # Case that base pair overlap of 2 bp occurs at L2                            
                    elif l2 == -2: 
                        if stemlength1 >= stemlength2 and stemlength1 >= 5:
                            stemlength1 = stemlength1 - 2       # Cut longer stem S1
                            l1 = l1 + 2
                            l2 = l2 + 2                            
                            s1_shortended = i, j, stemlength1   # Add to stem dictionary
                            stems_shortened[s1_shortended] = 0.0                        
                            key = i, j, stemlength1, k, l, stemlength2                                                                     
                            pseudoknot_first, pseudoknot_second = add_pk(key, key1, key2, l1, l2, l3, pseudoknot_first, pseudoknot_second)
                        if stemlength2 > stemlength1 and stemlength2 >= 5:
                            stemlength2 = stemlength2 - 2       # Cut longer stem S2
                            l2 = l2 + 2
                            l3 = l3 + 2                            
                            s2_shortended = k, l, stemlength2   # Add to stem dictionary
                            stems_shortened[s2_shortended] = 0.0                        
                            key = i, j, stemlength1, k, l, stemlength2                        
                            pseudoknot_first, pseudoknot_second = add_pk(key, key1, key2, l1, l2, l3, pseudoknot_first, pseudoknot_second)                            
    return pseudoknot_second, pseudoknot_first, stems_shortened

# Task: if a secondary structure element has lower free energy than a plain stem,
# the plain stem does not need to be considered in the recursive elements search.
def filter_kissing_stems_mwis(matrix_stems, bulge_internal_dic, multiloops):
    matrix_stems_mwis = matrix_stems.copy()
    for stem, value in matrix_stems.iteritems():
        i, j = stem[0], stem[1]              
        for stem_ib, value_ib in bulge_internal_dic.iteritems():
            if i == stem_ib[0] and j == stem_ib[1] and value_ib[2] <= value[0]:
                if stem in matrix_stems_mwis:
                    del matrix_stems_mwis[stem]         
        for stem_ml, value_ml in multiloops.iteritems():
            if i == stem_ml[0] and j == stem_ml[1] and value_ml[2] <= value[0]:
                if stem in matrix_stems_mwis:
                    del matrix_stems_mwis[stem]                       
    return matrix_stems_mwis

# Task: assemble kissing hairpins from core pseudoknots. Search for recursive elements.
# Filter kissing hairpin candidates with low probability. Evaluate energies.
def kissing_hairpins(pseudoknot_second, pseudoknot_first, matrix_stems, stems_shortened_dic, matrix_stems_mwis, bulge_internal_dic, multiloops, init, unpaired_nt, unpaired_nt_l3, pk_core_dic):
    lookup_dic_L1, lookup_dic_L2, lookup_dic_L3, lookup_dic_L4, lookup_dic_L5 = {}, {}, {}, {}, {}
    best_khps = {}       
    # pseudoknot_second[key2] stores all pks which have key2 as second stem
    # pseudoknot_first[key1] stores all pks which have key1 as first stem
    for key2 in sorted(pseudoknot_second):
        if key2 in pseudoknot_first:
            pks_2 = pseudoknot_second[key2]
            pks_1 = pseudoknot_first[key2]
            for pk2 in pks_2:
                for pk1 in pks_1:
                    if pk2[1] < pk1[3]:     # j < m, otherwise it is a triple helix                            
                        stem1 = pk2[0], pk2[1]
                        stem2 = pk2[3], pk2[4]
                        stem3 = pk1[3], pk1[4]
                        length = stem3[1] - stem1[0] + 1
                        
                        if length < 400:    # Length restriction                           
                            stemlength1 = find_in_dic(matrix_stems, stem1)[0]
                            prob1 = find_in_dic(matrix_stems,stem1)[1]                                        
                            if stemlength1 != pk2[2]:     # Look up if S1 is a shortened stem
                                stem1 = pk2[0], pk2[1], pk2[2]
                                # If not found, the shortened stem was deleted before because of high free energy
                                if find_in_dic(stems_shortened_dic,stem1) != 0.0:
                                    stack_energy1 = find_in_dic(stems_shortened_dic, stem1)[2]
                                    stemlength1 = find_in_dic(stems_shortened_dic,stem1)[0]
                                else:
                                    stack_energy1 = 100
                            else:
                                stack_energy1 = find_in_dic(matrix_stems, stem1)[2]                            
                            stemlength2 = find_in_dic(matrix_stems, stem2)[0]
                            prob2 = find_in_dic(matrix_stems,stem2)[1]
                            if stemlength2 !=  min(pk2[5],pk1[2]) :     # Look up if S2 is a shortened stem
                                stem2 = pk2[3], pk2[4], min(pk2[5],pk1[2]) 
                                # If not found, the shortened stem was deleted before because of high free energy
                                if find_in_dic(stems_shortened_dic,stem2) != 0.0:
                                    stack_energy2 = find_in_dic(stems_shortened_dic, stem2)[2]
                                    stemlength2 = find_in_dic(stems_shortened_dic, stem2)[0]
                                else:
                                    stack_energy2 = 100
                            else:     
                                stack_energy2 = find_in_dic(matrix_stems, stem2)[2]                                            
                            stemlength3 = find_in_dic(matrix_stems, stem3)[0]
                            prob3 = find_in_dic(matrix_stems,stem3)[1]
                            if stemlength3 != pk1[5]:     # Look up if S3 is a shortened stem
                                stem3 = pk1[3], pk1[4], pk1[5]  
                                # If not found, the shortened stem was deleted before because of high free energy
                                if find_in_dic(stems_shortened_dic,stem3) != 0.0:
                                    stack_energy3 = find_in_dic(stems_shortened_dic, stem3)[2]
                                    stemlength3 = find_in_dic(stems_shortened_dic, stem3)[0]
                                else:
                                    stack_energy3 = 100
                            else:
                                stack_energy3 = find_in_dic(matrix_stems, stem3)[2]                                                            

                            l1 = stem2[0] - (stem1[0] + stemlength1)
                            l2 = (stem1[1] - stemlength1 + 1) - (stem2[0] + stemlength2)
                            l3 = stem3[0] - stem1[1] - 1
                            l4 = (stem2[1] - stemlength2 + 1) - (stem3[0] + stemlength3)
                            l5 = (stem3[1] - stemlength3) - stem2[1]
                            
                            # Loop length requirement            
                            if (l1 > 0 or l2 > 0) and (l4 > 0 or l5 > 0):                                 				
				# Set a probability threshold 
				if prob1 + prob2 + prob3 > 0.001: 
                                    if stack_energy1 != 100 and stack_energy2 != 100 and stack_energy3 != 100:                                        
                                        kissing_hairpin = stem1, stemlength1, stem2, stemlength2, stem3, stemlength3
                                        candidate_list1, candidate_list2, candidate_list3, candidate_list4, candidate_list5 = [], [], [], [], []
                                        result1, result2, result3, result4, result5 = [], [], [], [], []

                                        # Look for recursive kissing_hairpins in loop L1
                                        L1_start = kissing_hairpin[0][0] + kissing_hairpin[1]
                                        L1_end = kissing_hairpin[2][0] - 1
                                        l1 = L1_end - L1_start + 1        
                                        # Look for recursive kissing_hairpins in loop L2        
                                        L2_start = kissing_hairpin[2][0] + kissing_hairpin[3]
                                        L2_end = kissing_hairpin[0][1] - kissing_hairpin[1]
                                        l2 = L2_end - L2_start + 1        
                                        # Look for recursive kissing_hairpins in loop L3
                                        L3_start = kissing_hairpin[0][1] + 1
                                        L3_end = kissing_hairpin[4][0] - 1
                                        l3 = L3_end - L3_start + 1        
                                        # Look for recursive kissing_hairpins in loop L4
                                        L4_start = kissing_hairpin[4][0] + kissing_hairpin[5]
                                        L4_end = kissing_hairpin[2][1] - kissing_hairpin[3]
                                        l4 = L4_end - L4_start + 1       
                                        # Look for recursive kissing_hairpins in loop L5
                                        L5_start = kissing_hairpin[2][1] + 1
                                        L5_end = kissing_hairpin[4][1] - kissing_hairpin[5]
                                        l5 = L5_end - L5_start + 1

                                        L1_key = L1_start, L1_end        
                                        L2_key = L2_start, L2_end
                                        L3_key = L3_start, L3_end
                                        L4_key = L4_start, L4_end
                                        L5_key = L5_start, L5_end
                                        pk = kissing_hairpin[0][0], kissing_hairpin[0][1], kissing_hairpin[2][0], kissing_hairpin[2][1], kissing_hairpin[4][0], kissing_hairpin[4][1]
                                        pk1_key = kissing_hairpin[0][0], kissing_hairpin[2][1], kissing_hairpin[0][0], kissing_hairpin[0][1], kissing_hairpin[1], kissing_hairpin[2][0], kissing_hairpin[2][1], kissing_hairpin[3], 'r'
                                        pk2_key = kissing_hairpin[2][0], kissing_hairpin[4][1], kissing_hairpin[2][0], kissing_hairpin[2][1], kissing_hairpin[3], kissing_hairpin[4][0], kissing_hairpin[4][1], kissing_hairpin[5], 'r'
                                       
                                        if l1 >= 9 and (l1 + l2 > 9):
                                            # Check if this information can be found in core pseudoknot dictionary
                                            if pk1_key in pk_core_dic:
                                                result1 = pk_core_dic[pk1_key][1]
                                            elif L1_key in lookup_dic_L1:
                                                result1 = lookup_dic_L1[L1_key]                 
                                            else:                                                        
                                                candidate_list1 = candidate_list(L1_start, L1_end, matrix_stems_mwis, bulge_internal_dic, multiloops)
                                                if candidate_list1:
                                                    if len(candidate_list1) == 1:
                                                        result1 = candidate_list1                      
                                                    else:
                                                        candidate_list1.sort()
                                                        sorted_endpoint_list = create_sorted_endpointlist(candidate_list1)
                                                        result1 = MWIS(candidate_list1, sorted_endpoint_list)
                                                        lookup_dic_L1[L1_key] = result1
                                                 
                                        if l2 >= 9 and (l1 + l2 > 9):
                                            # Check if this information can be found in core pseudoknot dictionary
                                            if pk1_key in pk_core_dic:
                                                result2 = pk_core_dic[pk1_key][2]
                                            elif L2_key in lookup_dic_L2:
                                                result2 = lookup_dic_L2[L2_key]                 
                                            else:
                                                candidate_list2 = candidate_list(L2_start, L2_end, matrix_stems_mwis, bulge_internal_dic, multiloops)
                                                if candidate_list2:
                                                    if len(candidate_list2) == 1:
                                                        result2 = candidate_list2                                                                    
                                                    else:
                                                        candidate_list2.sort()
                                                        sorted_endpoint_list = create_sorted_endpointlist(candidate_list2)
                                                        result2 = MWIS(candidate_list2, sorted_endpoint_list)                        
                                                        lookup_dic_L2[L2_key] = result2                                             
                                        if l3 >= 9:  
                                            if L3_key in lookup_dic_L3:
                                                result3 = lookup_dic_L3[L3_key]                                                          
                                            else:
                                                candidate_list3 = candidate_list(L3_start, L3_end, matrix_stems_mwis, bulge_internal_dic, multiloops)
                                                if candidate_list3:
                                                    if len(candidate_list3) == 1:
                                                        result3 = candidate_list3                                                                                             
                                                    else:
                                                        candidate_list3.sort()
                                                        sorted_endpoint_list = create_sorted_endpointlist(candidate_list3)
                                                        result3 = MWIS(candidate_list3, sorted_endpoint_list)                        
                                                        lookup_dic_L3[L3_key] = result3                                                                   
                                        if l4 >= 9 and (l4 + l5 > 9):
                                            # Check if this information can be found in core pseudoknot dictionary
                                            if pk2_key in pk_core_dic:
                                                result4 = pk_core_dic[pk2_key][2]                                                           
                                            elif L4_key in lookup_dic_L4:
                                                result4 = lookup_dic_L4[L4_key]                                                  
                                            else:
                                                candidate_list4 = candidate_list(L4_start, L4_end, matrix_stems_mwis, bulge_internal_dic, multiloops)
                                                if candidate_list4:
                                                    if len(candidate_list4) == 1:
                                                        result4 = candidate_list4                                                                                           
                                                    else:
                                                        candidate_list4.sort()
                                                        sorted_endpoint_list = create_sorted_endpointlist(candidate_list4)
                                                        result4 = MWIS(candidate_list4, sorted_endpoint_list)                           
                                                        lookup_dic_L4[L4_key] = result4                                                                    
                                        if l5 >= 9 and (l4 + l5 > 9):
                                            # Check if this information can be found in core pseudoknot dictionary
                                            if pk2_key in pk_core_dic:
                                                result5 = pk_core_dic[pk2_key][3]                                                           
                                            elif L5_key in lookup_dic_L5:
                                                result5 = lookup_dic_L5[L5_key]                                                            
                                            else:
                                                candidate_list5 = candidate_list(L5_start, L5_end, matrix_stems_mwis, bulge_internal_dic, multiloops)
                                                if candidate_list5:
                                                    if len(candidate_list5) == 1:
                                                        result5 = candidate_list5                                                                                                
                                                    else:
                                                        candidate_list5.sort()
                                                        sorted_endpoint_list = create_sorted_endpointlist(candidate_list5)
                                                        result5 = MWIS(candidate_list5, sorted_endpoint_list)                              
                                                        lookup_dic_L5[L5_key] = result5                                                                   

                                        energy_l1, energy_l2, energy_l3, energy_l4, energy_l5 = 0.0, 0.0, 0.0, 0.0, 0.0
                                        
                                        # Re-estimate free energy including recursive elements            
                                        if result1 or result2 or result3 or result4 or result5:                                             
                                            effective_l1 = l1
                                            if result1:                   
                                                for item in result1:
                                                    effective_l1 = effective_l1 - (item[1] - item[0] + 1)                    
                                                    energy_l1 = energy_l1 + item[5]     # Add free energies
                                                    effective_l1 = effective_l1 + 1     # Plus number of helices                       
                                            effective_l2 = l2                    
                                            if result2:                   
                                                for item in result2:
                                                    effective_l2 = effective_l2 - (item[1] - item[0] + 1)                         
                                                    energy_l2 = energy_l2 + item[5]     # Add free energies
                                                    effective_l2 = effective_l2 + 1     # Plus number of helices                                           
                                            effective_l3 = l3
                                            if result3:                   
                                                for item in result3:
                                                    effective_l3 = effective_l3 - (item[1] - item[0] + 1)                     
                                                    energy_l3 = energy_l3 + item[5]     # Add free energies
                                                    effective_l3 = effective_l3 + 1     # Plus number of helices                                           
                                            effective_l4 = l4
                                            if result4:                   
                                                for item in result4:
                                                    effective_l4 = effective_l4 - (item[1] - item[0] + 1)                     
                                                    energy_l4 = energy_l4 + item[5]     # Add free energies
                                                    effective_l4 = effective_l4 + 1     # Plus number of helices                                           
                                            effective_l5 = l5                    
                                            if result5:                   
                                                for item in result5:
                                                    effective_l5 = effective_l5 - (item[1] - item[0] + 1)                     
                                                    energy_l5 = energy_l5 + item[5]     # Add free energies                      
                                                    effective_l5 = effective_l5 + 1     # Plus number of helices                       
                                                            
                                            # Estimate kissing hairpin energy            
                                            estimated_energy = stack_energy1 + stack_energy2 + stack_energy3 + energy_l1 + energy_l2 + energy_l3 + energy_l4 + energy_l5 
                                            entropy = unpaired_nt * (effective_l1 + effective_l2 + effective_l4 + effective_l5) + unpaired_nt_l3 * effective_l3
                                            estimated_energy = estimated_energy + entropy + init 

                                            # Indication of stability
                                            length = kissing_hairpin[4][1] - kissing_hairpin[0][0] + 1
                                            normalized_energy = estimated_energy/length        
                                            # Leave at least one base unpaired in loops L1 or L2 and loops L4 and L5
                                            if (effective_l1 > 0 or effective_l2 > 0) and (effective_l4 > 0 or effective_l5 > 0):                
                                                if normalized_energy <= -0.25:              # Filtering step                                                    
                                                    interval = kissing_hairpin[0][0], kissing_hairpin[4][1]
                                                    # Keep best kissing hairpin for interval 
                                                    if interval not in best_khps:
                                                        best_khps[interval] = kissing_hairpin, estimated_energy, result1, result2, result3, result4, result5
                                                    else:
                                                        if best_khps[interval][1] > estimated_energy:
                                                            best_khps[interval] = kissing_hairpin, estimated_energy, result1, result2, result3, result4, result5
                                        else:
                                            entropy = unpaired_nt * (l1 + l2 + l4 + l5) + unpaired_nt_l3 * l3
                                            estimated_energy = stack_energy1 + stack_energy2 + stack_energy3 + entropy + init                     

                                            # Indication of stability
                                            length = kissing_hairpin[4][1] - kissing_hairpin[0][0] + 1
                                            normalized_energy = estimated_energy/length        
                                            # Leave at least one base unpaired in loops L1 or L2 and loops L4 and L5
                                            if (l1 > 0 or l2 > 0) and (l4 > 0 or l5 > 0):           
                                                if normalized_energy <= -0.25:              # Filtering step        
                                                    interval = kissing_hairpin[0][0], kissing_hairpin[4][1]
                                                    # Keep best kissing hairpin for interval [i:n]
                                                    if interval not in best_khps:
                                                        best_khps[interval] = kissing_hairpin, estimated_energy, result1, result2, result3, result4, result5
                                                    else:
                                                        if best_khps[interval][1] > estimated_energy:
                                                            best_khps[interval] = kissing_hairpin, estimated_energy, result1, result2, result3, result4, result5
    return best_khps


def candidate_list(start, end, matrix_stems_mwis, bulge_internal_dic, multiloops):
    candidate_list = []
    for stem, values in matrix_stems_mwis.iteritems():
        if stem[0] >= start and stem[1] <= end and values[3] <= 0.0:
            element = (stem[0],stem[1],values[0],values[1],-1*round(values[3],2),values[3],"hp")
            candidate_list.append(element)
    for stem_ib, values_ib in bulge_internal_dic.iteritems():
        if stem_ib[0] >= start and stem_ib[1] <= end:
            element = (stem_ib[0],stem_ib[1],values_ib[0],0.0,-1*round(values_ib[2],2),values_ib[2],"ib")
            candidate_list.append(element)
    for stem_ml, values_ml in multiloops.iteritems():
        if stem_ml[0] >= start and stem_ml[1] <= end:
            element = (stem_ml[0],stem_ml[1],values_ml[0],0.0,-1*round(values_ml[2],2),values_ml[2],"ml")
            candidate_list.append(element)                
    return candidate_list

# Task: MWIS calculation for recursive elements in loops
def recursive_pk_mwis(pk,candidate_list):
    candidate_list.sort()
    sorted_endpoint_list = create_sorted_endpointlist(candidate_list)
    result = MWIS(candidate_list, sorted_endpoint_list)
    return result

# Task: Maximum weight independent set calculation
# Step 1, initialization
# Step 2, scan sorted endpoints list
# Step 3, traceback step
def MWIS(interval_set, sorted_endpointlist):
    value, temp_max, Smax1, last_interval = [], 0.0, [], 0
    for j in xrange(len(sorted_endpointlist)):    # Step 1
        value.insert(j,0.0)
    for endpoint in sorted_endpointlist:          # Step 2
        if endpoint[1] == 'l':                    # If left endpoint is scanned
            c = endpoint[3] - 1
            value[c] = temp_max + endpoint[2]            
        if endpoint[1] == 'r':                    # If right endpoint is scanned
            c = endpoint[3] - 1
            if value[c] > temp_max:
                temp_max = value[c]
                last_interval = c
    Smax1.insert(0,interval_set[last_interval])
    temp_max = temp_max - interval_set[last_interval][4]    
    for j in xrange(last_interval-1,-1,-1):       # Step 3
        if round(value[j], 2) == round(temp_max, 2):            
            if interval_set[j][1] < interval_set[last_interval][0]:                
                Smax1.append(interval_set[j])
                temp_max = temp_max - interval_set[j][4]
                last_interval = j
    return Smax1

#--- MWIS ---#

# Function for finding nested intervals, given a right endpoint.
# Outer stem is a hairpin loop.
def find_nested(endpoint,candidate_list):
    only_hp_ib_ml = True
    result = []
    interval_index = endpoint[3]
    interval = candidate_list[interval_index - 1]
    interval_left = interval[0] + interval[2]  
    interval_right = interval[1] - interval[2] 
    for compare_interval in candidate_list:
        if compare_interval[0] >= interval_left:                              
            if compare_interval[1] <= interval_right:
                # Special case for nested pseudoknot
                # Add one base to each side as a safeguard
                # {{{.(((..[[[.)))....]]].}}}
                if compare_interval[5] == 'pk':
                    if compare_interval[0] > interval_left:
                        if compare_interval[1] < interval_right:
                            result.append(compare_interval)
                            only_hp_ib_ml = False
                else:
                    result.append(compare_interval)
                    if compare_interval[5] == 'khp':
                        only_hp_ib_ml = False    
    return result, only_hp_ib_ml

# Task: Create sorted endpoints list for a set of intervals
def create_sorted_endpointlist(intervals):
    sorted_list, firstpoint, secondpoint = [], (), ()
    for i in xrange(len(intervals)):
        firstpoint = (intervals[i][0],'l',float(intervals[i][4]),i+1)
        sorted_list.append(firstpoint)
        secondpoint = (intervals[i][1],'r',float(intervals[i][4]),i+1)
        sorted_list.append(secondpoint)        
    sorted_list.sort()
    return sorted_list

# Task: Maximum weight independent set using the set of structure elements and pseudoknots
def method(matrix_stems_mwis, pk_recursive_dic, bulge_internal_dic, multiloops, best_khps):
    crossing_structures, secondary_structures, mwis_dic, candidate_list = {}, {}, {}, []
    
    for stem, values in matrix_stems_mwis.iteritems():
        if values[3] < 0.0:
            element = (stem[0],stem[1],values[0],values[1],-1*round(values[3],2),"hp")
            candidate_list.append(element)
    for pk_stem, pk_energy in pk_recursive_dic.iteritems():
        element = (pk_stem[0],pk_stem[1],pk_stem[4],pk_stem[7],-1*round(pk_energy[0],2),"pk",pk_stem[2],pk_stem[3],pk_stem[4],pk_stem[5],pk_stem[6],pk_stem[7],pk_stem[8])
        candidate_list.append(element)
    for stem, values in bulge_internal_dic.iteritems():
        element = (stem[0],stem[1],values[0],values[1],-1*round(values[2],2),"ib")
        candidate_list.append(element)
    for stem, values in multiloops.iteritems():
        element = (stem[0],stem[1],values[0],values[1],-1*round(values[2],2),"ml")
        candidate_list.append(element)
    for stem, values in best_khps.iteritems():
        element = (stem[0],stem[1],values[1],0.0,-1*round(values[1],2),"khp")
        candidate_list.append(element)

    if candidate_list:
        candidate_list.sort()
        sorted_endpoint_list = create_sorted_endpointlist(candidate_list)    
        
        for endpoint in sorted_endpoint_list:   # Scan sorted endpoints list         
            if endpoint[1] == 'r':              # If a right endpoint is scanned
                sorted_endpoint_list_recursive, nested = [], []
                index = endpoint[3]                       
                if candidate_list[index-1][5] == 'hp':
                    nested, only_hp_ib_ml = find_nested(endpoint,candidate_list)                  
                    if nested and only_hp_ib_ml == False:     # MWIS on the set of nested structure elements
                        endpoint_list_recursive = create_sorted_endpointlist(nested)
                        result = MWIS(nested, endpoint_list_recursive)
                        interval = candidate_list[index-1]
                        energy = candidate_list[index-1][4]
                        for element in result:    # Free energy sum
                            energy = energy + element[4]
                        # Store updated free energy for outer stem
                        candidate_list[index-1] = (interval[0], interval[1], interval[2], interval[3], energy,interval[5])                    
                        stem = interval[0], interval[1], interval[2]
                        # Store inner structure elements in dictionary
                        mwis_dic[stem] = result                    
        # Main MWIS calculation
        sorted_endpoint_list_recursive = create_sorted_endpointlist(candidate_list)
        result = MWIS(candidate_list, sorted_endpoint_list_recursive)    
        energy = 0.0
        for j in xrange(len(result)):
            energy = energy + result[j][4]
       
        # Search for detected pseudoknots and kissing hairpins
        for element in result:
            if element[5] == 'khp' or element[5] == 'pk':
                crossing_structures[element] = element[4]
            if element[5] == 'hp' or element[5] == 'ib' or element[5] == 'ml':
                secondary_structures[element] = element[4]
            if element[5] == 'hp':  # Hairpin loop can have nested elements
                crossing_structures, secondary_structures = print_recursive(element, mwis_dic, crossing_structures, secondary_structures)    
    return mwis_dic, crossing_structures, secondary_structures

# Task : Given the MWIS result, look for pseudoknots and kissing hairpins recursively.
# Pseudoknots and kissing hairpins can be nested in hairpin loops.
def print_recursive(element, mwis_dic, crossing_structures, secondary_structures):
    structure = element[0], element[1], element[2]
    if find_in_dic(mwis_dic, structure) != 0.0:
        stem_internal_list = mwis_dic[element[0], element[1], element[2]]
        for item in stem_internal_list:
            if item[5] == 'khp' or item[5] == 'pk': 
                crossing_structures[item] = item[4]
            if item[5] == 'hp' or item[5] == 'ib' or item[5] == 'ml': 
                secondary_structures[item] = item[4]
            if item[5] == 'hp':
                substructure = item[0], item[1], item[2]
                if find_in_dic(mwis_dic, substructure) != 0.0:
                    crossing_structures, secondary_structures = print_recursive(substructure, mwis_dic, crossing_structures, secondary_structures) 
    return crossing_structures, secondary_structures

def recursive_elements(recursive_loop, pk_structure, i, bulge_internal_dic, multiloops):
    for element in recursive_loop:
        element_length = element[2] + 1
        start = element[0] - i
        end = element[1] - i                
        if element[6] == 'hp':                                       
            for counter in xrange(1,element_length):
                pk_structure = pk_structure[0:start] + '(' + pk_structure[start+1:]
                pk_structure = pk_structure[0:end] + ')' + pk_structure[end+1:]
                start = start + 1
                end = end - 1
        elif element[6] == 'ib':
            structure_ib = bulge_internal_dic[element[0],element[1]][1]                    
            structure_ib = structure_ib.replace(':','') # Cut off dangling ends
            pk_structure = pk_structure[0:start] + structure_ib + pk_structure[end+1:]                
        elif element[6] == 'ml':
            structure_ml = multiloops[element[0],element[1]][1]                    
            structure_ml = structure_ml.replace(':','') # Cut off dangling ends
            pk_structure = pk_structure[0:start] + structure_ml + pk_structure[end+1:]
    return pk_structure

# Task: Print structures for pseudoknots with both regular and interrupted stems
def khp_structures(seq, crossing_structures, best_khps, matrix_stems, stems_shortened_dic, bulge_internal_dic, multiloops, pk_recursive_dic, pk_dic_ib):
    pseudoknot_list = []    
    for pk in sorted(set(crossing_structures)):
        if pk[5] == 'khp':            
            i, n = pk[0], pk[1]
            energy = pk[2]
            for khp, khp_value in best_khps.items():
                if khp_value[0][0][0] == i and khp_value[0][4][1] == n and khp_value[1] == energy:
                    kissing_hairpin = khp_value[0]
                    value = khp_value[1:]
            i, j, stemlength1 = kissing_hairpin[0][0], kissing_hairpin[0][1], kissing_hairpin[1]
            k, l, stemlength2 = kissing_hairpin[2][0], kissing_hairpin[2][1], kissing_hairpin[3]
            m, n, stemlength3 = kissing_hairpin[4][0], kissing_hairpin[4][1], kissing_hairpin[5]
            # Find the recursive structure elements            
            recursive_loop1 = value[1]
            recursive_loop2 = value[2]
            recursive_loop3 = value[3]
            recursive_loop4 = value[4]
            recursive_loop5 = value[5]

            L1_start = kissing_hairpin[0][0] + kissing_hairpin[1]
            L1_end = kissing_hairpin[2][0] - 1
            l1 = L1_end - L1_start + 1        
            L2_start = kissing_hairpin[2][0] + kissing_hairpin[3]
            L2_end = kissing_hairpin[0][1] - kissing_hairpin[1]
            l2 = L2_end - L2_start + 1      
            L3_start = kissing_hairpin[0][1] + 1
            L3_end = kissing_hairpin[4][0] - 1
            l3 = L3_end - L3_start + 1             
            L4_start = kissing_hairpin[4][0] + kissing_hairpin[5]
            L4_end = kissing_hairpin[2][1] - kissing_hairpin[3]
            l4 = L4_end - L4_start + 1       
            L5_start = kissing_hairpin[2][1] + 1
            L5_end = kissing_hairpin[4][1] - kissing_hairpin[5]
            l5 = L5_end - L5_start + 1
            
            pk_seq = seq[int(i-1):int(n)]

            pk_structure = ''
            for x in xrange(stemlength1):
                pk_structure = pk_structure + '('
            for x in xrange(l1):
                pk_structure = pk_structure + '.'
            for x in xrange(stemlength2):
                pk_structure = pk_structure + '['
            for x in xrange(l2):
                pk_structure = pk_structure + '.'
            for x in xrange(stemlength1):
                pk_structure = pk_structure + ')'
            for x in xrange(l3):
                pk_structure = pk_structure + '.'
            for x in xrange(stemlength3):
                pk_structure = pk_structure + '('
            for x in xrange(l4):
                pk_structure = pk_structure + '.'            
            for x in xrange(stemlength2):
                pk_structure = pk_structure + ']'
            for x in xrange(l5):
                pk_structure = pk_structure + '.'            
            for x in xrange(stemlength3):
                pk_structure = pk_structure + ')'

            # Now add recursive structure elements
            if recursive_loop1:
                pk_structure = recursive_elements(recursive_loop1, pk_structure, i, bulge_internal_dic, multiloops)
            if recursive_loop2:
                pk_structure = recursive_elements(recursive_loop2, pk_structure, i, bulge_internal_dic, multiloops)
            if recursive_loop3:
                pk_structure = recursive_elements(recursive_loop3, pk_structure, i, bulge_internal_dic, multiloops)
            if recursive_loop4:
                pk_structure = recursive_elements(recursive_loop4, pk_structure, i, bulge_internal_dic, multiloops)
            if recursive_loop5:
                pk_structure = recursive_elements(recursive_loop5, pk_structure, i, bulge_internal_dic, multiloops)
                                        
            pseudoknot = [pk[0], pk[1], -1*float(pk[4]), pk[5], pk_seq, pk_structure]             
            pseudoknot_list.append(pseudoknot)

        if pk[5] == 'pk':
            i, j, stemlength1 = pk[6], pk[7], pk[8]        
            k, l, stemlength2 = pk[9], pk[10], pk[11]
            marker = pk[12]
            key = i, l, i, j, stemlength1, k, l, stemlength2, marker                            
            # Find the recursive structure elements            
            if marker == 'r':            
                recursive_loop1 = pk_recursive_dic[key][1]
                recursive_loop2 = pk_recursive_dic[key][2]
                recursive_loop3 = pk_recursive_dic[key][3]
                # Now assemble the core pseudoknot structure with regular stems
                looplength1 = k - (i + stemlength1)            
                looplength2 = (j - stemlength1 + 1) - (k + stemlength2)
                looplength3 = (l - stemlength2) - j
                pk_seq = seq[int(i-1):int(l)]
                pk_structure = ''
                for x in xrange(stemlength1):
                    pk_structure = pk_structure + '('
                for x in xrange(looplength1):
                    pk_structure = pk_structure + '.'
                for x in xrange(stemlength2):
                    pk_structure = pk_structure + '['
                for x in xrange(looplength2):
                    pk_structure = pk_structure + '.'
                for x in xrange(stemlength1):
                    pk_structure = pk_structure + ')'
                for x in xrange(looplength3):
                    pk_structure = pk_structure + '.'
                for x in xrange(stemlength2):
                    pk_structure = pk_structure + ']'                
            else:
                # Now assemble the core pseudoknot structure with one interrupted stem            
                recursive_loop1 = pk_dic_ib[key][1]
                recursive_loop2 = pk_dic_ib[key][2]
                recursive_loop3 = pk_dic_ib[key][3]                
                looplength1 = k - (i + stemlength1)            
                looplength2 = (j - stemlength1 + 1) - (k + stemlength2)
                looplength3 = (l - stemlength2) - j 
                pk_seq = seq[int(i-1):int(l)]
                pk_structure = ''            
                # Case 1: stem S1 is interrupted
                if marker == 'iS1':
                    stem1 = i,j
                    structure_stem1 = find_in_dic(bulge_internal_dic,stem1)[1]             
                    # Delete dangling ends ':'
                    structure_stem1 = structure_stem1.replace(':','')
                    pk_structure = pk_structure + structure_stem1                       
                    # Start of stem S2                                 
                    start = k-i               
                    for x in xrange(stemlength2):
                        pk_structure = pk_structure[0:start] + '[' + pk_structure[start+1:]
                        start = start + 1
                    for x in xrange(looplength3):
                        pk_structure = pk_structure + '.'                    
                    for x in xrange(stemlength2):
                        pk_structure = pk_structure + ']'                                
                # Case 2: stem S2 is interrupted, change brackets '(' to '[' and ')' to ']'
                if marker == 'iS2':
                    stem2 = k,l
                    structure_stem2 = find_in_dic(bulge_internal_dic,stem2)[1]                   
                    # Delete dangling ends ':'
                    structure_stem2 = structure_stem2.replace(':','')
                    structure_stem2 = structure_stem2.replace('(','[')
                    structure_stem2 = structure_stem2.replace(')',']')                
                    pk_structure = pk_structure + structure_stem2
                    for x in xrange(looplength1):
                        pk_structure = '.' + pk_structure                
                    for x in xrange(stemlength1):
                        pk_structure = '(' + pk_structure
                    # End of stem S1
                    end = j-i
                    for x in xrange(stemlength1):
                        pk_structure = pk_structure[0:end] + ')' + pk_structure[end+1:]
                        end = end - 1
            # Now add recursive structure elements
            if recursive_loop1:
                pk_structure = recursive_elements(recursive_loop1, pk_structure, i, bulge_internal_dic, multiloops)
            if recursive_loop2:
                pk_structure = recursive_elements(recursive_loop2, pk_structure, i, bulge_internal_dic, multiloops)
            if recursive_loop3:
                pk_structure = recursive_elements(recursive_loop3, pk_structure, i, bulge_internal_dic, multiloops)
                        
            pseudoknot = [pk[0], pk[1], -1*float(pk[4]), pk[5], pk_seq, pk_structure]             
            pseudoknot_list.append(pseudoknot)                

    return pseudoknot_list    

