import os
import sys
import math
import functions

hairpin_dic = {
 0: 0.0 ,   1: 0.0 , 2: 0.0,  3: 4.1,  4: 4.9,  5: 4.4,  6: 4.7,  7: 5.0,  8: 5.1,  9: 5.2, 10: 5.3,
11: 5.4 , 12: 5.5 , 13: 5.6, 14: 5.7, 15: 5.8, 16: 5.8, 17: 5.9, 18: 5.9, 19: 6.0, 20: 6.1, 21: 6.1,
22: 6.2 , 23: 6.2 , 24: 6.3, 25: 6.3, 26: 6.3, 27: 6.4, 28: 6.4, 29: 6.5, 30: 6.5
} 
# Task: Energy evaluation for pseudoknots with energy parameters LongPK and one interrupted stem
def evaluate_pk_with_IB(pk_with_IB, matrix_stems, stems_shortened_dic, init, penalty, seq):
    pk_dic_ib = {}
    for pk, values in pk_with_IB.items():                    
        hairpin = values[5]
        if hairpin > 30:
            hairpin = 30    # Subtract hairpin loop entropy for stem s_ib
        entropy = hairpin_dic[hairpin]            
        if pk[8] == 'iS1':    # S2 is regular stem
            s2 = pk[5], pk[6], pk[7]
            # Is S2 a shortened stem? 
            if functions.find_in_dic(stems_shortened_dic, s2) != 0.0:
                # S2 is a shortened stem
                stack_energy2 = stems_shortened_dic[s2][2]
                energy_stems = values[1] + stack_energy2 - entropy
                s2_length = pk[7]
            else:
                s2 = pk[5], pk[6]
                if functions.find_in_dic(matrix_stems, s2) != 0.0: # S2 is not a shortened stem
                    stem = functions.find_in_dic(matrix_stems, s2)
                    if pk[7] == stem[0]:
                        energy_stems = values[0] - entropy
                        s2_length = pk[7]                    
                    else:   # S2 is a shortened stem and was filtered 
                        energy_stems = 100.0                    
        if pk[8] == 'iS2':    # S1 is regular stem            
            s1 = pk[2], pk[3], pk[4]
            # Is S1 a shortened stem?
            if functions.find_in_dic(stems_shortened_dic,s1) != 0.0:                
                # S1 is a shortened stem
                stack_energy1 = stems_shortened_dic[s1][2]
                energy_stems = values[1] + stack_energy1 - entropy
                s1_length = pk[4]
            else:
                s1 = pk[2], pk[3]
                if functions.find_in_dic(matrix_stems, s1) != 0.0: # S1 is not a shortened stem
                    stem = functions.find_in_dic(matrix_stems, s1)
                    if pk[4] == stem[0]:
                        energy_stems = values[0] - entropy
                        s1_length = pk[4]
                    else:   # S1 is a shortened stem and was filtered
                        energy_stems = 100.0                      
        l1, l2, l3 = values[2], values[3], values[4]        
        pk_energy = energy_stems + penalty * (l1 + l2 + l3) + init
        cs = 0.0
        left, right = values[6], values[7]
        # Coaxial stacking
        # ACGGaUUGUguCCGUAAUcACA
        # (((((.[[[[)))))...]]]]
        # Stack is AU|GC = -2.10
        if l2 == 0:
            if pk[8] == 'iS2':    # S1 is regular stem                
            # AACcUUCACCAAUUagGUUCAAAuAAGUGGU
            # ((((:::[[[[.[[[))))::::]]].]]]]
            # string = 'second', stem[0], stem_ib[1], stem[0], stem[1], length, stem_ib[0], stem_ib[1], stem_length_eff, hairpin_loop
                pair1 = str(seq[((pk[2] - 1) + pk[4] - 1)]) + str(seq[pk[3] - pk[4]])                
                pair2 = str(seq[pk[5] + left - 1]) + str(seq[pk[5] + right - 1])
                stack = pair1 + "|" + pair2
                cs = functions.find_in_dic(stack_dic,stack)                
            if pk[8] == 'iS1':    # S2 is regular stem                
            # AACCUUCcCCAAUUagGUUCAAAuAAGuGGU
            # [[[[.[[[...((((]]].]]]]....))))
            # string = 'first', stem_ib[0], stem[1], stem_ib[0], stem_ib[1], stem_length_eff, stem[0], stem[1], length, hairpin_loop
                pair2 = str(seq[((pk[5] - 1) + pk[7] - 1)]) + str(seq[pk[6] - pk[7]])                
                pair1 = str(seq[pk[2] + left - 1]) + str(seq[pk[2] + right - 1])
                stack = pair1 + "|" + pair2
                cs = functions.find_in_dic(stack_dic,stack)               
        # If S1 is interrupted and > 10 bp: L3 >= 6 nt
        if pk[8] == 'iS1' and (values[9] >= 10 or values[10] >= 10):
            if l3 < 5:
                pk_energy = 100.0
        # If S2 is interrupted and > 10 bp: L1 >= 2 nt
        if pk[8] == 'iS2' and (values[9] >= 10 or values[10] >= 10):
            if l1 < 2:
                pk_energy = 100.0
        cs = 0.75*cs
        pk_energy = pk_energy + cs
        
        if pk_energy < 0.0:
            if pk_energy < values[1] and pk_energy < values[8]:
                # Key has same format as pseudoknots with regular stems
                pk_dic_ib[pk] = pk_energy, l1, l2, l3, energy_stems, left, right, pk[8]
    return pk_dic_ib

