diff --git a/caretta/msa_numba.py b/caretta/msa_numba.py index 00e00f9727af8b104e0a847933c614cefd7064e0..522f2a1fbff437c48426b9ef37960a4830a621d9 100644 --- a/caretta/msa_numba.py +++ b/caretta/msa_numba.py @@ -165,6 +165,20 @@ class Structure: features = feature_extraction.get_features(str(pdb_file), str(dssp_dir), only_dssp=only_dssp, force_overwrite=force_overwrite) return cls(pdb_name, pdb_file, sequence, helper.secondary_to_array(features["secondary"]), features, coordinates) + @classmethod + def from_pdb_id(cls, pdb_name: str, chain: str, dssp_dir="caretta_tmp", + extract_all_features=True, force_overwrite=False): + pdb = pd.parsePDB(pdb_name).select("protein").select(f"chain {chain}") + pdb_file = pd.writePDB(pdb_name, pdb) + alpha_indices = helper.get_alpha_indices(pdb) + sequence = pdb[alpha_indices].getSequence() + coordinates = pdb[alpha_indices].getCoords().astype(np.float64) + only_dssp = (not extract_all_features) + features = feature_extraction.get_features(str(pdb_file), str(dssp_dir), only_dssp=only_dssp, + force_overwrite=force_overwrite) + return cls(pdb_name, pdb_file, sequence, helper.secondary_to_array(features["secondary"]), features, + coordinates) + @dataclass class OutputFiles: