## Copyright (C) 2024 Rodney A. Sparapani

## This file is part of nftbart.
## predict.nft2mi.R

## nftbart is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 2 of the License, or
## (at your option) any later version.

## nftbart is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU General Public License for more details.

## You should have received a copy of the GNU General Public License
## along with this program.  If not, see <http://www.gnu.org/licenses/>.

## Author contact information
## Rodney A. Sparapani: rsparapa@mcw.edu

predict.nft2mi = function(
                       ## data
                       object,
                       xftest=object$xftrain,
                       xstest=object$xstrain,
                       ## multi-threading
                       tc=getOption("mc.cores", 1), ##OpenMP thread count
                       ## current process fit vs. previous process fit
                       XPtr=FALSE, ## external pointers not working here
                       ## predictions
                       K=0,
                       events=object$events,
                       FPD=FALSE,
                       probs=c(0.025, 0.975),
                       take.logs=TRUE,
                       na.rm=FALSE,
                       RMST.max=NULL,
                       ##seed=NULL,
                       ## default settings for NFT:BART/HBART/DPM
                       fmu=object$NFT$fmu,
                       soffset=object$soffset,
                       drawDPM=object$drawDPM,
                       ## etc.
                       ...)
{
    if(is.null(object)) stop("No fitted model specified!\n")

    xftest.list <- NULL 
    xstest.list <- NULL

    if(is.list(xftest)) {
        xftest.list <- xftest
        xftest <- xftest.list[[1]]
    }

    if(is.list(xstest)) {
        xstest.list <- xstest
        xstest <- xstest.list[[1]]
    }

    n = nrow(object$xftrain[[1]])
    np = nrow(xftest)
    if(np!=nrow(xstest))
        stop('The number of rows in xftest and xstest must be the same!')
    pf = ncol(object$xftrain[[1]])
    if(pf!=ncol(xftest))
        stop('The number of columns in xftrain and xftest must be the same!')
    ps = ncol(object$xstrain[[1]])
    if(ps!=ncol(xstest))
        stop('The number of columns in xstrain and xstest must be the same!')
    if(FPD && np!=(n*(np%/%n)))
        stop('The number of FPD blocks must be an integer')
    events.matrix=FALSE
    
    if(length(RMST.max)>0) {
        K=0
    } else if(length(K)==0) {
        K=0
        take.logs=FALSE
    } else if(K>0) {
        if(length(events)==0) {
            ##events = unique(quantile(object$z.train.mean,
            events = unique(quantile(object$times,
                                      probs=(1:K)/(K+1)))
            attr(events, 'names') = NULL
            take.logs=FALSE
            K = length(events)
        } else if(length(events)!=K) {
            stop("K and the length of events don't match")
        }
    } else if(K==0 && length(events)>0) {
        events.matrix=(class(events)[1]=='matrix')
        if(events.matrix) {
            if(FPD)
                stop("Friedman's partial dependence function: can't be used with a matrix of events")
            K=ncol(events)
        } else K = length(events)
    }
    if(K>0 && take.logs) events=log(events)
    
    object. <- object
    res.list <-  list()

    mult.impute <- object$mult.impute
    if(length(soffset) == 1) soffset[2:mult.impute] <- soffset[1]

    for(i in 1:mult.impute) {
        object$xftrain <- object.$xftrain[[i]]
        object$xstrain <- object.$xstrain[[i]]

        if(is.list(xftest.list)) xftest <- xftest.list[[i]]
        if(is.list(xstest.list)) xstest <- xstest.list[[i]]

        ptr <- c('ots', 'oid', 'ovar', 'oc', 'otheta', 
                 'sts', 'sid', 'svar', 'sc', 'stheta',
                 'f.trees', 's.trees', 's.train.mask',
                 'dpmu', 'dpsd', 'dpmu.', 'dpsd.', 'dpwt.') 
        for(var in ptr) 
            eval(parse(text=paste0('object$', var, 
                                   ' <- object.$', var, '[[i]]')))

        attr(object, 'class') <- 'nft2'

        res <- predict(
                       ## data
                       object,
                       xftest=xftest,
                       xstest=xstest,
                       ## multi-threading
                       tc=tc, ##OpenMP thread count
                       ## current process fit vs. previous process fit
                       XPtr = FALSE, ## external pointers not working here
                       ## predictions
                       K=K,
                       events=events,
                       FPD=FPD, 
                       probs=probs,
                       take.logs=take.logs,
                       na.rm=na.rm,
                       RMST.max=RMST.max,
                       ##seed=NULL,
                       ## default settings for NFT:BART/HBART/DPM
                       fmu=fmu,
                       soffset=soffset[i],
                       drawDPM=drawDPM,
                       ## etc.
                       ...)

        res.list[[i]] <- res
    }

    res <- res.list[[1]]
    
    for(i in 2:mult.impute) {
        if(FPD) {
            res$surv.fpd <- rbind(res$surv.fpd, res.list[[i]]$surv.fpd)
            res$pdf.fpd <- rbind(res$pdf.fpd, res.list[[i]]$pdf.fpd)
            res$haz.fpd <- rbind(res$haz.fpd, res.list[[i]]$haz.fpd)
        } else {
            res$f.test <- rbind(res$f.test, res.list[[i]]$f.test)
            res$s.test <- rbind(res$s.test, res.list[[i]]$s.test)
            res$surv.test <- rbind(res$surv.test, res.list[[i]]$surv.test)
            res$pdf.test <- rbind(res$pdf.test, res.list[[i]]$pdf.test)
            res$haz.test <- rbind(res$haz.test, res.list[[i]]$haz.test)
        }
        res$soffset[i] <- res.list[[i]]$soffset
        res$elapsed[i] <- res.list[[i]]$elapsed
    }

    lower <- min(probs)
    upper <- max(probs)

    if(FPD) {
        ## ndpost <- nrow(res$f.test)
        ## H <- np/n
        ## res$surv.fpd <- matrix(nrow = ndpost, ncol = H)
        ## res$pdf.fpd <- matrix(nrow = ndpost, ncol = H)
        ## res$haz.fpd <- matrix(nrow = ndpost, ncol = H)
        ## h <- 1:n
        ## for(i in 1:H) {
        ##     res$surv.fpd[ , i] <- apply(res$surv.test[ , (i-1)*n+h], 1, mean)
        ##     res$pdf.fpd[ , i] <- apply(res$pdf.test[ , (i-1)*n+h], 1, mean)
        ##     res$haz.fpd[ , i] <- apply(res$haz.test[ , (i-1)*n+h], 1, mean)
        ## }
        res$surv.fpd.mean <- apply(res$surv.fpd, 2, mean)
        res$surv.fpd.lower <- apply(res$surv.fpd, 2, quantile, probs = lower)
        res$surv.fpd.upper <- apply(res$surv.fpd, 2, quantile, probs = upper)
        res$pdf.fpd.mean <- apply(res$pdf.fpd, 2, mean)
        res$pdf.fpd.lower <- apply(res$pdf.fpd, 2, quantile, probs = lower)
        res$pdf.fpd.upper <- apply(res$pdf.fpd, 2, quantile, probs = upper)
        res$haz.fpd.mean <- apply(res$haz.fpd, 2, mean)
        res$haz.fpd.lower <- apply(res$haz.fpd, 2, quantile, probs = lower)
        res$haz.fpd.upper <- apply(res$haz.fpd, 2, quantile, probs = upper)
    } else {
        res$f.test.mean <- apply(res$f.test, 2, mean)
        res$f.test.lower <- apply(res$f.test, 2, quantile, probs = lower)
        res$f.test.upper <- apply(res$f.test, 2, quantile, probs = upper)
        res$s.test.mean <- apply(res$s.test, 2, mean)
        res$s.test.lower <- apply(res$s.test, 2, quantile, probs = lower)
        res$s.test.upper <- apply(res$s.test, 2, quantile, probs = upper)
        res$surv.test.mean <- apply(res$surv.test, 2, mean)
        res$surv.test.lower <- apply(res$surv.test, 2, quantile, probs = lower)
        res$surv.test.upper <- apply(res$surv.test, 2, quantile, probs = upper)
        res$pdf.test.mean <- apply(res$pdf.test, 2, mean)
        res$pdf.test.lower <- apply(res$pdf.test, 2, quantile, probs = lower)
        res$pdf.test.upper <- apply(res$pdf.test, 2, quantile, probs = upper)
        res$haz.test.mean <- apply(res$haz.test, 2, mean)
        res$haz.test.lower <- apply(res$haz.test, 2, quantile, probs = lower)
        res$haz.test.upper <- apply(res$haz.test, 2, quantile, probs = upper)
    }

    res$elapsed.sum <- sum(res$elapsed)
    
    return(res)
}



