# Import required modules
from os import rename, path
from numpy import *
from numpy.linalg import svd, det
from Bio.PDB import PDBParser, PDBList


# PART 1
# Center a matrix by its center of mass
def center_by_mass(m=False):
    """
    Center a matrix (of coordinates) by its center of mass

    :param m: Input matrix of n size
    :return: Centered matrix
    """

    # Check that a matrix is used
    if m is False:
        raise ValueError("You didn't specify a matrix")

    # Center by mass
    center_of_mass = m.sum(1) / m.shape[1]
    centered = m - center_of_mass

    # Return centered matrix
    return centered


# The RMSD algorithm superposition of two structures
# noinspection SpellCheckingInspection
def rmsd(structure_1, structure_2):
    """
    Function to calculate RMSD of structure

    :param structure_1: 3D coordinates of structure 1 (Matrix)
    :param structure_2: 3D coordinates of structure 2 (Matrix)
    :return: RMSD score
    """

    # Center the structures coordinates by the center of mass
    structure_1_centered = center_by_mass(structure_1)
    structure_2_centered = center_by_mass(structure_2)

    # Calculate R (rotation)
    r = structure_2_centered * transpose(structure_1_centered)

    # Singular Value Decomposition
    v, s, w = svd(r)
    w = transpose(w)
    v = transpose(v)

    # Rotation matrix
    u = w * v

    # Check for reflection (roto-reflection)
    if det(u) < 0:
        z = diag([1, 1, -1])
        u = w * z * v

    # Rotate Y by applying Y
    structure_2_rotated = u * structure_2_centered

    # Calculate RMSD (formula)
    rmsd_calc = sqrt(1 / structure_1_centered.shape[1] * sum((structure_1_centered.A - structure_2_rotated.A) ** 2))

    # Return the value of RMSD
    return rmsd_calc


def get_pdb_file(identifier):
    """
    Download PDB file from the protein data bank and save it as <identifier>.pdb
    
    :param identifier: 4-letter identifier
    :return: None
    """

    # Check if file already exist
    if path.isfile("./" + identifier + ".pdb"):
        return None

    # Instantiate PDBList
    pdbl = PDBList()

    # Ask BioPDB.PDBList to get PDB file
    pdbl.retrieve_pdb_file(identifier,
                           file_format="pdb",
                           pdir=".",
                           overwrite=True)

    # Rename the downloaded file for easier use...
    rename("./pdb" + identifier + ".ent", identifier + ".pdb")

    # Done
    return None


def get_ca_coordinates(model):
    """
    Get all CA coordinates of a given model

    :param model: Model object to select coordinates from
    :return: Matrix of coordinates for CA locations
    """

    # Variable containing coordinates of found CA
    coordinates = []

    # Only get coordinates for CA atoms
    for atom in model.get_atoms():
        if atom.get_name() == "CA":
            coordinates.append(atom.get_coord())

    # Transform list of coordinates into a {n,3} matrix and then transpose into a {3,n} matrix
    coordinates_matrix = matrix(coordinates, 'f')
    coordinates_matrix = coordinates_matrix.T

    # Return matrix of coordinates
    return coordinates_matrix


# If this is the main file, run the code below
if __name__ == "__main__":
    # Part 1
    # Matrices of shape a and b
    a = matrix([[18.92238689, 9.18841188, 8.70764463, 9.38130981, 8.53057997],
               [1.12391951, 0.8707568, 1.01214183, 0.59383894, 0.65155349],
               [0.46106398, 0.62858099, -0.02625641, 0.35264203, 0.53670857]], 'f')
    
    b = matrix([[1.68739355, 1.38774297, 2.1959675, 1.51248281, 1.70793414],
               [8.99726755, 8.73213223, 8.86804272, 8.31722197, 8.9924607],
               [1.1668153, 1.1135669, 1.02279055, 1.06534992, 0.54881902]], 'f')
    
    # Calculate rmsd of superposition structure a and b
    print("RMSD calculated to be %.5f for sample a and b" % rmsd(a, b))

    # Part 2
    # Get the 1LCD PDB file
    get_pdb_file("1LCD")
    
    # Instantiate the PDBParser
    parser = PDBParser(QUIET=True)

    # Load structure of 1LCD and get structure model 1 and 2 (0-indexed)
    structure = parser.get_structure("1LCD", "1LCD.pdb")
    model_1 = structure[0]["A"]
    model_2 = structure[1]["A"]

    # Lists containing coordinates for each model's alpha carbon of residues
    model_1_CA_coords = get_ca_coordinates(model_1)
    model_2_CA_coords = get_ca_coordinates(model_2)

    # Calculate rmsd of superposition structure a and b
    print("RMSD calculated to be %.5f for the two models" % rmsd(model_1_CA_coords, model_2_CA_coords))
