# Import required packages
import os
from Bio.PDB import PDBParser, PDBList
from numpy import *
import pandas as pd
from pymol import cmd
from ImmuneBuilder import ABodyBuilder2


# Amino acids three letter to one letter dictionary
amino_acids = {
    "ALA": "A",
    "ARG": "R",
    "ASN": "N",
    "ASP": "D",
    "CYS": "C",
    "GLU": "E",
    "GLN": "Q",
    "GLY": "G",
    "HIS": "H",
    "ILE": "I",
    "LEU": "L",
    "LYS": "K",
    "MET": "M",
    "PHE": "F",
    "PRO": "P",
    "SER": "S",
    "THR": "T",
    "TRP": "W",
    "TYR": "Y",
    "VAL": "V"
}


def get_pdb_file(identifier, overwrite_files=False):
    """
    Download PDB file from the protein data bank and save it in structures folder as <identifier>.pdb

    Args:
        identifier (Str): 4-letter PDB identifier
        overwrite_files (Bool): Overwrite already downloaded structure, default=False

    Returns:
        None
    """

    # Check if structures folder exist
    if not os.path.isdir("./structures/"):
        os.mkdir("./structures/")

    # Check if file already exist
    if os.path.isfile("./structures/" + identifier + ".pdb") and not overwrite_files:
        return None

    # Else-if PDB file already exists, remove this file
    elif os.path.isfile("./structures/" + identifier + ".pdb") and overwrite_files:
        os.remove("./structures/" + identifier + ".pdb")

    # Instantiate PDBList
    pdbl = PDBList()

    # Ask BioPDB.PDBList to get PDB file
    pdbl.retrieve_pdb_file(identifier,
                           file_format="pdb",
                           pdir="./structures/",
                           overwrite=overwrite_files)

    # Rename the downloaded file
    os.rename("./structures/pdb" + identifier + ".ent", "./structures/" + identifier + ".pdb")

    # Return None
    return None


def get_chain_from_pdb(pdb_file, chain_names, choosen_model=0):
    """
    Extract a specific chain's sequence from a PDB file

    Args:
        pdb_file (Str): Name and location of PDB file
        chain_names (List): List chain(s) name(s) to extract
        choosen_model (Int): Which model number to get chain(s) from

    Returns:
        Dictionary of sequences for chains
    """

    # Get the PDB id from filename
    pdb_id = os.path.basename(pdb_file)

    # Instanciate PDBParser
    parser = PDBParser(QUIET=True)

    # Get structure and specified model of protein from PDB file
    structure = parser.get_structure(pdb_id, filename)
    model = structure[choosen_model]

    # Dictionary to store found chain
    chain_sequences = {}

    # Loop through the list of specified chain names
    for chain_name in chain_names:

        # List of found residues in current chain
        residues = []

        # Find all residue names in chain
        try:
            for residue in model[chain_name]:

                # Get the name of the current residue
                res_name = residue.get_resname()

                # If the residue is an amino acid, we add it to the list of residues for the chain
                if res_name in amino_acids.keys():
                    residues.append(amino_acids[res_name])

        # If the chain doesn't exist we suppress the error and continue
        except KeyError:
            break

        # Put the found chain sequence into the list of found sequences
        chain_sequences[chain_name] = "".join(residues)

    # Return the chain sequences found
    return chain_sequences


