Skip to content
Snippets Groups Projects
Commit e49000d7 authored by Lannoy, Carlos de's avatar Lannoy, Carlos de
Browse files

allow using pre-made index files in validate_16S

parent d92b49d1
No related branches found
No related tags found
No related merge requests found
......@@ -366,6 +366,8 @@ def get_validate_parser():
parser.add_argument('--nn-dir', required=True, type=lambda x: check_input_path(x),
help='Directory containing pre-made nns for single k-mers, to be used directly')
parser.add_argument('--ground-truth-16s', required=True, type=str, help='csv denoting which species reads belong to')
parser.add_argument('--index-files', type=lambda x: check_input_path(x),
help='If provided, do not generate index files for CV folds but use the ones in this directory.')
# parser.add_argument('--primed-nn-dir', required=True, type=str, help='Directory containing NNs trained on held-out set')
parser.add_argument('--dryrun', action='store_true')
return parser
......
......@@ -2,8 +2,6 @@ import os, sys, re
import pandas as pd
from os.path import basename
from shutil import copy
from pathlib import Path
from distutils.dir_util import copy_tree
from sklearn.model_selection import StratifiedKFold
from jinja2 import Template
from datetime import datetime
......@@ -90,17 +88,26 @@ for sp in species_list:
raise ValueError(f'No reads for target species {sp} listed in ground truth file!')
test_read_dir = parse_output_path(f'{args.out_dir}test_reads/', clean=True)
for fi, (train_num_idx, _) in enumerate(StratifiedKFold(n_splits=args.nb_folds,
shuffle=True).split(fast5_df.index,
fast5_df.species)):
for sp in species_list:
_ = parse_output_path(f'{test_read_dir}{sp}/fold_{fi}')
train_idx = fast5_df.index[train_num_idx]
fast5_df.loc[:, f'fold_{fi}'] = False
fast5_df.loc[train_idx, f'fold_{fi}'] = True
fast5_df.to_csv(f'{read_index_dir}index_fold{fi}.csv',
columns=['fn', 'species', f'fold_{fi}'],
header=['fn', 'species', f'fold'])
if args.index_files:
index_fn_list = sorted(os.listdir(args.index_files))
for fi, fn in enumerate(index_fn_list):
idx_df = pd.read_csv(f'{args.index_files}{fn}',index_col=0)
fast5_df.loc[:, f'fold_{fi}'] = idx_df.loc[fast5_df.index, 'fold']
copy(f'{args.index_files}{fn}', f'{read_index_dir}index_fold{fi}.csv')
for sp in species_list:
_ = parse_output_path(f'{test_read_dir}{sp}/fold_{fi}')
else:
for fi, (train_num_idx, _) in enumerate(StratifiedKFold(n_splits=args.nb_folds,
shuffle=True).split(fast5_df.index,
fast5_df.species)):
for sp in species_list:
_ = parse_output_path(f'{test_read_dir}{sp}/fold_{fi}')
train_idx = fast5_df.index[train_num_idx]
fast5_df.loc[:, f'fold_{fi}'] = False
fast5_df.loc[train_idx, f'fold_{fi}'] = True
fast5_df.to_csv(f'{read_index_dir}index_fold{fi}.csv',
columns=['fn', 'species', f'fold_{fi}'],
header=['fn', 'species', f'fold'])
print(f'{datetime.now()}: writing test read dirs...')
ff_idx_list = np.array_split(np.arange(len(fast5_df)), args.cores)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment