
/* C Header */

/*
    Copyright (C) 2022- Torsten Hothorn

    This file is part of the 'mvtnorm' R add-on package.

    'mvtnorm' is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, version 2.

    'mvtnorm' is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with 'mvtnorm'.  If not, see <http://www.gnu.org/licenses/>.


    DO NOT EDIT THIS FILE

    Edit 'lmvnorm_src.w' and run 'nuweb -r lmvnorm_src.w'
*/

#include <R.h>
#include <Rmath.h>
#include <Rinternals.h>
#include <Rdefines.h>
#include <Rconfig.h>
#include <R_ext/Lapack.h> /* for dtptri */
/* colSumsdnorm */

SEXP R_ltMatrices_colSumsdnorm (SEXP z, SEXP N, SEXP J) {
    /* number of columns */
    int iN = INTEGER(N)[0];
    /* number of rows */
    int iJ = INTEGER(J)[0];
    SEXP ans;
    double *dans, Jl2pi, *dz;

    Jl2pi = iJ * log(2 * PI);
    PROTECT(ans = allocVector(REALSXP, iN));
    dans = REAL(ans);
    dz = REAL(z);

    for (int i = 0; i < iN; i++) {
        dans[i] = 0.0;
        for (int j = 0; j < iJ; j++)
            dans[i] += pow(dz[j], 2);
        dans[i] = - 0.5 * (Jl2pi + dans[i]);
        dz += iJ;
    }
    
    UNPROTECT(1);
    return(ans);
}

/* solve */

SEXP R_ltMatrices_solve (SEXP C, SEXP y, SEXP N, SEXP J, SEXP diag, SEXP transpose)
{

    SEXP ans, ansx;
    double *dans, *dansx, *dy;
    int i, j, k, info, nrow, ncol, jj, idx, ONE = 1;

    /* RC input */
    
    /* pointer to C matrices */
    double *dC = REAL(C);
    /* number of matrices */
    int iN = INTEGER(N)[0];
    /* dimension of matrices */
    int iJ = INTEGER(J)[0];
    /* C contains diagonal elements */
    Rboolean Rdiag = asLogical(diag);
    /* p = J * (J - 1) / 2 + diag * J */
    int len = iJ * (iJ - 1) / 2 + Rdiag * iJ;
    
    /* C length */
    
    int p;
    if (LENGTH(C) == len)
        /* C is constant for i = 1, ..., N */
        p = 0;
    else 
        /* C contains C_1, ...., C_N */
        p = len;
    

    char di, lo = 'L', tr = 'N';
    if (Rdiag) {
        /* non-unit diagonal elements */
        di = 'N';
    } else {
        /* unit diagonal elements */
        di = 'U';
    }

    /* t(C) instead of C */
    Rboolean Rtranspose = asLogical(transpose);
    if (Rtranspose) {
        /* t(C) */
        tr = 'T';
    } else {
        /* C */
        tr = 'N';
    }

    /* setup memory */
    
    /* return object: include unit diagonal elements if Rdiag == 0 */

    /* add diagonal elements (expected by Lapack) */
    nrow = (Rdiag ? len : len + iJ);
    ncol = (p > 0 ? iN : 1);
    PROTECT(ans = allocMatrix(REALSXP, nrow, ncol));
    dans = REAL(ans);

    ansx = ans;
    dansx = dans;
    dy = dans;
    if (y != R_NilValue) {
        dy = REAL(y);
        PROTECT(ansx = allocMatrix(REALSXP, iJ, iN));
        dansx = REAL(ansx);
    }
    
    
    /* loop over matrices, ie columns of C  / y */    
    for (i = 0; i < iN; i++) {

        /* copy elements */
        
        /* copy data and insert unit diagonal elements when necessary */
        if (p > 0 || i == 0) {
            jj = 0;
            k = 0;
            idx = 0;
            j = 0;
            while(j < len) {
                if (!Rdiag && (jj == idx)) {
                    dans[jj] = 1.0;
                    idx = idx + (iJ - k);
                    k++;
                } else {
                    dans[jj] = dC[j];
                    j++;
                }
                jj++;
            }
            if (!Rdiag) dans[idx] = 1.0;
        }

        if (y != R_NilValue) {
            for (j = 0; j < iJ; j++)
                dansx[j] = dy[j];
        }
        
        /* call Lapack */
        
        if (y == R_NilValue) {
            /* compute inverse */
            F77_CALL(dtptri)(&lo, &di, &iJ, dans, &info FCONE FCONE);
            if (info != 0)
                error("Cannot solve ltmatices");
        } else {
            /* solve linear system */
            F77_CALL(dtpsv)(&lo, &tr, &di, &iJ, dans, dansx, &ONE FCONE FCONE FCONE);
            dansx += iJ;
            dy += iJ;
        }
        

        /* next matrix */
        if (p > 0) {
            dans += nrow;
            dC += p;
        }
    }

    /* return objects */
    
    if (y == R_NilValue) {
        UNPROTECT(1);
        /* note: ans always includes diagonal elements */
        return(ans);
    } else {
        UNPROTECT(2);
        return(ansx);
    }
    
}

