##############################
##### seed eQTL analysis #####
##############################

#### prepare the script working environment ####
  remove(list = ls())
  gc()
  set.seed(1000)  
  
  # Set working directory ####
  work.dir <- "C:/Users/harta005/Projects/seed-germination-qtl"
  setwd(work.dir)

  # dependencies ####
  if (!requireNamespace("BiocManager", quietly = TRUE))
    install.packages("BiocManager")
  # BiocManager::install("Mfuzz")
  # BiocManager::install("topGO")
  # BiocManager::install("org.At.tair.db")
  # BiocManager::install("Rgraphviz")
  # BiocManager::install("org.At.eg.db")
  library(doParallel)
  library(dplyr)
  library(ggplot2)
  library(heritability)
  library(topGO)
  library(org.At.tair.db)
  # unused libraries ####
  # library(Mfuzz)
  # library(Biobase)
  # library(gplots)
  # library(RColorBrewer)
  # library(limma)
  # library(factoextra)
  # library(stats)
  # library(amap)
  # library(forcats)
  # library(tidyr)
  # library(VennDiagram)
  # library(gridExtra)
  # library(RCy3)
  # library(corrplot)
  # library(reshape2)
  # library(Hmisc)
  # library(igraph)
  # library(threejs)
  # library(MASS)
  # library(UpSetR)

  
  
  
  
# load required function ####
  setwd('functions/')
  for(i in 1:length(dir())){
    source(dir()[i])
  } # read function from Mark
  setwd(work.dir)
  
  write.EleQTL <- function(map1.output,filename){
    
    selector <- cbind(trait = rownames(map1.output$LOD), pval = apply(map1.output$LOD,1,max,na.rm=T)) %>%
      data.frame()
    
    rownames(selector) <- NULL
    
    lod <- map1.output$LOD
    lod <- lod[rownames(lod) %in% selector[,1],]
    rownames(lod) <- selector$trait
    colnames(lod) <- map1.output$Marker[,1]
    
    eff <- map1.output$Effect
    eff <- eff[rownames(eff) %in% selector[,1],]
    rownames(eff) <- selector$trait
    colnames(eff) <- map1.output$Marker[,1]
    
    lod.eff <- lod*sign(eff)
    
    dat <- map1.output$Trait
    dat <- dat[rownames(dat) %in% selector[,1],]
    rownames(dat) <- selector$trait
    colnames(dat) <- colnames(map1.output$Map)
    
    map <- map1.output$Map
    rownames(map) <- map1.output$Marker[,1]
    
    marker <- map1.output$Marker
    
    write.table(lod,file=paste(filename,"_lod.txt",sep=""),sep="\t",quote=F)
    write.table(eff,file=paste(filename,"_eff.txt",sep=""),sep="\t",quote=F)
    write.table(lod.eff,file=paste(filename,"_lodxeff.txt",sep=""),sep="\t",quote=F)
    write.table(marker,file=paste(filename,"_marker.txt",sep=""),sep="\t",quote=F)
    write.table(dat,file=paste(filename,"_data.txt",sep=""),sep="\t",quote=F)
    write.table(map,file=paste(filename,"_map.txt",sep=""),sep="\t",quote=F)                         
  } # function to convert mapping result to tables
  map.per.marker <- function(trait, marker) {
    model <- lm(terms(trait ~ marker, keep.order = FALSE))
    summ <- summary(model)
    pval <- summ$coefficients[2, 4]
    lod <- -log10(pval)
    eff <- summ$coefficients[2, 1]
    output <- c(lod, eff)
    return(output)
  } # map the QTL at a marker location
  map.all.marker <- function(trait, markers) {
    eff.out <- rep(NA, nrow(markers))
    pval.out <- rep(NA, nrow(markers))
    for (i in 1:nrow(markers)) {
      if(i == 1) {
        out.tmp <- map.per.marker(trait, markers[i, ])
      }
      if( i != 1 & sum(abs(as.numeric(markers[i-1,])  - as.numeric(markers[i,])), na.rm = T) != 0 ) {
        out.tmp <- map.per.marker(trait, markers[i, ])
      }
      if( i != 1 & sum(abs(as.numeric(markers[i-1,]) - as.numeric(markers[i,])), na.rm = T) == 0 ) {
        out.tmp <- out.tmp
      }
      pval.out[i] <- out.tmp[1]
      eff.out[i] <- out.tmp[2]
      output.lod <- cbind(pval.out, eff.out)
      colnames(output.lod) <- c('LOD', 'Eff')
    }
    return(output.lod)
  } # map the QTL using genome wide markers
  threshold.determination <- function(trait, strain.map, n.perm){
    
    traits <- t(replicate(n.perm, trait))
    perm.trait <- permutate.traits(traits)
    
    ###Check for NAs
    pval.out <- matrix(NA,nrow(perm.trait),nrow(strain.map))
    
    for (i in 1:nrow(perm.trait)) {
      pval.out[i, ] <- fast.lod.all.marker(perm.trait[i, ], strain.map)
    }
    
    pval.distribution <- apply(pval.out, 1, max)
    threshold <- quantile(x = pval.distribution, probs = 0.95)
    return(threshold)
  } # determine threshold for a QTL