if __name__ == "__main__":
    # Data set
    antibodies = ["8ahn", "8cwi", "8f6l", "8bse", "7wsl", "8ffe", "7ox3", "8dfi", "7str", "7vke", "4n1c", "4xvt",
                  "7vux", "7amr", "6mvl", "6svl", "6wzk", "5bjz", "5ggv", "3sob", "3nps", "6k64", "6kyz", "5n88",
                  "5hi4", "6sge", "7kql", "8dce", "7mzg", "5ggs", "6k3m", "5u5m", "7jum", "7nfr", "3p0y", "7jmp",
                  "7mdp", "6cbv", "6k6a", "7z0x", "6ba5", "5wk3", "4fqi", "4ydl", "3mxw", "7e9b", "6ban", "5vkd",
                  "5uea", "3l5x", "5u3d", "2vxq", "7so5", "6b0g", "7lfb", "4xmp", "7lm8", "4tsb", "7bz5", "5ob5",
                  "6vy4", "6aod", "6was", "5umn", "7kmh", "7q0g", "7neh", "5uek", "4h8w", "3d85", "6nmt", "6k0y",
                  "6phb", "4hjg", "5u5f", "6a3w", "5ucb", "4m62", "7b3o", "6dkj", "7mzi", "7n4n", "6qb3", "4nzr",
                  "7rxp", "6qfc", "7lm9", "4ioi", "2nxy", "5n7w", "7sd5", "6b0s", "6iea", "5l6y", "7eow", "5ngv",
                  "2yc1", "3u2s", "2uzi", "2xwt", "2ny2", "5w0d", "6e56", "6o39", "5u6a", "5vag", "4dtg", "4i77",
                  "7kmi", "7n3c", "3se9", "7fcq", "7lsg", "6meh", "7d6y", "3se8", "5m2j", "4al8", "4j6r", "7d5b",
                  "4xvs", "7n3d", "7rp3", "5sy8"]

    # Dictionary of accepted PDBs from list above
    accepted_pdbs = {}

    # Loop through all antibody PDBs specified above
    for antibody in antibodies:

        # Download PDB files using BioPDB
        get_pdb_file(antibody)

        # Construct file name variable of save location
        filename = "./structures/" + antibody + ".pdb"
        sequences = get_chain_from_pdb(filename, ["H", "L"])

        # If we were able to get a H and L sequences from PDB, the we accept it
        # Unfortunately, some of the PDBs in the dataset have named light and heavy chains other names than H and L
        # I omit these from the exercise
        if "H" and "L" in sequences.keys():
            accepted_pdbs[antibody] = sequences

    # Create an output directory, if it does not exist
    output_location = "./predictions/"

    if not os.path.isdir(output_location):
        os.mkdir(output_location)

    # Predicted structures
    predicted_structures = []

    # Loop through all accept PDB files of antibodies
    for antibody in accepted_pdbs.keys():

        # Set path of output file
        output_file = "./predictions/" + antibody + ".pdb"

        # If the structure have already been predicted, skip
        if os.path.isfile(output_file):
            predicted_structures.append(antibody)
            continue

        # Instanciate ABodyBuilder2
        antibody_builder = ABodyBuilder2()

        # Predict antibody and save it to the output file
        try:
            predict_antibody = antibody_builder.predict(accepted_pdbs[antibody])
            predict_antibody.save(output_file)
            predicted_structures.append(antibody)

        # In case we get an error from the prediction, we move on to the next antibody
        except AssertionError:
            continue

    # Dictionary containing RMSD values of superpositioning predicted structure onto crystal structure
    rmsd_values = {}

    # Calculate RMSD of superpositioning predicted structure on crystal structure
    # Loop through all predicted structures
    for structure in predicted_structures:

        # Set object name of structures for use in PyMol
        crystal_name = "crystal_" + structure
        prediction_name = "prediction_" + structure

        # Set crystal and prediction file paths
        crystal_pdb = "./structures/" + structure + ".pdb"
        prediction_pdb = "./predictions/" + structure + ".pdb"

        # Load structures into PyMol
        cmd.load(filename=crystal_pdb, object=crystal_name)
        cmd.load(filename=prediction_pdb, object=prediction_name)

        # Select H and L chains
        select_crystal = crystal_name + " and chain H+L"
        select_crystal_name = crystal_name + "_chains"

        select_prediction = prediction_name + " and chain H+L"
        select_prediction_name = prediction_name + "_chains"

        cmd.select(select_crystal_name, select_crystal)
        cmd.select(select_prediction_name, select_prediction)

        # Superposition predicted structure onto crystal structure
        superposition = cmd.align(select_crystal_name, select_prediction_name, object="Superposition")

        # Save RMSD value to dictionary
        rmsd = superposition[0]
        rmsd_values[structure] = rmsd

        # Delete everything in PyMol before we continue
        cmd.delete("*")

    # Save RMSD values to file using Pandas DataFrame
    rmsd_df = pd.DataFrame.from_dict(rmsd_values, orient="index", columns=["RMSD"])
    rmsd_df.to_csv("./protein.csv", sep=";", index_label="PDB_ID")