# Dictionary Stacking
stack_dic ={
"AU|AU" : -0.9 , "AU|CG" : -2.2 , "AU|GC" : -2.1 ,
"AU|GU" : -0.6 , "AU|UA" : -1.1 , "AU|UG" : -1.4 ,
"CG|AU" : -2.1 , "CG|CG" : -3.3 , "CG|GC" : -2.4 ,
"CG|GU" : -1.4 , "CG|UG" : -2.1 , "CG|UA" : -2.1 ,
"GC|AU" : -2.4 , "GC|CG" : -3.4 , "GC|GC" : -3.3 ,
"GC|GU" : -1.5 , "GC|UA" : -2.2 , "GC|UG" : -2.5 ,
"GU|AU" : -1.3 , "GU|CG" : -2.5 , "GU|GC" : -2.1 ,
"GU|GU" : -0.5 , "GU|UA" : -1.4 , "GU|UG" :  1.3 ,
"UA|AU" : -1.3 , "UA|CG" : -2.4 , "UA|GC" : -2.1 ,
"UA|GU" : -1.0 , "UA|UA" : -0.9 , "UA|UG" : -1.3 ,
"UG|AU" : -1.0 , "UG|CG" : -1.5 , "UG|GC" : -1.4 ,
"UG|GU" :  0.3 , "UG|UA" : -0.6 , "UG|UG" : -0.5
}