/* tcrossprod */


/* IDX */

#define IDX(i, j, n, d) ((i) >= (j) ? (n) * ((j) - 1) - ((j) - 2) * ((j) - 1)/2 + (i) - (j) - (!d) * (j) : 0)


SEXP R_ltMatrices_tcrossprod (SEXP C, SEXP N, SEXP J, SEXP diag, 
                              SEXP diag_only, SEXP transpose) {

    SEXP ans;
    double *dans;
    int i, j, n, k, ix, nrow;

    /* RC input */
    
    /* pointer to C matrices */
    double *dC = REAL(C);
    /* number of matrices */
    int iN = INTEGER(N)[0];
    /* dimension of matrices */
    int iJ = INTEGER(J)[0];
    /* C contains diagonal elements */
    Rboolean Rdiag = asLogical(diag);
    /* p = J * (J - 1) / 2 + diag * J */
    int len = iJ * (iJ - 1) / 2 + Rdiag * iJ;
    

    Rboolean Rdiag_only = asLogical(diag_only);
    Rboolean Rtranspose = asLogical(transpose);

    if (Rdiag_only) {
        /* tcrossprod diagonal only */
        
        PROTECT(ans = allocMatrix(REALSXP, iJ, iN));
        dans = REAL(ans);
        for (n = 0; n < iN; n++) {
            /* first element */
            
            dans[0] = 1.0;
            if (Rdiag)
                dans[0] = pow(dC[0], 2);
            if (Rtranspose) { // crossprod
                for (k = 1; k < iJ; k++) 
                    dans[0] += pow(dC[IDX(k + 1, 1, iJ, Rdiag)], 2);
            }
            
            for (i = 1; i < iJ; i++) {
                dans[i] = 0.0;
                if (Rtranspose) { // crossprod
                    for (k = i + 1; k < iJ; k++)
                        dans[i] += pow(dC[IDX(k + 1, i + 1, iJ, Rdiag)], 2);
                } else {         // tcrossprod
                    for (k = 0; k < i; k++)
                        dans[i] += pow(dC[IDX(i + 1, k + 1, iJ, Rdiag)], 2);
                }
                if (Rdiag) {
                    dans[i] += pow(dC[IDX(i + 1, i + 1, iJ, Rdiag)], 2);
                } else {
                    dans[i] += 1.0;
                }
            }
            dans += iJ;
            dC += len;
        }
        
    } else {
        /* tcrossprod full */
        
        nrow = iJ * (iJ + 1) / 2;
        PROTECT(ans = allocMatrix(REALSXP, nrow, iN)); 
        dans = REAL(ans);
        for (n = 0; n < INTEGER(N)[0]; n++) {
            /* first element */
            
            dans[0] = 1.0;
            if (Rdiag)
                dans[0] = pow(dC[0], 2);
            if (Rtranspose) { // crossprod
                for (k = 1; k < iJ; k++) 
                    dans[0] += pow(dC[IDX(k + 1, 1, iJ, Rdiag)], 2);
            }
            
            for (i = 1; i < iJ; i++) {
                for (j = 0; j <= i; j++) {
                    ix = IDX(i + 1, j + 1, iJ, 1);
                    dans[ix] = 0.0;
                    if (Rtranspose) { // crossprod
                        for (k = i + 1; k < iJ; k++)
                            dans[ix] += 
                                dC[IDX(k + 1, i + 1, iJ, Rdiag)] *
                                dC[IDX(k + 1, j + 1, iJ, Rdiag)];
                    } else {         // tcrossprod
                        for (k = 0; k < j; k++)
                            dans[ix] += 
                                dC[IDX(i + 1, k + 1, iJ, Rdiag)] *
                                dC[IDX(j + 1, k + 1, iJ, Rdiag)];
                    }
                    if (Rdiag) {
                        if (Rtranspose) {
                            dans[ix] += 
                                dC[IDX(i + 1, i + 1, iJ, Rdiag)] *
                                dC[IDX(i + 1, j + 1, iJ, Rdiag)];
                        } else {
                            dans[ix] += 
                                dC[IDX(i + 1, j + 1, iJ, Rdiag)] *
                                dC[IDX(j + 1, j + 1, iJ, Rdiag)];
                        }
                    } else {
                        if (j < i)
                            dans[ix] += dC[IDX(i + 1, j + 1, iJ, Rdiag)];
                        else
                            dans[ix] += 1.0;
                    }
                }
            }
            dans += nrow;
            dC += len;
        }
        
    }
    UNPROTECT(1);
    return(ans);
}

