# Import NumPy
import numpy as np
import pandas as pd
from Bio.PDB import PDBParser


# Basepairs possible in RNA and score
PAIRS = {
    "CG": 1,
    "GC": 1,
    "GU": 1,
    "UG": 1,
    "AU": 1,
    "UA": 1
}


# Minimum free energy table
MFE = pd.DataFrame([[-2.4, -3.3, -2.1, -1.4, -2.1, -2.1],
                    [-3.3, -3.4, -2.5, -1.5, -2.2, -2.4],
                    [-2.1, -2.5,  1.3, -0.5, -1.4, -1.3],
                    [-1.4, -1.5, -0.5,  0.3, -0.6, -0.9],
                    [-2.1, -2.2, -1.4, -0.6, -1.1, -0.9],
                    [-2.1, -2.4, -1.3, -1.0, -0.9, -1.3]],
                   index=["CG", "GC", "GU", "UG", "AU", "UA"],
                   columns=["CG", "GC", "GU", "UG", "AU", "UA"])


def seq_to_matrix(text=None):
    """
    Make a zeroed matrix with dimensions of the string length as columns and rows

    Args:
        text:  String used to determine matrix dimensions

    Returns:
        Zeroed matrix of dimensions determined by text length
    """

    # Determine amount of characters in text and create a zeroed matrix shaped with this information
    size = len(text)
    matrix = np.zeros((size, size))

    # Return the created matrix
    return matrix


def loop_matrix_diag(matrix=None, offset=3, min_loop=0):
    """
    Determine

    Args:
        matrix: Zeroed matrix of dimensions matching the amount of characters in a string
        offset: Skip the given amount of columns, default: 3
        min_loop: , default: 0

    Returns:
        Diagonal paths of given matrix as a list of tuples
    """

    # Remember that we're working with 0-index matrix
    offset -= 1

    # Quirky fix for minimum loop size
    if min_loop > 0:
        min_loop -= 1

    # Calculations for loop sizes
    matrix_length = len(matrix)
    limit = matrix_length - offset

    # Default values of where to start in column and row
    default_row = 0
    default_column = offset

    # List containing positions of diagonals
    positions = []

    # Loop for the amount of times determined to get all diagonals
    for i in range(min_loop, limit):

        # Reset row and column, but remember which iteration we're in
        row = default_row
        column = default_column + i

        # Loop for the amount of times determined to get all positions in current diagonal
        for j in range(matrix_length - column):

            # Append the current cell of diagonal
            positions.append((row, column))

            # Go to next cell of diagonal
            row += 1
            column += 1

    # Return list of positions to work with
    return positions


def nussinov(sequence=None, min_loop_size=0):
    """
    Implementation of the Nussinov algorithm to generate RNA secondary structure

    Args:
        sequence: RNA sequence
        min_loop_size: Minimum loop size

    Returns:
        A matrix of Nussinov values to be ...
    """

    # Get a zeroed matrix of the sequence in question
    seq_matrix = seq_to_matrix(sequence)

    # Loop through the cells of diagonals
    for i, j in loop_matrix_diag(seq_matrix, min_loop=min_loop_size):
        # Get the value of the cell to the left of the current cell (i, j-1)
        left_cell = seq_matrix[i, j - 1]

        # Get the value of the cell below the current cell (i+1, j)
        down_cell = seq_matrix[i + 1, j]

        # Check if the current cell is a basepair
        if sequence[i] + sequence[j] in PAIRS.keys():
            # Add score to value of cell diagonally down left
            # if (sequence[i + 1] + sequence[j - 1] in PAIRS.keys()):
            #    diagonal_down_left = seq_matrix[i + 1, j - 1] + MFE.at[sequence[i] + sequence[j],
            #                                                           sequence[i + 1] + sequence[j - 1]]
            # else:
            #    diagonal_down_left = seq_matrix[i + 1, j - 1] + PAIRS[sequence[i] + sequence[j]]
            diagonal_down_left = seq_matrix[i + 1, j - 1] + PAIRS[sequence[i] + sequence[j]]

        else:
            # Get value of the cell diagonally down left
            diagonal_down_left = seq_matrix[i + 1, j - 1]

        # Bifurcation
        bifurcation = 0

        for k in range(i, j - 1):
            if (seq_matrix[i, k] + seq_matrix[k + 1, j]) > bifurcation:
                bifurcation = seq_matrix[i, k] + seq_matrix[k + 1, j]

        # Determine which value of the cells were biggest and write it to the current cell
        biggest_value = max(left_cell, down_cell, diagonal_down_left, bifurcation)
        seq_matrix[i, j] = biggest_value

    # Return the matrix of Nussinov values
    return seq_matrix


def backtrack(matrix=None, sequence=None):
    """

    Args:
        matrix:
        sequence:

    Returns:

    """

    # Starting values of in the matrix depending on the sequence in question
    stack = [(0, len(sequence) - 1)]

    # Basepairs found
    basepairs = []

    # Continue to backtrack as long as the stack is not empty
    while len(stack) > 0:

        # Get first tuple values from stack
        i, j = stack.pop(0)

        # If i is equal to or bigger than j, move to next iteration
        if i >= j:
            continue

        # Check the value of the cell to the left of the current cell (i, j-1)
        if matrix[i, j-1] == matrix[i, j]:
            stack.append((i, j - 1))  # Move one column to the left

        # Check the value of the cell below the current cell (i+1, j)
        elif matrix[i + 1, j] == matrix[i, j]:
            stack.append((i + 1, j))  # Move one row down

        # Check the value of the cell diagonally down left from the current cell (i+1, j-1)
        elif sequence[i] + sequence[j] in PAIRS.keys():
            basepairs.append((i, j))    # Add found basepair to list
            stack.append((i + 1, j - 1))    # Move diagonally down left

        # If none of above is true, check if bifurcation occurred
        else:
            for k in range(i + 1, j - 1):
                stack.append((k + 1, j))
                stack.append((i, k))
                break

    # Return dot-bracket string of found basepairs (predicted secondary structure)
    dot_bracket_string = basepairs_to_string(basepairs, sequence)
    return dot_bracket_string


def basepairs_to_string(pairs, sequence):
    """
    Create a dot-bracket string for a given string using its determined pairs

    Args:
        pairs: List containing tuples of pairs in sequence
        sequence: RNA sequence which is to be converted to dot-bracket string

    Returns:
        A string of dot-bracket pairs for sequence given its basepairs
    """

    # Create a list of dots with the same length of the sequence
    sequence = ["."] * len(sequence)

    # Loop through basepairs and replace dots in sequence with corresponding brackets for pairs
    for i, j in pairs:
        sequence[i] = "("
        sequence[j] = ")"

    # Concatenate list to a string
    dot_bracket_sequence = "".join(sequence)

    # Return sequence in dot-bracket format
    return dot_bracket_sequence


if __name__ == "__main__":
    # Sequence we want to determine the loops of
    RNAseqs = ["GUACGUGUGCGU", "AAACUUUCCCAGGG"]
    # RNAseq = "GUACGUGUGCGU"
    # RNAseq = "AAACUUUCCCAGGG"

    # Minimum loop size
    min_loop_bases = 0

    # Do calculations for all RNA sequences
    for RNAseq in RNAseqs:
        # Calculate nussinov scores and backtrack to get structure
        nussinov_result = nussinov(RNAseq, min_loop_bases)
        backtrack_result = backtrack(nussinov_result, RNAseq)
        # print(nussinov_result)

        # Print sequence and resulting dot-bracket string of predicted secondary structure
        print(RNAseq)
        print(backtrack_result)
