Commit e5196843 authored by Akdel's avatar Akdel
Browse files

debug (removed njit)

parent 12cdbd70
......@@ -13,19 +13,19 @@ from caretta import psa_numba as psa
from caretta import rmsd_calculations, helper
@nb.njit
# @nb.njit
def get_common_coordinates(coords_1, coords_2, aln_1, aln_2, gap=-1):
assert aln_1.shape == aln_2.shape
pos_1, pos_2 = helper.get_common_positions(aln_1, aln_2, gap)
return coords_1[pos_1], coords_2[pos_2]
@nb.njit(parallel=True)
# @nb.njit
def make_pairwise_dtw_score_matrix(coords_array, secondary_array, lengths_array, gamma,
gap_open_penalty: float, gap_extend_penalty: float,
gap_open_sec, gap_extend_sec):
pairwise_matrix = np.zeros((coords_array.shape[0], coords_array.shape[0]))
for i in nb.prange(pairwise_matrix.shape[0] - 1):
for i in range(pairwise_matrix.shape[0] - 1):
for j in range(i + 1, pairwise_matrix.shape[1]):
dtw_aln_1, dtw_aln_2, score = psa.get_pairwise_alignment(coords_array[i, :lengths_array[i]], coords_array[j, :lengths_array[j]],
secondary_array[i, :lengths_array[i]], secondary_array[j, :lengths_array[j]],
......@@ -45,7 +45,7 @@ def make_pairwise_dtw_score_matrix(coords_array, secondary_array, lengths_array,
return pairwise_matrix
@nb.njit(parallel=True)
# @nb.njit(parallel=True)
def make_pairwise_rmsd_score_matrix(coords_array, secondary_array, lengths_array, gamma,
gap_open_penalty: float, gap_extend_penalty: float,
gap_open_sec, gap_extend_sec):
......@@ -70,7 +70,7 @@ def make_pairwise_rmsd_score_matrix(coords_array, secondary_array, lengths_array
return pairwise_matrix
@nb.njit
# @nb.njit
def _get_alignment_data(coords_1, coords_2, secondary_1, secondary_2, gamma,
gap_open_sec, gap_extend_sec,
gap_open_penalty: float, gap_extend_penalty: float):
......@@ -92,7 +92,7 @@ def _get_alignment_data(coords_1, coords_2, secondary_1, secondary_2, gamma,
return aln_coords_1, aln_coords_2, aln_sec_1, aln_sec_2, dtw_aln_1, dtw_aln_2
@nb.njit
# @nb.njit
def get_mean_coords_extra(aln_coords_1: np.ndarray, aln_coords_2: np.ndarray) -> np.ndarray:
"""
Mean of two coordinate sets (of the same shape)
......@@ -116,7 +116,7 @@ def get_mean_coords_extra(aln_coords_1: np.ndarray, aln_coords_2: np.ndarray) ->
return mean_coords
@nb.njit
# @nb.njit
def get_mean_secondary(aln_sec_1: np.ndarray, aln_sec_2: np.ndarray, gap=0) -> np.ndarray:
"""
Mean of two coordinate sets (of the same shape)
......@@ -165,6 +165,22 @@ class Structure:
features = feature_extraction.get_features(str(pdb_file), str(dssp_dir), only_dssp=only_dssp, force_overwrite=force_overwrite)
return cls(pdb_name, pdb_file, sequence, helper.secondary_to_array(features["secondary"]), features, coordinates)
@classmethod
def mutate_from_pdb_file(cls, pdb_file: typing.Union[str, Path], dssp_dir="caretta_tmp",
extract_all_features=True, force_overwrite=False, randomness=0.1):
pdb_name = helper.get_file_parts(pdb_file)[1]
pdb = pd.parsePDB(str(pdb_file)).select("protein")
alpha_indices = helper.get_alpha_indices(pdb)
sequence = pdb[alpha_indices].getSequence()
coordinates = pdb[alpha_indices].getCoords().astype(np.float64)
coordinates += np.random.normal(coordinates.mean(), randomness, coordinates.shape)
only_dssp = (not extract_all_features)
features = feature_extraction.get_features(str(pdb_file), str(dssp_dir), only_dssp=only_dssp,
force_overwrite=force_overwrite)
return cls(pdb_name, pdb_file, sequence, helper.secondary_to_array(features["secondary"]), features,
coordinates)
@classmethod
def from_pdb_id(cls, pdb_name: str, chain: str, dssp_dir="caretta_tmp",
extract_all_features=True, force_overwrite=False):
......@@ -242,10 +258,12 @@ class StructureMultiple:
output_feature_filename=Path("./result_features.pkl"),
output_class_filename=Path("./result_class.pkl")):
lengths_array = np.array([len(s.sequence) for s in structures])
# print(lengths_array)
max_length = np.max(lengths_array)
coords_array = np.zeros((len(structures), max_length, structures[0].coords.shape[1]))
secondary_array = np.zeros((len(structures), max_length))
for i in range(len(structures)):
# print(i, structures[i].secondary.shape)
coords_array[i, :lengths_array[i]] = structures[i].coords
secondary_array[i, :lengths_array[i]] = structures[i].secondary
if not Path(output_pdb_folder).exists():
......@@ -295,6 +313,61 @@ class StructureMultiple:
output_class_filename)
return msa_class
@classmethod
def mutate_multiple_from_single_pdb_file(cls, input_pdb, filename,
dssp_dir=Path("./caretta_tmp/"), num_threads=20,
extract_all_features=False,
consensus_weight=1.,
output_fasta_filename=Path("./result.fasta"),
output_pdb_folder=Path("./result_pdb/"),
output_feature_filename=Path("./result_features.pkl"),
output_class_filename=Path("./result_class.pkl"),
overwrite_dssp=False, number=10, noise_level=0.1):
from shutil import copyfile
if not Path(output_pdb_folder).exists():
Path(output_pdb_folder).mkdir()
cleaned_pdb_folder = Path(helper.get_file_parts(output_pdb_folder)[0]) / "cleaned_pdb"
if not cleaned_pdb_folder.exists():
cleaned_pdb_folder.mkdir()
for i in range(number):
copyfile(filename, str(cleaned_pdb_folder / f"n_{i}.pdb"))
if not Path(dssp_dir).exists():
Path(dssp_dir).mkdir()
pdbs = [pd.parsePDB(cleaned_pdb_folder / f"n_{i}.pdb").select("protein") for i in range(number)]
alpha_indices = [helper.get_alpha_indices(pdb) for pdb in pdbs]
sequences = [pdbs[i][alpha_indices[i]].getSequence() for i in range(len(pdbs))]
coordinates = [np.hstack((pdbs[i][alpha_indices[i]].getCoords().astype(np.float64) +
np.random.normal(0,
np.mean(pdbs[i][alpha_indices[i]].getCoords().astype(
np.float64)) * noise_level,
pdbs[i][alpha_indices[i]].getCoords().astype(np.float64).shape),
np.zeros((len(sequences[i]), 1)) + consensus_weight))
for i in range(len(pdbs))]
only_dssp = (not extract_all_features)
print("Extracting features...")
features = feature_extraction.get_features_multiple([cleaned_pdb_folder / f"n_{i}.pdb" for i in range(number)],
str(dssp_dir),
num_threads=num_threads, only_dssp=only_dssp,
force_overwrite=True)
structures = []
name = "a"
for i in range(len(pdbs)):
name += name
pdb_name = str(name) # helper.get_file_parts(pdb_files[i])[1]
# print(len(sequences[i]))
# print(len(features[i]["secondary"]))
structures.append(Structure(pdb_name,
None,
sequences[i],
helper.secondary_to_array(features[i]["secondary"]),
features[i],
coordinates[i]))
msa_class = StructureMultiple.from_structures(structures, output_fasta_filename, output_pdb_folder,
output_feature_filename,
output_class_filename)
return msa_class
def get_pairwise_matrix(self, gap_open_penalty, gap_extend_penalty, gamma=0.03, gap_open_sec=1., gap_extend_sec=0.1):
return make_pairwise_rmsd_score_matrix(self.coords_array,
self.secondary_array,
......@@ -307,10 +380,10 @@ class StructureMultiple:
def align_from_pdb_files(input_pdb,
dssp_dir="caretta_tmp", num_threads=20, extract_all_features=True,
gap_open_penalty=1., gap_extend_penalty=0.01, consensus_weight=1.,
write_fasta=True, output_fasta_filename=Path("./result.fasta"),
write_pdb=True, output_pdb_folder=Path("./result_pdb/"),
write_features=True, output_feature_filename=Path("./result_features.pkl"),
write_class=True, output_class_filename=Path("./result_class.pkl"),
write_fasta=False, output_fasta_filename=Path("./result.fasta"),
write_pdb=False, output_pdb_folder=Path("./result_pdb/"),
write_features=False, output_feature_filename=Path("./result_features.pkl"),
write_class=False, output_class_filename=Path("./result_class.pkl"),
overwrite_dssp=False):
"""
Caretta aligns protein structures and returns a sequence alignment, a set of aligned feature matrices, superposed PDB files, and
......@@ -390,7 +463,7 @@ class StructureMultiple:
self.alignment = {self.structures[0].name: "".join([self.structures[0].sequence[i] if i != -1 else '-' for i in dtw_1]),
self.structures[1].name: "".join([self.structures[1].sequence[i] if i != -1 else '-' for i in dtw_2])}
return self.alignment
print("start pairwise")
if pw_matrix is None:
pw_matrix = make_pairwise_dtw_score_matrix(self.coords_array,
self.secondary_array,
......@@ -603,3 +676,14 @@ class StructureMultiple:
transformation = pd.Transformation(rotation_matrix.T, translation_matrix)
pdb = pd.applyTransformation(transformation, pdb)
pd.writePDB(str(output_pdb_folder / f"{name}.pdb"), pdb)
if __name__ == "__main__":
from os import listdir
from sys import argv
print(listdir("./"))
pdb_folder = "/mnt/local_scratch/akdel001/caretta/test_data/"
pdb_file = "/mnt/local_scratch/akdel001/caretta/test_data/1kdu.pdb"
al = StructureMultiple.from_pdb_files(pdb_folder)
alignment = al.align(1, 0.01)
......@@ -5,7 +5,7 @@ from caretta import dynamic_time_warping as dtw
from caretta import rmsd_calculations, helper
@nb.njit
# @nb.njit
# @numba_cc.export('make_signal_index', 'f64[:](f64[:], i64)')
def make_signal_index(coords, index):
centroid = coords[index]
......@@ -15,7 +15,7 @@ def make_signal_index(coords, index):
return distances
@nb.njit
# @nb.njit
# @numba_cc.export('dtw_signals_index', '(f64[:], f64[:], i64, i64, i64)')
def dtw_signals_index(coords_1, coords_2, index, size=30, overlap=1):
signals_1 = np.zeros(((coords_1.shape[0] - size) // overlap, size))
......@@ -46,7 +46,7 @@ def dtw_signals_index(coords_1, coords_2, index, size=30, overlap=1):
return coords_1, coords_2
@nb.njit
# @nb.njit
# @numba_cc.export('get_dtw_signal_score_pos', '(f64[:], f64[:], f64, i64, f64, f64)')
def get_dtw_signal_score_pos(coords_1, coords_2, gamma, index, gap_open_penalty, gap_extend_penalty):
coords_1[:, :3], coords_2[:, :3] = dtw_signals_index(coords_1[:, :3], coords_2[:, :3], index)
......@@ -61,7 +61,7 @@ def get_dtw_signal_score_pos(coords_1, coords_2, gamma, index, gap_open_penalty,
return rmsd_calculations.get_caretta_score(common_coords_1, common_coords_2, gamma, False), pos_1, pos_2
@nb.njit
# @nb.njit
# @numba_cc.export('get_secondary_distance_matrix', '(i8[:], i8[:], i8)')
def get_secondary_distance_matrix(secondary_1, secondary_2, gap=0):
score_matrix = np.zeros((secondary_1.shape[0], secondary_2.shape[0]))
......@@ -75,7 +75,7 @@ def get_secondary_distance_matrix(secondary_1, secondary_2, gap=0):
return score_matrix
@nb.njit
# @nb.njit
# @numba_cc.export('get_secondary_rmsd_pos', '(i8[:], i8[:], f64[:], f64[:], f64, f64, f64)')
def get_secondary_rmsd_pos(secondary_1, secondary_2, coords_1, coords_2, gamma, gap_open_sec, gap_extend_sec):
distance_matrix = get_secondary_distance_matrix(secondary_1, secondary_2)
......@@ -87,14 +87,14 @@ def get_secondary_rmsd_pos(secondary_1, secondary_2, coords_1, coords_2, gamma,
return rmsd_calculations.get_caretta_score(common_coords_1, common_coords_2, gamma, False), pos_1, pos_2
@nb.njit
# @nb.njit
# @numba_cc.export('get_pairwise_alignment', '(f64[:], f64[:], i8[:], i8[:], f64, f64, f64, f64, f64, i64)')
def get_pairwise_alignment(coords_1, coords_2,
secondary_1, secondary_2,
gamma,
gap_open_sec, gap_extend_sec,
gap_open_penalty,
gap_extend_penalty, max_iter=100):
gap_extend_penalty, max_iter=3):
rmsd_1, pos_1_1, pos_2_1 = get_dtw_signal_score_pos(coords_1, coords_2, gamma, 0, gap_open_penalty, gap_extend_penalty)
rmsd_2, pos_1_2, pos_2_2 = get_secondary_rmsd_pos(secondary_1, secondary_2, coords_1[:, :3], coords_2[:, :3], gamma, gap_open_sec, gap_extend_sec)
rmsd_3, pos_1_3, pos_2_3 = get_dtw_signal_score_pos(coords_1, coords_2, gamma, -1, gap_open_penalty, gap_extend_penalty)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment