Given the values of (observed) x-variables in a SEM, this function may be used to predict the values of (observed) y-variables. The predictive procedure consists of two steps: (1) construction of the topological layer (TL) ordering of the input graph; (2) prediction of the node y values in a layer, where the nodes included in the previous layers act as predictors x.

# S3 method for class 'SEM'
predict(object, newdata, verbose = FALSE, ...)

Arguments

object

An object, as that created by the function SEMrun() with the argument fit set to fit = 0 or fit = 1.

newdata

A matrix with new data, with rows corresponding to subjects, and columns to variables. If object$fit is a model with the group variable (fit = 1), the first column of newdata must be the new group binary vector (0=control, 1=case).

verbose

A logical value. If FALSE (default), the processed graph will not be plotted to screen.

...

Currently ignored.

Value

A list of 2 objects:

  1. "PE", vector of the prediction error equal to the Mean Squared Error (MSE) for each out-of-bag prediction. The first value of PE is the AMSE, where we average over all (sink and mediators) graph nodes.

  2. "Yhat", the matrix of continuous predicted values of graph nodes (excluding source nodes) based on out-of-bag samples.

Details

The function first creates a layer-based structure of the input graph. Then, a SEM-based predictive approach (Rooij et al., 2022) is used to produce predictions while accounting for the graph structure organised in topological layers, j=1,...,L. In each iteration, the response variables y are the nodes in the j layer and the predictors x are the nodes belonging to the previous j-1 layers. Predictions (for y given x) are based on the (joint y and x) model-implied variance-covariance (Sigma) matrix and mean vector (Mu) of the fitted SEM, and the standard expression for the conditional mean of a multivariate normal distribution. Thus, the layer structure described in the SEM is taken into consideration, which differs from ordinary least squares (OLS) regression.

References

de Rooij M, Karch JD, Fokkema M, Bakk Z, Pratiwi BC, and Kelderman H (2023). SEM-Based Out-of-Sample Predictions, Structural Equation Modeling: A Multidisciplinary Journal, 30:1, 132-148 <https://doi.org/10.1080/10705511.2022.2061494>

Grassi M, Palluzzi F, Tarantino B (2022). SEMgraph: An R Package for Causal Network Analysis of High-Throughput Data with Structural Equation Models. Bioinformatics, 38 (20), 4829–4830 <https://doi.org/10.1093/bioinformatics/btac567>

Author

Mario Grassi mario.grassi@unipv.it

Examples


# load ALS data
ig<- alsData$graph
data<- alsData$exprs
data<- transformData(data)$data
#> Conducting the nonparanormal transformation via shrunkun ECDF...done.
group<- alsData$group

#...with train-test (0.5-0.5) samples
set.seed(123)
train<- sample(1:nrow(data), 0.5*nrow(data))

# SEM fitting
#sem0<- SEMrun(ig, data[train,], algo="lavaan", SE="none")
#sem0<- SEMrun(ig, data[train,], algo="ricf", n_rep=0)
sem0<- SEMrun(ig, data[train,], SE="none", limit=1000)
#> NLMINB solver ended normally after 1 iterations 
#> 
#> deviance/df: 6.263349  srmr: 0.3041398 
#> 

# predictors, source+mediator; outcomes, mediator+sink

res0<- predict(sem0, newdata=data[-train,]) 
print(res0$PE)
#>      amse     10452     84134       836      4747      4741      4744     79139 
#> 0.8973153 1.1160997 0.8591412 0.6957059 0.7435081 1.0773819 1.0733166 0.6766548 
#>      5530      5532      5533      5534      5535       842      1432      5600 
#> 0.9041601 0.8519291 1.2303948 0.8115211 0.6638488 0.7160223 0.8122749 0.5717884 
#>      5603      6300     54205      5606      5608       596      4217      1616 
#> 1.0343893 0.6118950 1.1161717 0.8886350 0.9082540 1.0675990 1.0817833 1.1257768 

# SEM fitting
#sem1<- SEMrun(ig, data[train,], group[train], algo="lavaan", SE="none")
#sem1<- SEMrun(ig, data[train,], group[train], algo="ricf", n_rep=0)
sem1<- SEMrun(ig, data[train,], group[train], SE="none", limit=1000)
#> NLMINB solver ended normally after 6 iterations 
#> 
#> deviance/df: 6.211221  srmr: 0.2861178 
#> 