/* mult */

SEXP R_ltMatrices_Mult (SEXP C, SEXP y, SEXP N, SEXP J, SEXP diag) {

    SEXP ans;
    double *dans, *dy = REAL(y);
    int i, j, k, start;

    /* RC input */
    
    /* pointer to C matrices */
    double *dC = REAL(C);
    /* number of matrices */
    int iN = INTEGER(N)[0];
    /* dimension of matrices */
    int iJ = INTEGER(J)[0];
    /* C contains diagonal elements */
    Rboolean Rdiag = asLogical(diag);
    /* p = J * (J - 1) / 2 + diag * J */
    int len = iJ * (iJ - 1) / 2 + Rdiag * iJ;
    
    /* C length */
    
    int p;
    if (LENGTH(C) == len)
        /* C is constant for i = 1, ..., N */
        p = 0;
    else 
        /* C contains C_1, ...., C_N */
        p = len;
    

    PROTECT(ans = allocMatrix(REALSXP, iJ, iN));
    dans = REAL(ans);
    
    for (i = 0; i < iN; i++) {
        start = 0;
        for (j = 0; j < iJ; j++) {
            dans[j] = 0.0;
            for (k = 0; k < j; k++)
                dans[j] += dC[start + k] * dy[k];
            if (Rdiag) {
                dans[j] += dC[start + j] * dy[j];
                start += j + 1;
            } else {
                dans[j] += dy[j]; 
                start += j;
            }
        }
        dC += p;
        dy += iJ;
        dans += iJ;
    }
    UNPROTECT(1);
    return(ans);
}

/* mult transpose */

SEXP R_ltMatrices_Mult_transpose (SEXP C, SEXP y, SEXP N, SEXP J, SEXP diag) {

    SEXP ans;
    double *dans, *dy = REAL(y);
    int i, j, k, start;

    /* RC input */
    
    /* pointer to C matrices */
    double *dC = REAL(C);
    /* number of matrices */
    int iN = INTEGER(N)[0];
    /* dimension of matrices */
    int iJ = INTEGER(J)[0];
    /* C contains diagonal elements */
    Rboolean Rdiag = asLogical(diag);
    /* p = J * (J - 1) / 2 + diag * J */
    int len = iJ * (iJ - 1) / 2 + Rdiag * iJ;
    
    /* C length */
    
    int p;
    if (LENGTH(C) == len)
        /* C is constant for i = 1, ..., N */
        p = 0;
    else 
        /* C contains C_1, ...., C_N */
        p = len;
    

    PROTECT(ans = allocMatrix(REALSXP, iJ, iN));
    dans = REAL(ans);
    
    for (i = 0; i < iN; i++) {
        start = 0;
        for (j = 0; j < iJ; j++) {
            dans[j] = 0.0;
            if (Rdiag) {
                dans[j] += dC[start] * dy[j];
                start++;
            } else {
                dans[j] += dy[j]; 
            }
            for (k = 0; k < (iJ - j - 1); k++)
                dans[j] += dC[start + k] * dy[j + k + 1];
            start += iJ - j - 1;
        }
        dC += p;
        dy += iJ;
        dans += iJ;
    }
    UNPROTECT(1);
    return(ans);
}

/* chol */

