The function converts a graph to a collection of nodewise-based models: each mediator or sink variable can be expressed as a function of its parents. Based on the assumed type of relationship, i.e. linear or non-linear, SEMml() fits a ML model to each node (variable) with non-zero incoming connectivity. The model fitting is repeated equation-by equation (r=1,...,R) times, where R is the number of mediators and sink nodes.

  train = NULL,
  algo = "sem",
  vimp = FALSE,
  thr = NULL,
  verbose = FALSE,



An igraph object.


A matrix with rows corresponding to subjects, and columns to graph nodes (variables).


A numeric vector specifying the row indices corresponding to the train dataset (default = NULL).


ML method used for nodewise-network predictions. Six algorithms can be specified:

  • algo="sem" (default) for a linear SEM, see SEMrun.

  • algo="gam" for a generalized additive model, see gam.

  • algo="rf" for a random forest model, see ranger.

  • algo="xgb" for a XGBoost model, see xgboost.

  • algo="nn" for a small neural network model (1 hidden layer and 10 nodes), see nnet.

  • algo="dnn" for a large neural network model (1 hidden layers and 1000 nodes), see dnn.


A Logical value(default=FALSE). If TRUE compute the variable importance, considering: (i) the squared value of the t-statistic or F-statistic of the model parameters for "sem" or "gam"; (ii) the variable importance from the importance or xgb.importance functions for "rf" or "xgb"; (iii) the Olden's connection weights for "nn" or "dnn".


A numerical value indicating the threshold to apply on the variable importance to color the graph. If thr=NULL (default), the threshold is set to thr = abs(mean(vimp)).


A logical value. If FALSE (default), the processed graph will not be plotted to screen.


Currently ignored.


An S3 object of class "ML" is returned. It is a list of 5 objects:

  1. "fit", a list of ML model objects, including: the estimated covariance matrix (Sigma), the estimated model errors (Psi), the fitting indices (fitIdx), and the signed Shapley R2 values (parameterEstimates), if shap = TRUE,

  2. "Yhat", a matrix of predictions of sink and mediator graph nodes.

  3. "model", a list of all the fitted nodewise-based models (sem, gam, rf, xgb or nn).

  4. "graph", the induced DAG of the input graph mapped on data variables. If vimp = TRUE, the DAG is colored based on the variable importance measure, i.e., if abs(vimp) > thr will be highlighted in red (vimp > 0) or blue (vimp < 0).

  5. "data", input training data subset mapping graph nodes.


By mapping data onto the input graph, SEMml() creates a set of nodewise-based models based on the directed links, i.e., a set of edges pointing in the same direction, between two nodes in the input graph that are causally relevant to each other. The mediator or sink variables can be characterized in detail as functions of their parents. An ML model (sem, gam, rf, xgb, nn, dnn) can then be fitted to each variable with non-zero inbound connectivity, taking into account the kind of relationship (linear or non-linear). With R representing the number of mediators and sink nodes in the network, the model fitting process is performed equation-by-equation (r=1,...,R) times.


Mario Grassi


# \donttest{
# Load Amyotrophic Lateral Sclerosis (ALS)
data<- alsData$exprs; dim(data)
#> [1] 160 318
data<- transformData(data)$data
#> Conducting the nonparanormal transformation via shrunkun ECDF...done.
group<- alsData$group; table (group)
#> group
#>   0   1 
#>  21 139 
ig<- alsData$graph; gplot(ig)

#...with train-test (0.5-0.5) samples
train<- sample(1:nrow(data), 0.5*nrow(data))

start<- Sys.time()
# ... rf
#res1<- SEMml(ig, data, train, algo="rf", vimp=FALSE)
res1<- SEMml(ig, data, train, algo="rf", vimp=TRUE)
#> 1 : z10452 
#> 2 : z1432 
#> 3 : z1616 
#> 4 : z4217 
#> 5 : z4741 
#> 6 : z4744 
#> 7 : z4747 
#> 8 : z54205 
#> 9 : z5530 
#> 10 : z5532 
#> 11 : z5533 
#> 12 : z5534 
#> 13 : z5535 
#> 14 : z5600 
#> 15 : z5603 
#> 16 : z5606 
#> 17 : z5608 
#> 18 : z596 
#> 19 : z6300 
#> 20 : z79139 
#> 21 : z836 
#> 22 : z84134 
#> 23 : z842 
#>  RF solver ended normally after 23 iterations 
#>  logL: -33.22824  srmr: 0.0859545 

# ... xgb
#res2<- SEMml(ig, data, train, algo="xgb", vimp=FALSE)
res2<- SEMml(ig, data, train, algo="xgb", vimp=TRUE)
#> 1 : z10452 
#> 2 : z1432 
#> 3 : z1616 
#> 4 : z4217 
#> 5 : z4741 
#> 6 : z4744 
#> 7 : z4747 
#> 8 : z54205 
#> 9 : z5530 
#> 10 : z5532 
#> 11 : z5533 
#> 12 : z5534 
#> 13 : z5535 
#> 14 : z5600 
#> 15 : z5603 
#> 16 : z5606 
#> 17 : z5608 
#> 18 : z596 
#> 19 : z6300 
#> 20 : z79139 
#> 21 : z836 
#> 22 : z84134 
#> 23 : z842 
#>  XGB solver ended normally after 23 iterations 
#>  logL: 70.10035  srmr: 0.0014393 