# load required dataset ####
  trait.matrix <- as.matrix(read.csv(file = 'files/trait-matrix.csv', row.names = 1))
  genetic.map <- as.matrix(read.csv(file = 'files/genetic-map.csv'))
  trait.matrix <- trait.matrix[, colnames(trait.matrix) %in% colnames(genetic.map)] # remove sample without genetic map
  trait.matrix <- trait.matrix[, 17:176] #remove parent sample
  genetic.map <- genetic.map[, 17:176] #remove parent sample
  marker <- read.csv('files/marker.csv', row.names = 1)
  sample.list <- read.csv(file = 'files/sample-list.csv')
  sample.stage <- sample.list$stage
  gene.info <- read.csv('files/gene.info.csv', row.names = 1)
  ril.stage <- substr(x = colnames(trait.matrix), start = 8, stop = 10) 
  stage <- 'pd'
  n.cores <- detectCores() - 1
  
# QTL mapping ####
  for (stage in seed.stage) {
  map <- genetic.map[, which(ril.stage == stage)]
  map <- apply(map, 2, as.numeric) # make sure the alleles are treated as numeric
  trait <- trait.matrix[, which(ril.stage == stage)]
  
  qtl.data <- QTL.data.prep(trait.matrix = trait, 
                strain.trait = colnames(trait), 
                strain.map = map, 
                strain.marker = marker)
  cluster <- makeCluster(n.cores, type = "PSOCK")
  registerDoParallel(cluster)
  output <- foreach(i = 1:nrow(trait), .combine = 'cbind') %dopar% {
    map.all.marker(trait = trait[i, ], markers = map)
  }
  stopCluster(cluster)
  
  pval.out <- t(output[, which(colnames(output) == 'LOD')])
  eff.out <- t(output[, which(colnames(output) == 'Eff')])
  
  colnames(pval.out) <- rownames(marker); rownames(pval.out) <- rownames(trait)
  colnames(eff.out) <- rownames(marker); rownames(eff.out) <- rownames(trait)
  
  qtl.profile <- NULL; qtl.profile <- as.list(qtl.profile)
  qtl.profile[[1]] <- round(pval.out,digits=2)
  qtl.profile[[2]] <- round(eff.out,digits=3)
  qtl.profile[[3]] <- trait
  qtl.profile[[4]] <- map
  qtl.profile[[5]] <- marker
  names(qtl.profile) <- c("LOD","Effect","Trait","Map","Marker")
  
  write.EleQTL(map1.output = qtl.profile, filename = paste0("qtl-tables/table_single-stage-eqtl_", stage))
  #saveRDS(object = qtl.profile, 
          #file = paste0("qtl-profiles/profile_single-stage-eqtl_", stage,".rds")) 
  }
  