SEXP R_syMatrices_chol (SEXP Sigma, SEXP N, SEXP J) {

    SEXP ans;
    double *dans, *dSigma;
    int iJ = INTEGER(J)[0];
    int pJ = iJ * (iJ + 1) / 2;
    int iN = INTEGER(N)[0];
    int i, j, info = 0;
    char lo = 'L';

    PROTECT(ans = allocMatrix(REALSXP, pJ, iN));
    dans = REAL(ans);
    dSigma = REAL(Sigma);

    for (i = 0; i < iN; i++) {

        /* copy data */
        for (j = 0; j < pJ; j++)
            dans[j] = dSigma[j];

        F77_CALL(dpptrf)(&lo, &iJ, dans, &info FCONE);

        if (info != 0) {
            if (info > 0)
                error("the leading minor of order %d is not positive definite",
                      info);
            error("argument %d of Lapack routine %s had invalid value",
                  -info, "dpptrf");
        }

        dSigma += pJ;
        dans += pJ;
    }
    UNPROTECT(1);
    return(ans);
}

/* vec trick */


/* IDX */

#define IDX(i, j, n, d) ((i) >= (j) ? (n) * ((j) - 1) - ((j) - 2) * ((j) - 1)/2 + (i) - (j) - (!d) * (j) : 0)


SEXP R_vectrick(SEXP C, SEXP N, SEXP J, SEXP S, SEXP A, SEXP diag, SEXP trans) {

    int i, j, k;
    SEXP ans;
    double *dS, *dans, *dA;

    /* note: diag is needed by this chunk but has no consequences */
    /* RC input */
    
    /* pointer to C matrices */
    double *dC = REAL(C);
    /* number of matrices */
    int iN = INTEGER(N)[0];
    /* dimension of matrices */
    int iJ = INTEGER(J)[0];
    /* C contains diagonal elements */
    Rboolean Rdiag = asLogical(diag);
    /* p = J * (J - 1) / 2 + diag * J */
    int len = iJ * (iJ - 1) / 2 + Rdiag * iJ;
    
    /* C length */
    
    int p;
    if (LENGTH(C) == len)
        /* C is constant for i = 1, ..., N */
        p = 0;
    else 
        /* C contains C_1, ...., C_N */
        p = len;
    
    dS = REAL(S);
    dA = REAL(A);

    Rboolean RtC = LOGICAL(trans)[0];
    Rboolean RtA = LOGICAL(trans)[1];

    /* t(C) S t(A) */
    
    char siR = 'R', siL = 'L', lo = 'L', tr = 'N', trT = 'T', di = 'N', trs;
    double ONE = 1.0;
    int iJ2 = iJ * iJ;

    double tmp[iJ2];
    for (j = 0; j < iJ2; j++) tmp[j] = 0.0;

    ans = PROTECT(allocMatrix(REALSXP, iJ2, iN));
    dans = REAL(ans);

    for (i = 0; i < LENGTH(ans); i++) dans[i] = 0.0;

    for (i = 0; i < iN; i++) {

        /* A := C */
        for (j = 0; j < iJ; j++) {
            for (k = 0; k <= j; k++)
                tmp[k * iJ + j] = dC[IDX(j + 1, k + 1, iJ, 1L)];
        }

        /* S was already expanded in R code; B = S */
        for (j = 0; j < iJ2; j++) dans[j] = dS[j];

        /* B := t(A) %*% B */
        trs = (RtC ? trT : tr);
        F77_CALL(dtrmm)(&siL, &lo, &trs, &di, &iJ, &iJ, &ONE, tmp, &iJ, 
                        dans, &iJ FCONE FCONE FCONE FCONE);

        /* A */
        for (j = 0; j < iJ; j++) {
            for (k = 0; k <= j; k++)
                tmp[k * iJ + j] = dA[IDX(j + 1, k + 1, iJ, 1L)];
        }

        /* B := B %*% t(A) */
        trs = (RtA ? trT : tr);
        F77_CALL(dtrmm)(&siR, &lo, &trs, &di, &iJ, &iJ, &ONE, tmp, &iJ, 
                        dans, &iJ FCONE FCONE FCONE FCONE);

        dans += iJ2;
        dC += p;
        dS += iJ2;
        dA += p;
    }    
    

    UNPROTECT(1);
    return(ans);
}