# predictors, source+mediator+group; outcomes, source+mediator+sink

res1<- predict(sem1, newdata=cbind(group,data)[-train,]) 
print(res1$PE)
#>      amse     10452     84134       836      4747      4741      4744     79139 
#> 0.8963810 1.1258657 0.9073936 0.6861219 0.7001557 1.0070725 1.1045920 0.6652474 
#>      5530      5532      5533      5534      5535       842      1432      5600 
#> 0.8679917 0.7365461 1.2300132 0.6637344 0.6675348 0.7262147 0.8576185 0.6667256 
#>      5603      6300      5630     54205       317      5606      5608       581 
#> 1.1063933 0.6068862 1.0886300 1.0784325 0.7720406 0.8841224 0.8326926 1.0285427 
#>       572       596       598      4217      6647      1616      7132      7133 
#> 0.7379857 1.0832421 0.8036031 1.0731173 1.0383387 1.1300782 0.8195168 1.0913615 

# \donttest{
#...with a binary outcome (1=case, 0=control)

ig1<- mapGraph(ig, type="outcome"); gplot(ig1)

outcome<- ifelse(group == 0, -1, 1); table(outcome)
#> outcome
#>  -1   1 
#>  21 139 
data1<- cbind(outcome, data); data1[1:5,1:5]
#>      outcome        207         208      10000       284
#> ALS2       1 -1.8273895 -0.45307006 -0.1360061 0.4530701
#> ALS3       1 -2.5616910 -0.96201413  0.3160400 0.6762093
#> ALS4       1 -0.8003346  0.82216031 -1.1521227 0.5613048
#> ALS5       1 -2.1342965 -0.98709115  1.1521227 0.5064807
#> ALS6       1 -2.0111279  0.02393297  0.5987578 0.1360061

sem10 <- SEMrun(ig1, data1[train,], SE="none", limit=1000)
#> NLMINB solver ended normally after 1 iterations 
#> 
#> deviance/df: 6.122632  srmr: 0.3101125 
#> 
res10<- predict(sem10, newdata=data1[-train,], verbose=TRUE) 

#>      amse   outcome     10452     84134       836      4747      4741      4744 
#> 0.8828829 0.5509374 1.1160997 0.8591412 0.6957059 0.7435081 1.0773819 1.0733166 
#>     79139      5530      5532      5533      5534      5535       842      1432 
#> 0.6766548 0.9041601 0.8519291 1.2303948 0.8115211 0.6638488 0.7160223 0.8122749 
#>      5600      5603      6300     54205      5606      5608       596      4217 
#> 0.5717884 1.0343893 0.6118950 1.1161717 0.8886350 0.9082540 1.0675990 1.0817833 
#>      1616 
#> 1.1257768 

yobs<- group[-train]
yhat<- res10$Yhat[,"outcome"]
benchmark(yobs, yhat)
#>     ypred
#> yobs  0  1
#>    0  3  3
#>    1 14 60
#> 
#>        pre       rec        f1       mcc
#> 1 0.952381 0.8108108 0.8759124 0.2001211

#...with predictors, source nodes; outcomes, sink nodes
ig2<- mapGraph(ig, type= "source"); gplot(ig2)


sem02 <- SEMrun(ig2, data[train,], SE="none", limit=1000)
#> NLMINB solver ended normally after 1 iterations 
#> 
#> deviance/df: 10.16978  srmr: 0.1295992 
#> 
res02<- predict(sem02, newdata=data[-train,], verbose=TRUE) 

#>      amse     10452     84134       836      4747      4741      4744     79139 
#> 0.7550791 0.6146408 0.5114778 0.6960717 0.8871358 1.1343159 1.0664702 0.5043968 
#>      5530      5532      5533      5534      5535 
#> 0.6128182 0.6257037 1.1708051 0.6069758 0.6301368 
#print(res02$PE)

#...with 10-iterations of 10-fold cross-validation samples

