The function converts a graph to a collection of
nodewise-based models: each mediator or sink variable can be expressed as
a function of its parents. Based on the assumed type of relationship,
i.e. linear or non-linear, SEMml()
fits a ML model to each
node (variable) with non-zero incoming connectivity.
The model fitting is repeated equation-by equation (r=1,...,R)
times, where R is the number of mediators and sink nodes.
SEMml(
graph,
data,
train = NULL,
algo = "sem",
vimp = FALSE,
thr = NULL,
verbose = FALSE,
...
)
An igraph object.
A matrix with rows corresponding to subjects, and columns to graph nodes (variables).
A numeric vector specifying the row indices corresponding to the train dataset (default = NULL).
ML method used for nodewise-network predictions. Six algorithms can be specified:
algo="sem"
(default) for a linear SEM, see SEMrun
.
algo="gam"
for a generalized additive model, see gam
.
algo="rf"
for a random forest model, see ranger
.
algo="xgb"
for a XGBoost model, see xgboost
.
algo="nn"
for a small neural network model (1 hidden layer and 10 nodes), see nnet
.
algo="dnn"
for a large neural network model (1 hidden layers and 1000 nodes), see dnn
.
A Logical value(default=FALSE). If TRUE compute the variable
importance, considering: (i) the squared value of the t-statistic or F-statistic
of the model parameters for "sem" or "gam"; (ii) the variable importance from
the importance
or xgb.importance
functions for "rf" or "xgb"; (iii) the Olden's connection weights for "nn" or
"dnn".
A numerical value indicating the threshold to apply on the variable importance to color the graph. If thr=NULL (default), the threshold is set to thr = abs(mean(vimp)).
A logical value. If FALSE (default), the processed graph will not be plotted to screen.
Currently ignored.
An S3 object of class "ML" is returned. It is a list of 5 objects:
"fit", a list of ML model objects, including: the estimated covariance matrix (Sigma), the estimated model errors (Psi), the fitting indices (fitIdx), and the signed Shapley R2 values (parameterEstimates), if shap = TRUE,
"Yhat", a matrix of predictions of sink and mediator graph nodes.
"model", a list of all the fitted nodewise-based models (sem, gam, rf, xgb or nn).
"graph", the induced DAG of the input graph mapped on data variables. If vimp = TRUE, the DAG is colored based on the variable importance measure, i.e., if abs(vimp) > thr will be highlighted in red (vimp > 0) or blue (vimp < 0).
"data", input training data subset mapping graph nodes.
By mapping data onto the input graph, SEMml()
creates
a set of nodewise-based models based on the directed links, i.e.,
a set of edges pointing in the same direction, between two nodes
in the input graph that are causally relevant to each other.
The mediator or sink variables can be characterized in detail as
functions of their parents. An ML model (sem, gam, rf, xgb, nn, dnn)
can then be fitted to each variable with non-zero inbound connectivity,
taking into account the kind of relationship (linear or non-linear).
With R representing the number of mediators and sink nodes in the
network, the model fitting process is performed equation-by-equation
(r=1,...,R) times.
Grassi M, Palluzzi F, Tarantino B (2022). SEMgraph: An R Package for Causal Network Analysis of High-Throughput Data with Structural Equation Models. Bioinformatics, 38 (20), 4829–4830 <https://doi.org/10.1093/bioinformatics/btac567>
Hastie, T. and Tibshirani, R. (1990) Generalized Additive Models. London: Chapman and Hall.
Breiman, L. (2001), Random Forests, Machine Learning 45(1), 5-32.
Chen, T., & Guestrin, C. (2016). XGBoost: A Scalable Tree Boosting System. Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining.
Ripley, B. D. (1996) Pattern Recognition and Neural Networks. Cambridge.
Redell, N. (2019). Shapley Decomposition of R-Squared in Machine Learning Models. arXiv: Methodology.
# \donttest{
# Load Amyotrophic Lateral Sclerosis (ALS)
data<- alsData$exprs; dim(data)
#> [1] 160 318
data<- transformData(data)$data
#> Conducting the nonparanormal transformation via shrunkun ECDF...done.
group<- alsData$group; table (group)
#> group
#> 0 1
#> 21 139
ig<- alsData$graph; gplot(ig)
#...with train-test (0.5-0.5) samples
set.seed(123)
train<- sample(1:nrow(data), 0.5*nrow(data))
start<- Sys.time()
# ... rf
#res1<- SEMml(ig, data, train, algo="rf", vimp=FALSE)
res1<- SEMml(ig, data, train, algo="rf", vimp=TRUE)
#> 1 : z10452
#> 2 : z1432
#> 3 : z1616
#> 4 : z4217
#> 5 : z4741
#> 6 : z4744
#> 7 : z4747
#> 8 : z54205
#> 9 : z5530
#> 10 : z5532
#> 11 : z5533
#> 12 : z5534
#> 13 : z5535
#> 14 : z5600
#> 15 : z5603
#> 16 : z5606
#> 17 : z5608
#> 18 : z596
#> 19 : z6300
#> 20 : z79139
#> 21 : z836
#> 22 : z84134
#> 23 : z842
#>
#> RF solver ended normally after 23 iterations
#>
#> logL: -33.22824 srmr: 0.0859545
#>
# ... xgb
#res2<- SEMml(ig, data, train, algo="xgb", vimp=FALSE)
res2<- SEMml(ig, data, train, algo="xgb", vimp=TRUE)
#> 1 : z10452
#> 2 : z1432
#> 3 : z1616
#> 4 : z4217
#> 5 : z4741
#> 6 : z4744
#> 7 : z4747
#> 8 : z54205
#> 9 : z5530
#> 10 : z5532
#> 11 : z5533
#> 12 : z5534
#> 13 : z5535
#> 14 : z5600
#> 15 : z5603
#> 16 : z5606
#> 17 : z5608
#> 18 : z596
#> 19 : z6300
#> 20 : z79139
#> 21 : z836
#> 22 : z84134
#> 23 : z842
#>
#> XGB solver ended normally after 23 iterations
#>
#> logL: 70.10035 srmr: 0.0014393
#>
# ... nn
#res3<- SEMml(ig, data, train, algo="nn", vimp=FALSE)
res3<- SEMml(ig, data, train, algo="nn", vimp=TRUE)
#> 1 : z10452
#> 2 : z1432
#> 3 : z1616
#> 4 : z4217
#> 5 : z4741
#> 6 : z4744
#> 7 : z4747
#> 8 : z54205
#> 9 : z5530
#> 10 : z5532
#> 11 : z5533
#> 12 : z5534
#> 13 : z5535
#> 14 : z5600
#> 15 : z5603
#> 16 : z5606
#> 17 : z5608
#> 18 : z596
#> 19 : z6300
#> 20 : z79139
#> 21 : z836
#> 22 : z84134
#> 23 : z842
#>
#> NN solver ended normally after 23 iterations
#>
#> logL: -37.48083 srmr: 0.1987503
#>
# ... gam
#res4<- SEMml(ig, data, train, algo="gam", vimp=FALSE)
res4<- SEMml(ig, data, train, algo="gam", vimp=TRUE)
#> 1 : z10452
#> 2 : z1432
#> 3 : z1616
#> 4 : z4217
#> 5 : z4741
#> 6 : z4744
#> 7 : z4747
#> 8 : z54205
#> 9 : z5530
#> 10 : z5532
#> 11 : z5533
#> 12 : z5534
#> 13 : z5535
#> 14 : z5600
#> 15 : z5603
#> 16 : z5606
#> 17 : z5608
#> 18 : z596
#> 19 : z6300
#> 20 : z79139
#> 21 : z836
#> 22 : z84134
#> 23 : z842
#>
#> GAM solver ended normally after 23 iterations
#>
#> logL: -46.77283 srmr: 0.3819281
#>
end<- Sys.time()
print(end-start)
#> Time difference of 9.565145 secs
# ... sem
#res5<- SEMml(ig, data, train, algo="sem", vimp=FALSE)
res5<- SEMml(ig, data, train, algo="sem", vimp=TRUE)
#> 1 : z10452
#> 2 : z1432
#> 3 : z1616
#> 4 : z4217
#> 5 : z4741
#> 6 : z4744
#> 7 : z4747
#> 8 : z54205
#> 9 : z5530
#> 10 : z5532
#> 11 : z5533
#> 12 : z5534
#> 13 : z5535
#> 14 : z5600
#> 15 : z5603
#> 16 : z5606
#> 17 : z5608
#> 18 : z596
#> 19 : z6300
#> 20 : z79139
#> 21 : z836
#> 22 : z84134
#> 23 : z842
#>
#> SEM solver ended normally after 23 iterations
#>
#> logL: -47.90891 srmr: 0.3040025
#>
#str(res5, max.level=2)
res5$fit$fitIdx
#> logL amse rmse srmr.srmr
#> -47.9089078 0.8166851 0.9037063 0.3040025
res5$fit$parameterEstimates
#> lhs op rhs varImp
#> 1 10452 ~ 6647 0.164
#> z5606 1432 ~ 5606 9.823
#> z5608 1432 ~ 5608 29.897
#> z7132 1616 ~ 7132 2.654
#> z7133 1616 ~ 7133 0.132
#> 11 4217 ~ 1616 2.781
#> z1432 4741 ~ 1432 7.768
#> z5600 4741 ~ 5600 0.554
#> z5603 4741 ~ 5603 0.745
#> z6300 4741 ~ 6300 0.801
#> z5630 4741 ~ 5630 3.339
#> z14321 4744 ~ 1432 5.927
#> z56001 4744 ~ 5600 0.563
#> z56031 4744 ~ 5603 0.351
#> z63001 4744 ~ 6300 0.347
#> z56301 4744 ~ 5630 4.134
#> z6647 4747 ~ 6647 25.892
#> z14322 4747 ~ 1432 8.210
#> z56002 4747 ~ 5600 0.050
#> z56032 4747 ~ 5603 4.131
#> z63002 4747 ~ 6300 0.973
#> z56302 4747 ~ 5630 1.506
#> z581 54205 ~ 581 22.353
#> z572 54205 ~ 572 0.029
#> z596 54205 ~ 596 0.080
#> z598 54205 ~ 598 2.735
#> 12 5530 ~ 6647 9.332
#> 13 5532 ~ 6647 17.839
#> 14 5533 ~ 6647 6.467
#> 15 5534 ~ 6647 9.491
#> 16 5535 ~ 6647 16.447
#> z56061 5600 ~ 5606 2.159
#> z56081 5600 ~ 5608 5.425
#> z56062 5603 ~ 5606 4.948
#> z56082 5603 ~ 5608 0.037
#> 17 5606 ~ 4217 0.669
#> 18 5608 ~ 4217 34.179
#> 19 596 ~ 6647 0.416
#> z56063 6300 ~ 5606 1.325
#> z56083 6300 ~ 5608 14.431
#> 110 79139 ~ 6647 19.233
#> 111 836 ~ 842 22.853
#> 112 84134 ~ 6647 7.515
#> z54205 842 ~ 54205 41.336
#> z317 842 ~ 317 6.901
gplot(res5$graph)
#Comparison of AMSE (in train data)
rf <- res1$fit$fitIdx[2];rf
#> amse
#> 0.2327284
xgb<- res2$fit$fitIdx[2];xgb
#> amse
#> 0.0001865892
nn <- res3$fit$fitIdx[2];nn
#> amse
#> 0.4966564
gam<- res4$fit$fitIdx[2];gam
#> amse
#> 0.7525285
sem<- res5$fit$fitIdx[2];sem
#> amse
#> 0.8166851
#Comparison of SRMR (in train data)
rf <- res1$fit$fitIdx[4];rf
#> srmr
#> 0.08595454
xgb<- res2$fit$fitIdx[4];xgb
#> srmr
#> 0.001439282
nn <- res3$fit$fitIdx[4];nn
#> srmr
#> 0.1987503
gam<- res4$fit$fitIdx[4];gam
#> srmr
#> 0.3819281
sem<- res5$fit$fitIdx[4];sem
#> srmr.srmr
#> 0.3040025
#Comparison of VIMP (in train data)
table(E(res1$graph)$color) #rf
#>
#> gray50 red2
#> 28 17
table(E(res2$graph)$color) #xgb
#>
#> gray50 red2
#> 26 19
table(E(res3$graph)$color) #nn
#>
#> gray50 red2 royalblue3
#> 31 4 10
table(E(res4$graph)$color) #gam
#>
#> gray50 red2
#> 31 14
table(E(res5$graph)$color) #sem
#>
#> gray50 red2
#> 31 14
#Comparison of AMSE (in test data)
print(predict(res1, data[-train, ])$PE[1]) #rf
#> amse
#> 1.108783
print(predict(res2, data[-train, ])$PE[1]) #xgb
#> amse
#> 1.548598
print(predict(res3, data[-train, ])$PE[1]) #nn
#> amse
#> 1.643303
print(predict(res4, data[-train, ])$PE[1]) #gam
#> amse
#> 0.912187
print(predict(res5, data[-train, ])$PE[1]) #sem
#> amse
#> 0.8973153
#...with a binary outcome (1=case, 0=control)
ig1<- mapGraph(ig, type="outcome"); gplot(ig1)
outcome<- ifelse(group == 0, -1, 1); table(outcome)
#> outcome
#> -1 1
#> 21 139
data1<- cbind(outcome, data); data1[1:5,1:5]
#> outcome 207 208 10000 284
#> ALS2 1 -1.8273895 -0.45307006 -0.1360061 0.4530701
#> ALS3 1 -2.5616910 -0.96201413 0.3160400 0.6762093
#> ALS4 1 -0.8003346 0.82216031 -1.1521227 0.5613048
#> ALS5 1 -2.1342965 -0.98709115 1.1521227 0.5064807
#> ALS6 1 -2.0111279 0.02393297 0.5987578 0.1360061
res6 <- SEMml(ig1, data1, train, algo="nn", vimp=TRUE)
#> 1 : z10452
#> 2 : z1432
#> 3 : z1616
#> 4 : z4217
#> 5 : z4741
#> 6 : z4744
#> 7 : z4747
#> 8 : z54205
#> 9 : z5530
#> 10 : z5532
#> 11 : z5533
#> 12 : z5534
#> 13 : z5535
#> 14 : z5600
#> 15 : z5603
#> 16 : z5606
#> 17 : z5608
#> 18 : z596
#> 19 : z6300
#> 20 : z79139
#> 21 : z836
#> 22 : z84134
#> 23 : z842
#> 24 : zoutcome
#>
#> NN solver ended normally after 24 iterations
#>
#> logL: -35.20595 srmr: 0.2007807
#>
gplot(res6$graph)
table(E(res6$graph)$color)
#>
#> gray50 red2 royalblue3
#> 43 7 7
mse6 <- predict(res6, data1[-train, ])
yobs <- group[-train]
yhat <- mse6$Yhat[ ,"outcome"]
benchmark(yobs, yhat, thr=0, F1=TRUE)
#> ypred
#> yobs 0 1
#> 0 4 2
#> 1 18 56
#>
#> pre rec f1 mcc
#> 1 0.9655172 0.7567568 0.8484848 0.2497704
benchmark(yobs, yhat, thr=0, F1=FALSE)
#> ypred
#> yobs 0 1
#> 0 4 2
#> 1 18 56
#>
#> sp se acc mcc
#> 1 0.6666667 0.7567568 0.75 0.2497704
# }