#!/usr/bin/python3
import sys
import argparse
import os
import re
import pandas as pd

def parse_options():
  usage = "\nbins_summary.py -b bins_folder -b busco_summaries_folder -c checkm_report -o output_summary.tsv"
  description = "Creates a summary of bins from different tools"
  parser = argparse.ArgumentParser(usage=usage, description=description)

  input_group = parser.add_argument_group('Required arguments')
  input_group.add_argument("-b","--bin_dir",     dest="bin_dir",       help="Folder with bins in fasta format", required=True, metavar="")
  input_group.add_argument("-d","--bin_depths",  dest="bin_depths",    help="MetaBAT2 aggregateDepths file", required=True, metavar="")
  input_group.add_argument("-B","--busco_dir",   dest="busco_dir",     help="Folder with BUSCO reports", required=False,  metavar="")
  input_group.add_argument("-c","--checkm",      dest="checkm_report", help="Checkm report file", required=False,  metavar="")
  input_group.add_argument("-o","--output_file", dest="output_file",   help="Output name", required=True,  metavar="")

  inputs = parser.parse_args()
  return inputs

def bin_stats(bin_dir,output_file):
  # File containing which contigs belongs to which bin.
  bin_contigs_file = open("binContigs.tsv","w")

  bin_file_list = os.listdir(bin_dir)  
  for bin in bin_file_list:
    bin_file_path = bin_dir+"/"+bin

    # TOTAL SIZE, CONTIG SIZES, GC%
    contig = ""
    total_size = 0
    contig_sizes = []
    contig_len = 0
    GC = 0
    for i,line in enumerate(open(bin_file_path,"r").readlines()):
      line = line.strip()
      if line.startswith(">"):
        contig = bin+"\t"+line.strip(">")
        bin_contigs_file.write(bin+"\t"+contig+"\n")
        if i != 0:
          contig_sizes.append(contig_len)
          contig_len = 0
      else:
        GC += line.count("G")+line.count("C")
        total_size += len(line.strip())
        contig_len += len(line.strip())

    bin_contigs_file.write(bin+"\t"+contig+"\n")
    contig_sizes.append(contig_len)
    
    GC = round(GC/total_size*100,1)
    contigs = len(contig_sizes)
    largest_contig = max(contig_sizes)

    # N50
    size = []
    for contig in sorted(contig_sizes, reverse=True):
      size.append(contig)
      if sum(size) >= total_size * 0.5:
          n50 = contig
          break

    df.loc[bin, ['Contigs','Size','Largest_contig','N50','GC']] = [contigs,total_size,largest_contig,n50,GC]
  bin_contigs_file.close()

def bin_depths(bin_depths):
  for i,line in enumerate(open(bin_depths,"r").readlines()):
    if i > 0:
      line = line.strip().split()
      bin = line[0].split("/")[-1]
      df.at[bin,'avgDepth'] = line[2]

def read_checkm_report(checkm_report):
  for i,line in enumerate(open(checkm_report,"r").readlines()):
    if not "Bin Id" in line and not "-------" in line:
      line = re.split('  +',line.strip())
      df.loc[line[0]+".fa", ['CheckM_Completeness','CheckM_Contamination','CheckM_Strain-heterogeneity']] = [line[11]+"%",line[12]+"%",line[13]+"%"] 

def read_busco(busco_dir):
  busco_file_list = os.listdir(busco_dir)
  if busco_file_list:
    for busco_file in busco_file_list:
      if "specific" in busco_file:
        bin = ".".join(busco_file.split(".")[3:-1])
        taxonomy = busco_file.split(".")[2]
        busco_file_path = busco_dir+"/"+busco_file
        for line in open(busco_file_path):
          if line.strip().startswith("C"):
            busco_score = line.strip()
            df.loc[bin, ['BUSCO_Taxonomy','BUSCO_score']] = [taxonomy,busco_score]


def main(argv):
  inputs = parse_options()

  out_columns = ['Contigs','Size','Largest_contig','N50','GC','avgDepth','BUSCO_Taxonomy','BUSCO_score','CheckM_Completeness','CheckM_Contamination','CheckM_Strain-heterogeneity']
  bin_file_list = os.listdir(inputs.bin_dir)

  global df
  df = pd.DataFrame(columns=out_columns, index=bin_file_list)

  # Fill dataframe (df) with the input files data.
  bin_stats(inputs.bin_dir, inputs.output_file)
  bin_depths(inputs.bin_depths)
  read_checkm_report(inputs.checkm_report)

  read_busco(inputs.busco_dir)
  df.index.name = "Bin"
  df.to_csv(inputs.output_file, sep="\t")


if __name__ == "__main__":
	main(sys.argv[1:])