res<- NULL
for (r in 1:10) {
  set.seed(r)
  cat("rep = ", r, "\n")
  idx <- SEMdeep:::createFolds(y=data[,1], k=10)
  for (k in 1:10) {
   cat("  k-fold = ", k, "\n")
   semr<- SEMdeep:::quiet(SEMrun(ig, data, SE="none", limit=1000))
   resr<- predict(semr, newdata=data[-idx[[k]], ])
   res<- rbind(res, resr$PE)
  }
}
#> rep =  1 
#>   k-fold =  1 
#>   k-fold =  2 
#>   k-fold =  3 
#>   k-fold =  4 
#>   k-fold =  5 
#>   k-fold =  6 
#>   k-fold =  7 
#>   k-fold =  8 
#>   k-fold =  9 
#>   k-fold =  10 
#> rep =  2 
#>   k-fold =  1 
#>   k-fold =  2 
#>   k-fold =  3 
#>   k-fold =  4 
#>   k-fold =  5 
#>   k-fold =  6 
#>   k-fold =  7 
#>   k-fold =  8 
#>   k-fold =  9 
#>   k-fold =  10 
#> rep =  3 
#>   k-fold =  1 
#>   k-fold =  2 
#>   k-fold =  3 
#>   k-fold =  4 
#>   k-fold =  5 
#>   k-fold =  6 
#>   k-fold =  7 
#>   k-fold =  8 
#>   k-fold =  9 
#>   k-fold =  10 
#> rep =  4 
#>   k-fold =  1 
#>   k-fold =  2 
#>   k-fold =  3 
#>   k-fold =  4 
#>   k-fold =  5 
#>   k-fold =  6 
#>   k-fold =  7 
#>   k-fold =  8 
#>   k-fold =  9 
#>   k-fold =  10 
#> rep =  5 
#>   k-fold =  1 
#>   k-fold =  2 
#>   k-fold =  3 
#>   k-fold =  4 
#>   k-fold =  5 
#>   k-fold =  6 
#>   k-fold =  7 
#>   k-fold =  8 
#>   k-fold =  9 
#>   k-fold =  10 
#> rep =  6 
#>   k-fold =  1 
#>   k-fold =  2 
#>   k-fold =  3 
#>   k-fold =  4 
#>   k-fold =  5 
#>   k-fold =  6 
#>   k-fold =  7 
#>   k-fold =  8 
#>   k-fold =  9 
#>   k-fold =  10 
#> rep =  7 
#>   k-fold =  1 
#>   k-fold =  2 
#>   k-fold =  3 
#>   k-fold =  4 
#>   k-fold =  5 
#>   k-fold =  6 
#>   k-fold =  7 
#>   k-fold =  8 
#>   k-fold =  9 
#>   k-fold =  10 
#> rep =  8 
#>   k-fold =  1 
#>   k-fold =  2 
#>   k-fold =  3 
#>   k-fold =  4 
#>   k-fold =  5 
#>   k-fold =  6 
#>   k-fold =  7 
#>   k-fold =  8 
#>   k-fold =  9 
#>   k-fold =  10 
#> rep =  9 
#>   k-fold =  1 
#>   k-fold =  2 
#>   k-fold =  3 
#>   k-fold =  4 
#>   k-fold =  5 
#>   k-fold =  6 
#>   k-fold =  7 
#>   k-fold =  8 
#>   k-fold =  9 
#>   k-fold =  10 
#> rep =  10 
#>   k-fold =  1 
#>   k-fold =  2 
#>   k-fold =  3 
#>   k-fold =  4 
#>   k-fold =  5 
#>   k-fold =  6 
#>   k-fold =  7 
#>   k-fold =  8 
#>   k-fold =  9 
#>   k-fold =  10 
#average results
apply(res, 2, mean)
#>      amse     10452     84134       836      4747      4741      4744     79139 
#> 0.8203834 0.9924008 0.9373265 0.7580843 0.5039775 0.7884417 0.8785680 0.7226705 
#>      5530      5532      5533      5534      5535       842      1432      5600 
#> 0.8494899 0.7238154 0.9889105 0.8135880 0.7818244 0.5714233 0.7531485 0.8071404 
#>      5603      6300     54205      5606      5608       596      4217      1616 
#> 0.9521244 0.7174839 0.7012493 0.9636981 0.7552133 0.9749872 0.9882541 0.9449986 
# }