# ... nn
#res3<- SEMml(ig, data, train, algo="nn", vimp=FALSE)
res3<- SEMml(ig, data, train, algo="nn", vimp=TRUE)
#> 1 : z10452 
#> 2 : z1432 
#> 3 : z1616 
#> 4 : z4217 
#> 5 : z4741 
#> 6 : z4744 
#> 7 : z4747 
#> 8 : z54205 
#> 9 : z5530 
#> 10 : z5532 
#> 11 : z5533 
#> 12 : z5534 
#> 13 : z5535 
#> 14 : z5600 
#> 15 : z5603 
#> 16 : z5606 
#> 17 : z5608 
#> 18 : z596 
#> 19 : z6300 
#> 20 : z79139 
#> 21 : z836 
#> 22 : z84134 
#> 23 : z842 
#>  NN solver ended normally after 23 iterations 
#>  logL: -37.48083  srmr: 0.1987503 

# ... gam
#res4<- SEMml(ig, data, train, algo="gam", vimp=FALSE)
res4<- SEMml(ig, data, train, algo="gam", vimp=TRUE)
#> 1 : z10452 
#> 2 : z1432 
#> 3 : z1616 
#> 4 : z4217 
#> 5 : z4741 
#> 6 : z4744 
#> 7 : z4747 
#> 8 : z54205 
#> 9 : z5530 
#> 10 : z5532 
#> 11 : z5533 
#> 12 : z5534 
#> 13 : z5535 
#> 14 : z5600 
#> 15 : z5603 
#> 16 : z5606 
#> 17 : z5608 
#> 18 : z596 
#> 19 : z6300 
#> 20 : z79139 
#> 21 : z836 
#> 22 : z84134 
#> 23 : z842 
#>  GAM solver ended normally after 23 iterations 
#>  logL: -46.77283  srmr: 0.3819281 
end<- Sys.time()
#> Time difference of 9.565145 secs

# ... sem
#res5<- SEMml(ig, data, train, algo="sem", vimp=FALSE)
res5<- SEMml(ig, data, train, algo="sem", vimp=TRUE)
#> 1 : z10452 
#> 2 : z1432 
#> 3 : z1616 
#> 4 : z4217 
#> 5 : z4741 
#> 6 : z4744 
#> 7 : z4747 
#> 8 : z54205 
#> 9 : z5530 
#> 10 : z5532 
#> 11 : z5533 
#> 12 : z5534 
#> 13 : z5535 
#> 14 : z5600 
#> 15 : z5603 
#> 16 : z5606 
#> 17 : z5608 
#> 18 : z596 
#> 19 : z6300 
#> 20 : z79139 
#> 21 : z836 
#> 22 : z84134 
#> 23 : z842 
#>  SEM solver ended normally after 23 iterations 
#>  logL: -47.90891  srmr: 0.3040025 

#str(res5, max.level=2)
#>        logL        amse        rmse   srmr.srmr 
#> -47.9089078   0.8166851   0.9037063   0.3040025 
#>          lhs op   rhs varImp
#> 1      10452  ~  6647  0.164
#> z5606   1432  ~  5606  9.823
#> z5608   1432  ~  5608 29.897
#> z7132   1616  ~  7132  2.654
#> z7133   1616  ~  7133  0.132
#> 11      4217  ~  1616  2.781
#> z1432   4741  ~  1432  7.768
#> z5600   4741  ~  5600  0.554
#> z5603   4741  ~  5603  0.745
#> z6300   4741  ~  6300  0.801
#> z5630   4741  ~  5630  3.339
#> z14321  4744  ~  1432  5.927
#> z56001  4744  ~  5600  0.563
#> z56031  4744  ~  5603  0.351
#> z63001  4744  ~  6300  0.347
#> z56301  4744  ~  5630  4.134
#> z6647   4747  ~  6647 25.892
#> z14322  4747  ~  1432  8.210
#> z56002  4747  ~  5600  0.050
#> z56032  4747  ~  5603  4.131
#> z63002  4747  ~  6300  0.973
#> z56302  4747  ~  5630  1.506
#> z581   54205  ~   581 22.353
#> z572   54205  ~   572  0.029
#> z596   54205  ~   596  0.080
#> z598   54205  ~   598  2.735
#> 12      5530  ~  6647  9.332
#> 13      5532  ~  6647 17.839
#> 14      5533  ~  6647  6.467
#> 15      5534  ~  6647  9.491
#> 16      5535  ~  6647 16.447
#> z56061  5600  ~  5606  2.159
#> z56081  5600  ~  5608  5.425
#> z56062  5603  ~  5606  4.948
#> z56082  5603  ~  5608  0.037
#> 17      5606  ~  4217  0.669
#> 18      5608  ~  4217 34.179
#> 19       596  ~  6647  0.416
#> z56063  6300  ~  5606  1.325
#> z56083  6300  ~  5608 14.431
#> 110    79139  ~  6647 19.233
#> 111      836  ~   842 22.853
#> 112    84134  ~  6647  7.515
#> z54205   842  ~ 54205 41.336
#> z317     842  ~   317  6.901

