From 5b1631c0d2406c56e683a302cd90b8b0ce31f81a Mon Sep 17 00:00:00 2001 From: Sam Overduin <sam.overduin@wur.nl> Date: Tue, 3 Dec 2019 17:49:07 +0100 Subject: [PATCH] Improved taxonomy assignment & implemented taxonomy transition filter. barcode_index.hpp: - Initialized TaxIdEncoder with "0" taxatree - Added func TaxIdEncoder::ToTaxaTreeVector - Bugfix EdgeEntry::GetTaxonomy() barcode_info_extractor.hpp: - Added funcs: ToTaxaTreeVector, TaxaTreeFromTaxId, GetTaxaTreeFromEdge barcode_index_construction.cpp: - Moved ToTaxaTreeVector to extractor & barcode_index - Improved majority_vote_lca to actually use LCA construction_callers.cpp & hpp: - Added TaxaBreakConstructorCaller to break mismatched taxa in contracted assembly graph read_cloud_connection_conditions.cpp: - Added TaxaBreakPredicate that returns false when transition (scaffold_edge) has incompatible taxonomy scaffold_graph_construction_pipeline.cpp: - Added TaxaBreak routine to basic mode & scaffolding mode --- src/common/barcode_index/barcode_index.hpp | 19 ++++- .../barcode_index/barcode_info_extractor.hpp | 16 ++++ .../construction_callers.cpp | 24 ++++++ .../construction_callers.hpp | 23 +++++ .../read_cloud_connection_conditions.cpp | 38 +++++++++ .../read_cloud_connection_conditions.hpp | 16 ++++ .../scaffold_graph_construction_pipeline.cpp | 5 ++ .../spades/barcode_index_construction.cpp | 85 ++++++++++++------- 8 files changed, 191 insertions(+), 35 deletions(-) diff --git a/src/common/barcode_index/barcode_index.hpp b/src/common/barcode_index/barcode_index.hpp index 9a4adb35..3cc5cf1b 100644 --- a/src/common/barcode_index/barcode_index.hpp +++ b/src/common/barcode_index/barcode_index.hpp @@ -84,7 +84,8 @@ namespace barcode_index { TaxIdEncoder(): codes_(), codes_rev_() - { } + { string empty_taxa_tree = "0"; //initialise with empty taxatree + AddTaxaTree(empty_taxa_tree);} void AddTaxaTree(string &taxatree) { auto it = codes_.find(taxatree); @@ -106,6 +107,20 @@ namespace barcode_index { return std::stoi(taxid_str); } + std::vector<TaxId> ToTaxaTreeVector(const std::string& taxa_tree_string, const char sep='.') const { + std::vector<TaxId> taxa_tree_vect; + std::string taxa; + TaxId taxid; + std::stringstream tree_stream(taxa_tree_string); // Insert the string into a stream + while(getline(tree_stream, taxa, sep)) { + // string to uint64_t + std::stringstream taxa_stream(taxa); + taxa_stream >> taxid; + taxa_tree_vect.push_back(taxid); + } + return taxa_tree_vect; + } + TaxId GetCode (const string& taxatree) const { VERIFY(codes_.find(taxatree) != codes_.end()); return codes_.at(taxatree); @@ -590,7 +605,7 @@ namespace barcode_index { taxonomy_ = taxid; } - TaxId GetTaxonomy() { + TaxId GetTaxonomy() const { return taxonomy_; } diff --git a/src/common/barcode_index/barcode_info_extractor.hpp b/src/common/barcode_index/barcode_info_extractor.hpp index 8fc3fea5..aaec31df 100644 --- a/src/common/barcode_index/barcode_info_extractor.hpp +++ b/src/common/barcode_index/barcode_info_extractor.hpp @@ -153,6 +153,22 @@ namespace barcode_index { return index_.edge_to_entry_.at(edge_id_).taxid_distribution_.at(taxid).GetCount(); } + std::vector<TaxId> ToTaxaTreeVector(const std::string& taxa_tree_string, const char sep='.') const { + return index_.taxatree_codes_.ToTaxaTreeVector(taxa_tree_string, sep); + } + + string TaxaTreeFromTaxId(TaxId& taxid) const { + return index_.taxatree_codes_.GetTaxaTree(taxid); + } + + string GetTaxaTreeFromEdge(const EdgeId& edge_id_) const { + //INFO("EdgeId: " << edge_id_); + TaxId taxid = index_.edge_to_entry_.at(edge_id_).GetTaxonomy(); + //INFO("GetTaxonomy: "<< taxid); + string taxatree = TaxaTreeFromTaxId(taxid); + return taxatree; + } + typename taxid_distribution_t::const_iterator taxid_iterator_begin(const EdgeId &edge) const { auto entry_it = index_.GetEntryHeadsIterator(edge); return entry_it->second.taxid_begin(); // second is EdgeEntry diff --git a/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/construction_callers.cpp b/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/construction_callers.cpp index 31c314b2..2e92199f 100644 --- a/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/construction_callers.cpp +++ b/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/construction_callers.cpp @@ -146,12 +146,14 @@ CompositeConnectionConstructorCaller::CompositeConnectionConstructorCaller( gp_(gp), main_extractor_(main_extractor), long_edge_extractor_(barcode_extractor), unique_storage_(unique_storage), search_parameter_pack_(search_parameter_pack), scaff_con_configs_(scaff_con_configs), max_threads_(max_threads), scaffolding_mode_(scaffolding_mode) {} + EdgeSplitConstructorCaller::EdgeSplitConstructorCaller( const Graph &g_, std::shared_ptr<barcode_index::SimpleScaffoldVertexIndexInfoExtractor> barcode_extractor_, std::size_t max_threads_) : IterativeScaffoldGraphConstructorCaller("Conjugate filter"), g_(g_), barcode_extractor_(barcode_extractor_), max_threads_(max_threads_) {} + std::shared_ptr<path_extend::scaffolder::ScaffoldGraphConstructor> EdgeSplitConstructorCaller::GetScaffoldGraphConstuctor( const ScaffolderParams ¶ms, const ScaffoldGraph &scaffold_graph) const { @@ -164,6 +166,28 @@ std::shared_ptr<path_extend::scaffolder::ScaffoldGraphConstructor> EdgeSplitCons max_threads_); return constructor; } + +//TODO: Add taxonomy getter to barcode_extractor_ +TaxaBreakConstructorCaller::TaxaBreakConstructorCaller( + const Graph &g_, + std::shared_ptr<barcode_index::FrameBarcodeIndexInfoExtractor> barcode_extractor_, + std::size_t max_threads_) + : IterativeScaffoldGraphConstructorCaller("Taxonomy based filter"), + g_(g_), barcode_extractor_(barcode_extractor_), max_threads_(max_threads_) {} + +std::shared_ptr<path_extend::scaffolder::ScaffoldGraphConstructor> TaxaBreakConstructorCaller::GetScaffoldGraphConstuctor( + const read_cloud::ScaffolderParams ¶ms, + const ScaffoldGraph &scaffold_graph) const { + auto predicate = std::make_shared<TaxaBreakPredicate>(g_, barcode_extractor_); + auto constructor = + std::make_shared<path_extend::scaffolder::PredicateScaffoldGraphFilter>(g_, + scaffold_graph, + predicate, + max_threads_); + return constructor; +} + + TransitiveConstructorCaller::TransitiveConstructorCaller(const Graph &g_, std::size_t max_threads_) : IterativeScaffoldGraphConstructorCaller("Transitive filter"), diff --git a/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/construction_callers.hpp b/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/construction_callers.hpp index 9c232bf1..93a30bd0 100644 --- a/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/construction_callers.hpp +++ b/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/construction_callers.hpp @@ -159,6 +159,29 @@ class EdgeSplitConstructorCaller : public IterativeScaffoldGraphConstructorCalle std::shared_ptr<barcode_index::SimpleScaffoldVertexIndexInfoExtractor> barcode_extractor_; std::size_t max_threads_; }; +/** ConstructorCaller that filters wrong taxonomic transitions. + * Eg 1.2.3 -/-> 1.2.4 (mismatch in taxonomy so wrong), + * 1.2.3 --> 0 (taxonomy undefined so not wrong), + * 1.2.3 --> 1.2 (within same hierarchy so not wrong). + */ + +class TaxaBreakConstructorCaller : public IterativeScaffoldGraphConstructorCaller { +public: + using IterativeScaffoldGraphConstructorCaller::ScaffoldGraph; + TaxaBreakConstructorCaller(const Graph &g_, + std::shared_ptr<barcode_index::FrameBarcodeIndexInfoExtractor> barcode_extractor_, + std::size_t max_threads_); + + std::shared_ptr<scaffolder::ScaffoldGraphConstructor> GetScaffoldGraphConstuctor( + const read_cloud::ScaffolderParams ¶ms, + const ScaffoldGraph &scaffold_graph) const override; + +private: + const Graph &g_; + std::shared_ptr<barcode_index::FrameBarcodeIndexInfoExtractor> barcode_extractor_; + std::size_t max_threads_; +}; + /** ConstructorCaller that filters transitive connections. */ class TransitiveConstructorCaller : public IterativeScaffoldGraphConstructorCaller { diff --git a/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/read_cloud_connection_conditions.cpp b/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/read_cloud_connection_conditions.cpp index 55176ea1..f98535cd 100644 --- a/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/read_cloud_connection_conditions.cpp +++ b/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/read_cloud_connection_conditions.cpp @@ -110,6 +110,44 @@ bool ReadCloudMiddleDijkstraPredicate::Check(const scaffold_graph::ScaffoldGraph return false; } +TaxaBreakPredicate::TaxaBreakPredicate( + const Graph &g_, + std::shared_ptr<barcode_index::FrameBarcodeIndexInfoExtractor> barcode_extractor) + : g_(g_), + barcode_extractor_(barcode_extractor) {} + +bool TaxaBreakPredicate::Check(const ScaffoldEdge &scaffold_edge) const { + //barcode_extractor_. ; + static const std::string null_taxa = "0"; + auto first = scaffold_edge.getStart(); + auto second = scaffold_edge.getEnd(); + DEBUG("In TaxaBreakPredicate, Edge_id: " << scaffold_edge.getId() << " length: " << scaffold_edge.getLength()); + auto seq1 = first.GetSequence(g_); + auto seq2 = second.GetSequence(g_); + DEBUG("Vertex_1_id: " << first.int_id() << " Taxonomy: " << barcode_extractor_->GetTaxaTreeFromEdge(first.int_id()) << " Length: " << seq1.size()); + DEBUG("Vertex_2_id: " << second.int_id() << " Taxonomy: " << barcode_extractor_->GetTaxaTreeFromEdge(second.int_id()) << " Length: " << seq2.size()); + std::string taxatree_1 = barcode_extractor_->GetTaxaTreeFromEdge(first.int_id()); + std::string taxatree_2 = barcode_extractor_->GetTaxaTreeFromEdge(second.int_id()); + + if (taxatree_1 == null_taxa || taxatree_2 == null_taxa || + taxatree_1.find(taxatree_2) != std::string::npos || taxatree_2.find(taxatree_1) != std::string::npos) { + DEBUG("Taxatree match!"); + return true; + } + else { + DEBUG("Taxatree mismatch!"); + } + return false; + + + //TODO: implement taxonomy in barcode_extractor and vertexes!! + //barcode_extractor_-> + //get taxonomy for first and second + //first. + //bool result = true; + //return result; +} + EdgeSplitPredicate::EdgeSplitPredicate( const Graph &g_, std::shared_ptr<barcode_index::SimpleScaffoldVertexIndexInfoExtractor> barcode_extractor, diff --git a/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/read_cloud_connection_conditions.hpp b/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/read_cloud_connection_conditions.hpp index fdb22480..6d786421 100644 --- a/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/read_cloud_connection_conditions.hpp +++ b/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/read_cloud_connection_conditions.hpp @@ -107,6 +107,22 @@ class CompositeConnectionPredicate : public ScaffoldEdgePredicate { DECL_LOGGER("CompositeConnectionPredicate"); }; +class TaxaBreakPredicate : public ScaffoldEdgePredicate { +public: + using ScaffoldEdgePredicate::ScaffoldEdge; + typedef barcode_index::BarcodeId BarcodeId; + TaxaBreakPredicate(const Graph &g_, + std::shared_ptr<barcode_index::FrameBarcodeIndexInfoExtractor> barcode_extractor); + + bool Check(const ScaffoldEdge &scaffold_edge) const override; + +private: + const Graph &g_; + std::shared_ptr<barcode_index::FrameBarcodeIndexInfoExtractor> barcode_extractor_; + + DECL_LOGGER("TaxaBreakPredicate"); +}; + class EdgeSplitPredicate : public ScaffoldEdgePredicate { public: using ScaffoldEdgePredicate::ScaffoldEdge; diff --git a/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/scaffold_graph_construction_pipeline.cpp b/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/scaffold_graph_construction_pipeline.cpp index 89a67482..1bb4b018 100644 --- a/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/scaffold_graph_construction_pipeline.cpp +++ b/src/common/modules/path_extend/read_cloud_path_extend/scaffold_graph_construction/scaffold_graph_construction_pipeline.cpp @@ -99,6 +99,7 @@ std::vector<ScaffoldGraphConstructionPipeline::ResultT> ScaffoldGraphConstructio ScaffoldGraphPipelineConstructor::ScaffoldGraphPipelineConstructor(const ReadCloudConfigsT &configs, const Graph &g) : read_cloud_configs_(configs), g_(g) {} + std::shared_ptr<ScaffoldGraphPipelineConstructor::ScaffoldVertexExtractor> ScaffoldGraphPipelineConstructor::ConstructSimpleEdgeIndex( const std::set<ScaffoldGraphPipelineConstructor::ScaffoldVertex> &scaffold_vertices, ScaffoldGraphPipelineConstructor::BarcodeIndexPtr barcode_extractor, @@ -241,6 +242,8 @@ std::vector<std::shared_ptr<IterativeScaffoldGraphConstructorCaller>> FullScaffo unique_storage_, search_parameter_pack_, read_cloud_configs_.scaff_con, max_threads_, scaffolding_mode)); + iterative_constructor_callers.push_back( + std::make_shared<TaxaBreakConstructorCaller>(gp_.g, barcode_extractor_, max_threads_)); const size_t min_pipeline_length = read_cloud_configs_.long_edge_length_lower_bound; bool launch_full_pipeline = min_length_ > min_pipeline_length; @@ -316,6 +319,8 @@ std::vector<std::shared_ptr<IterativeScaffoldGraphConstructorCaller>> MergingSca split_scaffold_index_extractor, max_threads_)); iterative_constructor_callers.push_back(std::make_shared<TransitiveConstructorCaller>(g_, max_threads_)); + iterative_constructor_callers.push_back( + std::make_shared<TaxaBreakConstructorCaller>(g_, barcode_extractor_, max_threads_)); return iterative_constructor_callers; } } //path_extend diff --git a/src/projects/spades/barcode_index_construction.cpp b/src/projects/spades/barcode_index_construction.cpp index 4b33fe19..bcf0a63e 100644 --- a/src/projects/spades/barcode_index_construction.cpp +++ b/src/projects/spades/barcode_index_construction.cpp @@ -18,21 +18,10 @@ namespace debruijn_graph { return has_read_clouds; } - std::vector<TaxId> ToTaxaTreeVector(const std::string& taxa_tree_string, const char sep='.') { - std::vector<TaxId> taxa_tree_vect; - std::string taxa; - TaxId taxid; - std::stringstream tree_stream(taxa_tree_string); // Insert the string into a stream - while(getline(tree_stream, taxa, sep)) { - // string to uint64_t - std::stringstream taxa_stream(taxa); - taxa_stream >> taxid; - taxa_tree_vect.push_back(taxid); - } - return taxa_tree_vect; - } - TaxId majority_vote_lca(std::vector<string> &taxatree_str_vec, std::vector<size_t> &count_vec) { + + TaxId majority_vote_lca(std::vector<string>& taxatree_str_vec, std::vector<size_t>& count_vec, + const FrameBarcodeIndexInfoExtractor& extractor) { std::vector<std::vector<TaxId>> taxatree_vec; size_t longest_lineage = 0; size_t total_counts = 0; @@ -40,7 +29,7 @@ namespace debruijn_graph { // transfer taxatree_strings to taxid_vectors. for ( const std::string& taxatree_str : taxatree_str_vec ) { if (taxatree_str != "0") { - std::vector<TaxId> taxid_vec = ToTaxaTreeVector(taxatree_str, '.'); + std::vector<TaxId> taxid_vec = extractor.ToTaxaTreeVector(taxatree_str, '.'); taxatree_vec.push_back(taxid_vec); longest_lineage = std::max(taxid_vec.size(), longest_lineage); } @@ -57,33 +46,62 @@ namespace debruijn_graph { total_counts += count; } size_t min_majority = total_counts/2.0; - //min_majority = 3; //temporary for testing purposes ToDo: remove this one. // find index of taxid with highest count auto most_common_index = std::distance(count_vec.begin(), std::max_element(count_vec.begin(), count_vec.end())); - if (count_vec[most_common_index] > min_majority and total_counts > (0.2*(total_counts + null_counts))){ - return taxatree_vec[most_common_index].back(); + TaxId lca = 0; + VERIFY_MSG(taxatree_vec.size() == count_vec.size(), "ERROR: taxatree_vec.size is not count_vec.size in max_lca"); + if (count_vec.size() > 0) { //in this case only taxa 0 was part of taxatree_str_vec. + if (total_counts > (0.20 * (total_counts + null_counts))) { //minimum 20% assigned taxids. + size_t i = 0; + TaxId proposed_lca = 0; + size_t prop_lca_counts = total_counts; + while (i < longest_lineage && prop_lca_counts >= min_majority) { + prop_lca_counts = 0; + if (i < taxatree_vec[most_common_index].size()){ + proposed_lca = taxatree_vec[most_common_index][i]; + } else { + size_t max_sub_count = 0; + for ( auto taxatree : taxatree_vec ) { + if ( i < taxatree.size() && count_vec[i] > max_sub_count ) { + //get most common taxa as proposed_lca within taxatree that aren't too long. + proposed_lca = taxatree[i]; + max_sub_count = count_vec[i]; + } + } + } + size_t count_vec_pos = 0; + for ( auto taxatree : taxatree_vec ) { + if ( i < taxatree.size() && taxatree[i] == proposed_lca) { + prop_lca_counts += count_vec[count_vec_pos]; + } + count_vec_pos += 1; + } + if ( prop_lca_counts >= min_majority ) { + lca = proposed_lca; + } + ++i; + } + } } - return 0; + return lca; } - TaxId last_common_ancestor(std::vector<string> &taxatree_vec, std::vector<size_t> &count_vec) { - VERIFY_MSG(taxatree_vec.size() == count_vec.size(), "ERROR: taxatree_vec.size is not count_vec.size during lca"); + TaxId last_common_ancestor(std::vector<string> &taxatree_vec, std::vector<size_t> &count_vec, + const FrameBarcodeIndexInfoExtractor& extractor) { + VERIFY_MSG(taxatree_vec.size() == count_vec.size(), "ERROR: taxatree_vec.size is not count_vec.size before lca"); // wrapper to change underlying lca_algorithm. TaxId lca = 0; - lca = majority_vote_lca(taxatree_vec, count_vec); + //INFO("Starting LCA"); + lca = majority_vote_lca(taxatree_vec, count_vec, extractor); + //INFO("Done LCA)"); return lca; } void assign_taxonomy_to_edges(barcode_index::FrameBarcodeIndex<debruijn_graph::DeBruijnGraph>& barcodeindex, const FrameBarcodeIndexInfoExtractor& extractor){ - //for edge in edge_iterator - // get taxid_lst + count_lst - // taxid_lst to taxatree_lst - // LCA = lca_method(taxatree_lst, count_lst) - // edge.SetTaxonomy(LCA) for ( auto &p : barcodeindex.edge_to_entry_ ) { std::vector<string> taxatree_vector; std::vector<size_t> count_vector; @@ -101,13 +119,14 @@ namespace debruijn_graph { count = extractor.GetTaxidCount(edge, taxid); count_vector.push_back(count); } - TaxId lca = last_common_ancestor(taxatree_vector, count_vector); - p.second.SetTaxonomy(lca); - INFO("EdgeId: " << edge); + DEBUG("EdgeId: " << edge << " Frame_size: " << p.second.GetFrameSize() << " # of frames: " << p.second.GetNumberOfFrames()); for (size_t i = 0; i != taxatree_vector.size(); i++ ){ - INFO("TaxaTree: " << taxatree_vector[i] << ", Count: " << count_vector[i]); + DEBUG("TaxaTree: " << taxatree_vector[i] << ", Count: " << count_vector[i]); } - INFO("Taxonomy: " << barcodeindex.edge_to_entry_.at(edge).GetTaxonomy()) + // Beware that lca function can mess with count_vector length so just use once. + TaxId lca = last_common_ancestor(taxatree_vector, count_vector, extractor); + p.second.SetTaxonomy(lca); + DEBUG("Taxonomy: " << barcodeindex.edge_to_entry_.at(edge).GetTaxonomy()) } } @@ -137,7 +156,7 @@ namespace debruijn_graph { mapper_builder.FillMap(reads, graph_pack.index, graph_pack.kmer_mapper); INFO("Barcode index construction finished."); FrameBarcodeIndexInfoExtractor extractor(graph_pack.barcode_mapper, graph_pack.g); - assign_taxonomy_to_edges(graph_pack.barcode_mapper, extractor); // function remade as method of barcode_mapper + assign_taxonomy_to_edges(graph_pack.barcode_mapper, extractor); INFO("Taxonomy assigned to all edges"); size_t length_threshold = cfg::get().pe_params.read_cloud.long_edge_length_lower_bound; INFO("Average barcode coverage: " + std::to_string(extractor.AverageBarcodeCoverage(length_threshold))); -- GitLab