# Task: Construct recursive H-type pseudoknots with MWIS calculation and one interrupted stem
def recursive_pk(matrix_stems, bulge_internal_dic, multiloops, pk_dic_ib):
    for pk, values_pk in pk_dic_ib.items():
        
        candidate_list1, candidate_list2, candidate_list3 = [], [], []
        candidate_list1_positive, candidate_list2_positive, candidate_list3_positive = [], [], []
        
        l1, l2, l3 = values_pk[1], values_pk[2], values_pk[3]
        left, right = values_pk[5], values_pk[6]        
        marker = pk[8]
        if marker == 'iS1':
            # First Case, combine s_ib with normal stem s
            # (((...((((.xxx...))))...)))........xxx        
            loop1_start = pk[0] + left + 1
            loop1_end = pk[5]-1            
            loop2_start = pk[5]+pk[7]
            loop2_end = pk[0] + right - 1           
            loop3_start = pk[3]+1
            loop3_end = pk[6]-pk[7]
        if marker == 'iS2':
            # Second Case, combine normal stem s with s_ib
            # xxx........(((...((((.xxx...))))...)))    
            loop1_start = pk[2] + pk[4] 
            loop1_end = pk[5]-1            
            loop2_start = pk[5] + left + 1
            loop2_end = pk[3] - pk[4]           
            loop3_start = pk[3]+1
            loop3_end = pk[0] + right -1

        # Look for recursive candidates in loops
        if l1 >= 9 or l2 >= 9 or l3 >= 9:
            for stem, values in matrix_stems.items():                  
                if stem[0] >= loop1_start and stem[1] <= loop1_end:
                    # Safeguard against sterically infeasible configuration, add one base at either side of the loop region
                    if stem[0] > loop1_start or stem[1] < loop1_end:                                                  
                        if values[3] <= 0.0:
                            element = (stem[0],stem[1],values[0],values[1],-1*round(values[3],2),values[3],"hp")
                            candidate_list1.append(element)
                        else:
                            element = (stem[0],stem[1],values[0],values[1],1.0/round(values[3],2),values[3],"hp")
                            candidate_list1_positive.append(element)                          
                if stem[0] >= loop2_start and stem[1] <= loop2_end:                    
                    if values[3] <= 0.0:
                        element = (stem[0],stem[1],values[0],values[1],-1*round(values[3],2),values[3],"hp")
                        candidate_list2.append(element)
                    else:
                        element = (stem[0],stem[1],values[0],values[1],1.0/round(values[3],2),values[3],"hp")
                        candidate_list2_positive.append(element)  
                if stem[0] >= loop3_start and stem[1] <= loop3_end:
                    # Safeguard against sterically infeasible configuration, add one base at either side of the loop region
                    if stem[0] > loop3_start or stem[1] < loop3_end:                             
                        if values[3] <= 0.0:
                            element = (stem[0],stem[1],values[0],values[1],-1*round(values[3],2),values[3],"hp")
                            candidate_list3.append(element)
                        else:
                            element = (stem[0],stem[1],values[0],values[1],1.0/round(values[3],2),values[3],"hp")
                            candidate_list3_positive.append(element)                          
            for stem_ib, values_ib in bulge_internal_dic.items():
                # Safeguard against sterically infeasible configuration, add one base at either side of the loop region
                if stem_ib[0] >= loop1_start and stem_ib[1] <= loop1_end:
                    if stem[0] > loop1_start or stem[1] < loop1_end:     
                        element = (stem_ib[0],stem_ib[1],values_ib[0],0.0,-1*round(values_ib[2],2),values_ib[2],"ib")
                        candidate_list1.append(element)
                if stem_ib[0] >= loop2_start and stem_ib[1] <= loop2_end:
                    element = (stem_ib[0],stem_ib[1],values_ib[0],0.0,-1*round(values_ib[2],2),values_ib[2],"ib")
                    candidate_list2.append(element)
                if stem_ib[0] >= loop3_start and stem_ib[1] <= loop3_end:
                    # Safeguard against sterically infeasible configuration, add one base at either side of the loop region
                    if stem[0] > loop3_start or stem[1] < loop3_end:     
                        element = (stem_ib[0],stem_ib[1],values_ib[0],0.0,-1*round(values_ib[2],2),values_ib[2],"ib")                    
                        candidate_list3.append(element)
                        
            for stem_ml, values_ml in multiloops.items():
                # Safeguard against sterically infeasible configuration, add one base at either side of the loop region
                if stem_ml[0] >= loop1_start and stem_ml[1] <= loop1_end:
                    if stem[0] > loop1_start or stem[1] < loop1_end:     
                        element = (stem_ml[0],stem_ml[1],values_ml[0],0.0,-1*round(values_ml[2],2),values_ml[2],"ml")
                        candidate_list1.append(element)
                if stem_ml[0] >= loop2_start and stem_ml[1] <= loop2_end:
                    element = (stem_ml[0],stem_ml[1],values_ml[0],0.0,-1*round(values_ml[2],2),values_ml[2],"ml")
                    candidate_list2.append(element)                                 
                if stem_ml[0] >= loop3_start and stem_ml[1] <= loop3_end:
                    # Safeguard against sterically infeasible configuration, add one base at either side of the loop region
                    if stem[0] > loop3_start or stem[1] < loop3_end:     
                        element = (stem_ml[0],stem_ml[1],values_ml[0],0.0,-1*round(values_ml[2],2),values_ml[2],"ml")
                        candidate_list3.append(element)                     
                    
        energy1, energy2, energy3 = 0.0, 0.0, 0.0
        result1, result2, result3 = [], [], []
        if candidate_list1:
            result1 = recursive_pk_mwis(pk,candidate_list1)
        else:
            if candidate_list1_positive:          
                result1 = recursive_pk_mwis(pk,candidate_list1_positive)
        if candidate_list2:          
            result2 = recursive_pk_mwis(pk,candidate_list2)
        else:
            if candidate_list2_positive:       
                result2 = recursive_pk_mwis(pk,candidate_list2_positive)             
        if candidate_list3:
            result3 = recursive_pk_mwis(pk,candidate_list3)
        else:
            if candidate_list3_positive:
                result3 = recursive_pk_mwis(pk,candidate_list3_positive)
        pk_dic_ib[pk] = values_pk, result1, result2, result3
    return pk_dic_ib

