Given the values of (observed) x-variables in a SEM, this function may be used to predict the values of (observed) y-variables. The predictive procedure consists of two steps: (1) construction of the topological layer (TL) ordering of the input graph; (2) prediction of the node y values in a layer, where the nodes included in the previous layers act as predictors x.
# S3 method for class 'SEM'
predict(object, newdata, verbose = FALSE, ...)
An object, as that created by the function SEMrun()
with the argument fit
set to fit = 0
or fit = 1
.
A matrix with new data, with rows corresponding to subjects,
and columns to variables. If object$fit
is a model with the group
variable (fit = 1
), the first column of newdata must be the new
group binary vector (0=control, 1=case).
A logical value. If FALSE (default), the processed graph will not be plotted to screen.
Currently ignored.
A list of 2 objects:
"PE", vector of the prediction error equal to the Mean Squared Error (MSE) for each out-of-bag prediction. The first value of PE is the AMSE, where we average over all (sink and mediators) graph nodes.
"Yhat", the matrix of continuous predicted values of graph nodes (excluding source nodes) based on out-of-bag samples.
The function first creates a layer-based structure of the input graph. Then, a SEM-based predictive approach (Rooij et al., 2022) is used to produce predictions while accounting for the graph structure organised in topological layers, j=1,...,L. In each iteration, the response variables y are the nodes in the j layer and the predictors x are the nodes belonging to the previous j-1 layers. Predictions (for y given x) are based on the (joint y and x) model-implied variance-covariance (Sigma) matrix and mean vector (Mu) of the fitted SEM, and the standard expression for the conditional mean of a multivariate normal distribution. Thus, the layer structure described in the SEM is taken into consideration, which differs from ordinary least squares (OLS) regression.
de Rooij M, Karch JD, Fokkema M, Bakk Z, Pratiwi BC, and Kelderman H (2023). SEM-Based Out-of-Sample Predictions, Structural Equation Modeling: A Multidisciplinary Journal, 30:1, 132-148 <https://doi.org/10.1080/10705511.2022.2061494>
Grassi M, Palluzzi F, Tarantino B (2022). SEMgraph: An R Package for Causal Network Analysis of High-Throughput Data with Structural Equation Models. Bioinformatics, 38 (20), 4829–4830 <https://doi.org/10.1093/bioinformatics/btac567>
# load ALS data
ig<- alsData$graph
data<- alsData$exprs
data<- transformData(data)$data
#> Conducting the nonparanormal transformation via shrunkun ECDF...done.
group<- alsData$group
#...with train-test (0.5-0.5) samples
set.seed(123)
train<- sample(1:nrow(data), 0.5*nrow(data))
# SEM fitting
#sem0<- SEMrun(ig, data[train,], algo="lavaan", SE="none")
#sem0<- SEMrun(ig, data[train,], algo="ricf", n_rep=0)
sem0<- SEMrun(ig, data[train,], SE="none", limit=1000)
#> NLMINB solver ended normally after 1 iterations
#>
#> deviance/df: 6.263349 srmr: 0.3041398
#>
# predictors, source+mediator; outcomes, mediator+sink
res0<- predict(sem0, newdata=data[-train,])
print(res0$PE)
#> amse 10452 84134 836 4747 4741 4744 79139
#> 0.8973153 1.1160997 0.8591412 0.6957059 0.7435081 1.0773819 1.0733166 0.6766548
#> 5530 5532 5533 5534 5535 842 1432 5600
#> 0.9041601 0.8519291 1.2303948 0.8115211 0.6638488 0.7160223 0.8122749 0.5717884
#> 5603 6300 54205 5606 5608 596 4217 1616
#> 1.0343893 0.6118950 1.1161717 0.8886350 0.9082540 1.0675990 1.0817833 1.1257768
# SEM fitting
#sem1<- SEMrun(ig, data[train,], group[train], algo="lavaan", SE="none")
#sem1<- SEMrun(ig, data[train,], group[train], algo="ricf", n_rep=0)
sem1<- SEMrun(ig, data[train,], group[train], SE="none", limit=1000)
#> NLMINB solver ended normally after 6 iterations
#>
#> deviance/df: 6.211221 srmr: 0.2861178
#>
# predictors, source+mediator+group; outcomes, source+mediator+sink
res1<- predict(sem1, newdata=cbind(group,data)[-train,])
print(res1$PE)
#> amse 10452 84134 836 4747 4741 4744 79139
#> 0.8963810 1.1258657 0.9073936 0.6861219 0.7001557 1.0070725 1.1045920 0.6652474
#> 5530 5532 5533 5534 5535 842 1432 5600
#> 0.8679917 0.7365461 1.2300132 0.6637344 0.6675348 0.7262147 0.8576185 0.6667256
#> 5603 6300 5630 54205 317 5606 5608 581
#> 1.1063933 0.6068862 1.0886300 1.0784325 0.7720406 0.8841224 0.8326926 1.0285427
#> 572 596 598 4217 6647 1616 7132 7133
#> 0.7379857 1.0832421 0.8036031 1.0731173 1.0383387 1.1300782 0.8195168 1.0913615
# \donttest{
#...with a binary outcome (1=case, 0=control)
ig1<- mapGraph(ig, type="outcome"); gplot(ig1)
outcome<- ifelse(group == 0, -1, 1); table(outcome)
#> outcome
#> -1 1
#> 21 139
data1<- cbind(outcome, data); data1[1:5,1:5]
#> outcome 207 208 10000 284
#> ALS2 1 -1.8273895 -0.45307006 -0.1360061 0.4530701
#> ALS3 1 -2.5616910 -0.96201413 0.3160400 0.6762093
#> ALS4 1 -0.8003346 0.82216031 -1.1521227 0.5613048
#> ALS5 1 -2.1342965 -0.98709115 1.1521227 0.5064807
#> ALS6 1 -2.0111279 0.02393297 0.5987578 0.1360061
sem10 <- SEMrun(ig1, data1[train,], SE="none", limit=1000)
#> NLMINB solver ended normally after 1 iterations
#>
#> deviance/df: 6.122632 srmr: 0.3101125
#>
res10<- predict(sem10, newdata=data1[-train,], verbose=TRUE)
#> amse outcome 10452 84134 836 4747 4741 4744
#> 0.8828829 0.5509374 1.1160997 0.8591412 0.6957059 0.7435081 1.0773819 1.0733166
#> 79139 5530 5532 5533 5534 5535 842 1432
#> 0.6766548 0.9041601 0.8519291 1.2303948 0.8115211 0.6638488 0.7160223 0.8122749
#> 5600 5603 6300 54205 5606 5608 596 4217
#> 0.5717884 1.0343893 0.6118950 1.1161717 0.8886350 0.9082540 1.0675990 1.0817833
#> 1616
#> 1.1257768
yobs<- group[-train]
yhat<- res10$Yhat[,"outcome"]
benchmark(yobs, yhat)
#> ypred
#> yobs 0 1
#> 0 3 3
#> 1 14 60
#>
#> pre rec f1 mcc
#> 1 0.952381 0.8108108 0.8759124 0.2001211
#...with predictors, source nodes; outcomes, sink nodes
ig2<- mapGraph(ig, type= "source"); gplot(ig2)
sem02 <- SEMrun(ig2, data[train,], SE="none", limit=1000)
#> NLMINB solver ended normally after 1 iterations
#>
#> deviance/df: 10.16978 srmr: 0.1295992
#>
res02<- predict(sem02, newdata=data[-train,], verbose=TRUE)
#> amse 10452 84134 836 4747 4741 4744 79139
#> 0.7550791 0.6146408 0.5114778 0.6960717 0.8871358 1.1343159 1.0664702 0.5043968
#> 5530 5532 5533 5534 5535
#> 0.6128182 0.6257037 1.1708051 0.6069758 0.6301368
#print(res02$PE)
#...with 10-iterations of 10-fold cross-validation samples
res<- NULL
for (r in 1:10) {
set.seed(r)
cat("rep = ", r, "\n")
idx <- SEMdeep:::createFolds(y=data[,1], k=10)
for (k in 1:10) {
cat(" k-fold = ", k, "\n")
semr<- SEMdeep:::quiet(SEMrun(ig, data, SE="none", limit=1000))
resr<- predict(semr, newdata=data[-idx[[k]], ])
res<- rbind(res, resr$PE)
}
}
#> rep = 1
#> k-fold = 1
#> k-fold = 2
#> k-fold = 3
#> k-fold = 4
#> k-fold = 5
#> k-fold = 6
#> k-fold = 7
#> k-fold = 8
#> k-fold = 9
#> k-fold = 10
#> rep = 2
#> k-fold = 1
#> k-fold = 2
#> k-fold = 3
#> k-fold = 4
#> k-fold = 5
#> k-fold = 6
#> k-fold = 7
#> k-fold = 8
#> k-fold = 9
#> k-fold = 10
#> rep = 3
#> k-fold = 1
#> k-fold = 2
#> k-fold = 3
#> k-fold = 4
#> k-fold = 5
#> k-fold = 6
#> k-fold = 7
#> k-fold = 8
#> k-fold = 9
#> k-fold = 10
#> rep = 4
#> k-fold = 1
#> k-fold = 2
#> k-fold = 3
#> k-fold = 4
#> k-fold = 5
#> k-fold = 6
#> k-fold = 7
#> k-fold = 8
#> k-fold = 9
#> k-fold = 10
#> rep = 5
#> k-fold = 1
#> k-fold = 2
#> k-fold = 3
#> k-fold = 4
#> k-fold = 5
#> k-fold = 6
#> k-fold = 7
#> k-fold = 8
#> k-fold = 9
#> k-fold = 10
#> rep = 6
#> k-fold = 1
#> k-fold = 2
#> k-fold = 3
#> k-fold = 4
#> k-fold = 5
#> k-fold = 6
#> k-fold = 7
#> k-fold = 8
#> k-fold = 9
#> k-fold = 10
#> rep = 7
#> k-fold = 1
#> k-fold = 2
#> k-fold = 3
#> k-fold = 4
#> k-fold = 5
#> k-fold = 6
#> k-fold = 7
#> k-fold = 8
#> k-fold = 9
#> k-fold = 10
#> rep = 8
#> k-fold = 1
#> k-fold = 2
#> k-fold = 3
#> k-fold = 4
#> k-fold = 5
#> k-fold = 6
#> k-fold = 7
#> k-fold = 8
#> k-fold = 9
#> k-fold = 10
#> rep = 9
#> k-fold = 1
#> k-fold = 2
#> k-fold = 3
#> k-fold = 4
#> k-fold = 5
#> k-fold = 6
#> k-fold = 7
#> k-fold = 8
#> k-fold = 9
#> k-fold = 10
#> rep = 10
#> k-fold = 1
#> k-fold = 2
#> k-fold = 3
#> k-fold = 4
#> k-fold = 5
#> k-fold = 6
#> k-fold = 7
#> k-fold = 8
#> k-fold = 9
#> k-fold = 10
#average results
apply(res, 2, mean)
#> amse 10452 84134 836 4747 4741 4744 79139
#> 0.8203834 0.9924008 0.9373265 0.7580843 0.5039775 0.7884417 0.8785680 0.7226705
#> 5530 5532 5533 5534 5535 842 1432 5600
#> 0.8494899 0.7238154 0.9889105 0.8135880 0.7818244 0.5714233 0.7531485 0.8071404
#> 5603 6300 54205 5606 5608 596 4217 1616
#> 0.9521244 0.7174839 0.7012493 0.9636981 0.7552133 0.9749872 0.9882541 0.9449986
# }