#Comparison of AMSE (in train data)
rf <- res1$fit$fitIdx[2];rf
#>      amse 
#> 0.2327284 
xgb<- res2$fit$fitIdx[2];xgb
#>         amse 
#> 0.0001865892 
nn <- res3$fit$fitIdx[2];nn
#>      amse 
#> 0.4966564 
gam<- res4$fit$fitIdx[2];gam
#>      amse 
#> 0.7525285 
sem<- res5$fit$fitIdx[2];sem
#>      amse 
#> 0.8166851 

#Comparison of SRMR (in train data)
rf <- res1$fit$fitIdx[4];rf
#>       srmr 
#> 0.08595454 
xgb<- res2$fit$fitIdx[4];xgb
#>        srmr 
#> 0.001439282 
nn <- res3$fit$fitIdx[4];nn
#>      srmr 
#> 0.1987503 
gam<- res4$fit$fitIdx[4];gam
#>      srmr 
#> 0.3819281 
sem<- res5$fit$fitIdx[4];sem
#> srmr.srmr 
#> 0.3040025 

#Comparison of VIMP (in train data)
table(E(res1$graph)$color) #rf
#> gray50   red2 
#>     28     17 
table(E(res2$graph)$color) #xgb
#> gray50   red2 
#>     26     19 
table(E(res3$graph)$color) #nn
#>     gray50       red2 royalblue3 
#>         31          4         10 
table(E(res4$graph)$color) #gam
#> gray50   red2 
#>     31     14 
table(E(res5$graph)$color) #sem
#> gray50   red2 
#>     31     14 

#Comparison of AMSE (in test data)
print(predict(res1, data[-train, ])$PE[1]) #rf
#>     amse 
#> 1.108783 
print(predict(res2, data[-train, ])$PE[1]) #xgb
#>     amse 
#> 1.548598 
print(predict(res3, data[-train, ])$PE[1]) #nn
#>     amse 
#> 1.643303 
print(predict(res4, data[-train, ])$PE[1]) #gam
#>     amse 
#> 0.912187 
print(predict(res5, data[-train, ])$PE[1]) #sem
#>      amse 
#> 0.8973153 

#...with a binary outcome (1=case, 0=control)

ig1<- mapGraph(ig, type="outcome"); gplot(ig1)

outcome<- ifelse(group == 0, -1, 1); table(outcome)
#> outcome
#>  -1   1 
#>  21 139 
data1<- cbind(outcome, data); data1[1:5,1:5]
#>      outcome        207         208      10000       284
#> ALS2       1 -1.8273895 -0.45307006 -0.1360061 0.4530701
#> ALS3       1 -2.5616910 -0.96201413  0.3160400 0.6762093
#> ALS4       1 -0.8003346  0.82216031 -1.1521227 0.5613048
#> ALS5       1 -2.1342965 -0.98709115  1.1521227 0.5064807
#> ALS6       1 -2.0111279  0.02393297  0.5987578 0.1360061

res6 <- SEMml(ig1, data1, train, algo="nn", vimp=TRUE)
#> 1 : z10452 
#> 2 : z1432 
#> 3 : z1616 
#> 4 : z4217 
#> 5 : z4741 
#> 6 : z4744 
#> 7 : z4747 
#> 8 : z54205 
#> 9 : z5530 
#> 10 : z5532 
#> 11 : z5533 
#> 12 : z5534 
#> 13 : z5535 
#> 14 : z5600 
#> 15 : z5603 
#> 16 : z5606 
#> 17 : z5608 
#> 18 : z596 
#> 19 : z6300 
#> 20 : z79139 
#> 21 : z836 
#> 22 : z84134 
#> 23 : z842 
#> 24 : zoutcome 
#>  NN solver ended normally after 24 iterations 
#>  logL: -35.20595  srmr: 0.2007807 

#>     gray50       red2 royalblue3 
#>         43          7          7 

mse6 <- predict(res6, data1[-train, ])
yobs <- group[-train]
yhat <- mse6$Yhat[ ,"outcome"]
benchmark(yobs, yhat, thr=0, F1=TRUE)
#>     ypred
#> yobs  0  1
#>    0  4  2
#>    1 18 56
#>         pre       rec        f1       mcc
#> 1 0.9655172 0.7567568 0.8484848 0.2497704
benchmark(yobs, yhat, thr=0, F1=FALSE)
#>     ypred
#> yobs  0  1
#>    0  4  2
#>    1 18 56
#>          sp        se  acc       mcc
#> 1 0.6666667 0.7567568 0.75 0.2497704
# }