# Task: MWIS calculation for recursive H-type pseudoknot construction
def recursive_pk_mwis(pk,candidate_list):
    candidate_list.sort()
    sorted_endpoint_list = functions.create_sorted_endpointlist(candidate_list)
    result = functions.MWIS(candidate_list, sorted_endpoint_list)        
    return result

# Task: Energy re-evaluation for pseudoknots with one interrupted stem
def re_evaluate_pk_with_IB(pk_dic_ib, init, penalty):
    for pk_stem, values in pk_dic_ib.items():
        if values[1] or values[2] or values[3]:
            list1 = values[1]
            list2 = values[2]
            list3 = values[3]                   
            l1 = values[0][1]
            l2 = values[0][2]     
            l3 = values[0][3]
            entropy_l1, energy_l1, entropy_l2, energy_l2, entropy_l3, energy_l3 = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0            
            energy = values[0][4] # Store stem energies for S1 and S2
            # Calculate effective loop lengths
            if list1:
                effective_l1 = l1                
                for item in list1:
                    effective_l1 = effective_l1 - (item[1]-item[0]+1)                   
                    energy_l1 = energy_l1 + item[5]     # Add free energy                       
                    effective_l1 = effective_l1 + 1     # Plus number of helices
            else:
                effective_l1 = l1                                      
            if list2:
                effective_l2 = l2                
                # Add up hairpin entropies
                for item in list2:
                    effective_l2 = effective_l2 - (item[1]-item[0]+1) 
                    energy_l2 = energy_l2 + item[5]     # Add free energy                 
                    effective_l2 = effective_l2 + 1     # Plus number of helices                     
            else:
                effective_l2 = l2
            if list3:
                effective_l3 = l3
                for item in list3:
                    effective_l3 = effective_l3 - (item[1]-item[0]+1) 
                    energy_l3 = energy_l3 + item[5]     # Add free energy               
                    effective_l3 = effective_l3 + 1     # Plus number of helices                     
            else:
                effective_l3 = l3
            looplength = effective_l1 + effective_l2 + effective_l3            
            entropy = penalty*(looplength) 
            pk_energy = energy + entropy + energy_l1 + energy_l2 + energy_l3 + init
            # Only include internal loop energies if this leads to more stable pseudoknots
            if pk_energy < values[0][0]:
                pk_dic_ib[pk_stem] = pk_energy, values[1], values[2], values[3]
            else:
                energy = values[0][0]
                pk_dic_ib[pk_stem] = energy, [], [], []                                           
        else:
            energy = values[0][0]
            pk_dic_ib[pk_stem] = energy, values[1], values[2], values[3]

    return pk_dic_ib