# permutation and FDR determination 
  ## based on Benjamini-Yekutieli
  # setwd(dir = "qtl-permutations/")
  # cluster <- makeCluster(n.cores, type = "PSOCK")
  # registerDoParallel(cluster)
  # foreach(i = 1:100) %dopar% {
  #   qtl.perm <- map.1.perm(trait.matrix = trait, 
  #                          strain.map = map, 
  #                          strain.marker = marker,
  #                          n.perm = 1)
  #   save(qtl.perm, file = paste0("perm_single-stage-eqtl_", stage, i, ".RData"))
  # }
  # stopCluster(cluster)
  # 
  # filenames.perm <- dir()
  # filenames.perm <- filenames.perm[grep(paste("perm_single-stage-eqtl", stage, sep = "."), filenames.perm)]
  # 
  # FDR <- map.perm.fdr(map1.output = qtl.profile,
  #                     filenames.perm = filenames.perm,
  #                     FDR_dir = paste0(getwd(), "/"),
  #                     q.value = 0.05)
  # saveRDS(FDR, file = paste0('fdr_single-stage-eqtl_', stage, '.RDS'))
  # 
  # setwd(work.dir)
  # 

# eQTL peak finder and table ####
  
  for (stage in seed.stage) {
    # threshold
    threshold <- ifelse(stage == 'pd' | stage == 'ar', 4.2, ifelse(stage == 'im', 4.1, ifelse(stage =='rp', 4.3, NA)))
    # the thresholds are based on multiple-testing correction using 100 permuted datasets
    qtl.profile <- readRDS(paste0('qtl-profiles/profile_single-stage-eqtl_', stage, '.rds'))
    qtl.peak <- mapping.to.list(map1.output = qtl.profile) %>%
      peak.finder(threshold = threshold)
    qtl.peak <- na.omit(qtl.peak)
    saveRDS(object = qtl.peak,
            file = paste0("qtl-peaks/peak_single-stage-eqtl_", stage, ".rds"))

# eQTL table ####
    qtl.profile <- readRDS(paste0('qtl-profiles/profile_single-stage-eqtl_', stage, '.rds'))
    qtl.peak <- readRDS(paste0('qtl-peaks/peak_single-stage-eqtl_', stage, '.rds'))
    eqtl.table <- eQTL.table(peak.list.file = qtl.peak, trait.annotation = gene.info) %>%
      eQTL.table.addR2(QTL.prep.file = qtl.profile)
    eqtl.table$qtl_chromosome <- as.factor(eqtl.table$qtl_chromosome)
    eqtl.table$gene_chromosome <- as.factor(eqtl.table$gene_chromosome)
    
    # add heritability - single stage ####

    h2.result <- as.list(NULL)
    n.perm <- 100
    
    qtl.genes <- eqtl.table$trait
    
    map <- genetic.map[, which(ril.stage == stage)]
    map <- apply(map, 2, as.numeric) # make sure the alleles are treated as numeric
    trait <- trait.matrix[, which(ril.stage == stage)]
      
    map2 <- map
    #map2[map2 == 0] <- 0.5
    map2[map2 == -1] <- round(0, 0)
    #map2[is.na(map2)] <- 0.5
    kinship.matrix <- emma.kinship(map2)
    colnames(kinship.matrix) <- colnames(map2); rownames(kinship.matrix) <- colnames(map2)
      
    cluster <- makeCluster(n.cores, type = 'PSOCK')
    clusterExport(cl = cluster, c('trait', 'map2', 'marker_h2', 'kinship.matrix', 'h2.REML'))
    h2 <- t(parApply(cl = cluster, X = trait, MARGIN = 1, FUN = h2.REML, strain.names = colnames(map2), kinship.matrix = kinship.matrix, Vg.factor = 1))
    stopCluster(cluster)
    
    h2 <- cbind.data.frame(trait = rownames(h2), h2)
    eqtl.table <- merge(x = eqtl.table, y = h2[, 1:2], by = 'trait', all.x = T)
    #eqtl.table <- rename(.data = eqtl.table, h2_REML = h2)
      
    ###permutation
    # print(paste0('permuting', " ", stage))
    #   
    # cluster <- makeCluster(n.cores, type = "PSOCK")
    # registerDoParallel(cluster)
    # perm.output <- foreach(i = 1:n.perm, .combine = 'cbind', 
    #                        .export = c('trait.tmp', 'map.tmp', 'marker_h2', 'kinship.matrix', 'h2.REML')) %dopar% {
    #                          perm.trait <- t(apply(trait.tmp, 1, function(x){x <- x[order(runif(length(x)))];return(x)}))
    #                          t(apply(perm.trait, 1, h2.REML, strain.names = colnames(map.tmp2), kinship.matrix = kinship.matrix, Vg.factor = 1))[, 1]
    #                        }
    # stopCluster(cluster)
    # perm.result <- cbind(h2, FDR0.05_REML = apply(perm.output, 1, quantile, 0.95))
    # 
    # h2.result[[stage]] <- perm.result
    # 
    # 
    # for (stage in development.stage) {
    #   h2.result[[stage]] <- cbind.data.frame(h2.result[[stage]], trait = rownames(h2.result[[stage]]))
    # }
    
# trans-bands identification ####
    window.nu <- 2e6
    maxsize <- 100e6
    chr.num <- 5

    transband.id <- mutate(eqtl.table, interval = findInterval(qtl_bp, seq(1, maxsize, by = window.nu))) %>%
      group_by(qtl_chromosome, interval, qtl_type) %>%
      summarise(n.ct = length(unique(trait))) %>%
      data.frame() %>%
      group_by(qtl_type) %>%
      mutate(exp.ct = mean(as.numeric(unlist(n.ct)))) %>%
      data.frame() %>%
      mutate(transband_significance = ppois(n.ct, lambda = exp.ct, lower.tail = F)) %>%
      filter(transband_significance < 0.0001, qtl_type == "trans")
    
    transband.id$transband_id <- with(transband.id, paste0("ch", qtl_chromosome, ":", 
                                                           (interval - 1) * 2, "-", 
                                                           interval * 2, "Mb"))
    transband.id$stage <- stage
    
    saveRDS(object = transband.id,
            file = paste0("trans-bands/trans.band_", stage, ".rds"))

    # for pd
    if(stage == 'pd') {
      eqtl.table <- mutate(eqtl.table,
                           trans_band = ifelse(qtl_type == "trans" &
                                                 qtl_chromosome == 1 &
                                                 qtl_bp > 6e6 &
                                                 qtl_bp <= 10e6, "ch1:6-10Mb",
                                               ifelse(qtl_type == "trans" &
                                                        qtl_chromosome == 3 &
                                                        qtl_bp > 8e6 &
                                                        qtl_bp <= 12e6, "ch3:8-12Mb", "none")))
    }
    
    # for ar 
    if( stage == 'ar') {
    eqtl.table <- mutate(eqtl.table,
                         trans_band = ifelse(qtl_type == "trans" &
                                               qtl_chromosome == 2 &
                                               qtl_bp > 12e6 &
                                               qtl_bp <= 14e6, "ch2:12-14Mb",
                                             ifelse(qtl_type == "trans" &
                                                      qtl_chromosome == 3 &
                                                      qtl_bp > 2e6 &
                                                      qtl_bp <= 4e6, "ch3:2-4Mb", "none")))
    }
    
    # for im
    if( stage == 'im' ) {
      eqtl.table <- mutate(eqtl.table,
                           trans_band = ifelse(qtl_type == "trans" &
                                                 qtl_chromosome == 5 &
                                                 qtl_bp > 6e6 &
                                                 qtl_bp <= 8e6, "ch5:6-8Mb",
                                               ifelse(qtl_type == "trans" &
                                                        qtl_chromosome == 5 &
                                                        qtl_bp > 22e6 &
                                                        qtl_bp <= 26e6, "ch5:22-26Mb","none")))
    }
    
    
    # for rp
    if(stage == 'rp') {
      eqtl.table <- mutate(eqtl.table,
                           trans_band = ifelse(qtl_type == "trans" &
                                                 qtl_chromosome == 1 &
                                                 qtl_bp > 0e6 &
                                                 qtl_bp <= 2e6, "ch1:0-2Mb ",
                                         ifelse(qtl_type == "trans" &
                                                             qtl_chromosome == 1 &
                                                             qtl_bp > 6e6 &
                                                             qtl_bp <= 8e6, "ch1:6-8Mb",
                                         ifelse(qtl_type == "trans" &
                                                             qtl_chromosome == 5 &
                                                             qtl_bp > 14e6 &
                                                             qtl_bp <= 16e6, "ch5:14-16Mb",
                                          ifelse(qtl_type == "trans" &
                                                              qtl_chromosome == 5 &
                                                              qtl_bp > 24e6 &
                                                              qtl_bp <=26e6, "ch5:24-26Mb", "none")))))
    }
    
    write.csv(x = eqtl.table,
              file = paste0("qtl-tables/table_single-stage-eqtl_", stage, ".csv"), 
              row.names = T)
    saveRDS(object = eqtl.table,
            file = paste0("qtl-tables/table_single-stage-eqtl_", stage, ".rds"))
    table(eqtl.table$qtl_type, eqtl.table$trans_band!="none")
  }
    # GO for trans bands - single stage ####
    ontology <- 'BP' #c('BP', 'CC', 'MF')
    
    x <- org.At.tairCHR
    all.genes <- as.list(rownames(trait.matrix))
    transband.go <- as.data.frame(matrix(data = NA, nrow = 0, ncol = 9))
    colnames(transband.go) <- c('GO.ID', 'Term', 'Annotated', 'Significant', 'Expected', 'Fisher', 
                          'FDR', 'stage', 'transband')
    transband.go.perstage <- transband.go
    
    for (stage in seed.stage) {
      eqtl.table <- readRDS(paste0('qtl-tables/table_single-stage-eqtl_', stage, '.rds'))
      transband.id <- unique(eqtl.table$trans_band)
      transband.id <- transband.id[!transband.id %in% 'none']
      transband.go.perstage <- as.data.frame(matrix(data = NA, nrow = 0, ncol = 9))
      colnames(transband.go.perstage) <- c('GO.ID', 'Term', 'Annotated', 'Significant', 'Expected', 'Fisher', 
                            'FDR', 'stage', 'transband')
      
      for (i in 1:length(transband.id)) {
        gene.set <- eqtl.table[which(eqtl.table$trans_band == transband.id[i]), 'trait']
        gene.set <- factor(as.integer(all.genes %in% gene.set))
        names(gene.set) <- all.genes
        GOdata <- new("topGOdata",
                      description = "GOE for genes in regulated by trans bands", ontology = ontology,
                      allGenes = gene.set, 
                      annot = annFUN.org,mapping= "org.At.tair.db")
        resultFisher <- runTest(GOdata, algorithm = 'weight', statistic = "fisher")
        result.df <- GenTable(GOdata, Fisher = resultFisher,
                              orderBy = "Fisher", ranksOf = "Fisher", topNodes = length(resultFisher@score))
        result.df$stage <- stage
        result.df$transband <- transband.id[i]
        result.df$Fisher <- as.numeric(result.df$Fisher)
        result.df$FDR <- p.adjust(p = result.df$Fisher, method = 'fdr')
        result.df <- result.df[order(result.df$FDR), ]
        result.df <- dplyr::filter(result.df, Fisher <= 0.01)
        transband.go.perstage <- rbind(transband.go.perstage, result.df)
      }
      transband.go <- rbind(transband.go, transband.go.perstage)
    }
    
    write.csv(transband.go, paste0('files/trans-bands-go-', ontology, '.csv'))