Predict method for ML objects.

# S3 method for class 'ML'
predict(object, newdata, verbose = FALSE, ...)

Arguments

object

A model fitting object from SEMml() function.

newdata

A matrix containing new data with rows corresponding to subjects, and columns to variables.

verbose

Print predicted out-of-sample MSE values (default = FALSE).

...

Currently ignored.

Value

A list of 2 objects:

  1. "PE", vector of the prediction error equal to the Mean Squared Error (MSE) for each out-of-bag prediction. The first value of PE is the AMSE, where we average over all (sink and mediators) graph nodes.

  2. "Yhat", the matrix of continuous predicted values of graph nodes (excluding source nodes) based on out-of-bag samples.

Author

Mario Grassi mario.grassi@unipv.it

Examples


# \donttest{
# Load Amyotrophic Lateral Sclerosis (ALS)
data<- alsData$exprs; dim(data)
#> [1] 160 318
data<- transformData(data)$data
#> Conducting the nonparanormal transformation via shrunkun ECDF...done.
ig<- alsData$graph; gplot(ig)


#...with train-test (0.5-0.5) samples
set.seed(123)
train<- sample(1:nrow(data), 0.5*nrow(data))

start<- Sys.time()
# ... rf
res1<- SEMml(ig, data, train, algo="rf", vimp=FALSE)
#> 1 : z10452 
#> 2 : z1432 
#> 3 : z1616 
#> 4 : z4217 
#> 5 : z4741 
#> 6 : z4744 
#> 7 : z4747 
#> 8 : z54205 
#> 9 : z5530 
#> 10 : z5532 
#> 11 : z5533 
#> 12 : z5534 
#> 13 : z5535 
#> 14 : z5600 
#> 15 : z5603 
#> 16 : z5606 
#> 17 : z5608 
#> 18 : z596 
#> 19 : z6300 
#> 20 : z79139 
#> 21 : z836 
#> 22 : z84134 
#> 23 : z842 
#> 
#>  RF solver ended normally after 23 iterations 
#> 
#>  logL: -33.22824  srmr: 0.0859545 
#> 
mse1<- predict(res1, data[-train, ], verbose=TRUE)
#>      amse     10452      1432      1616      4217      4741      4744      4747 
#> 1.1087827 1.5751174 0.8407252 1.2992863 1.4710206 0.9542770 0.9797583 0.7995575 
#>     54205      5530      5532      5533      5534      5535      5600      5603 
#> 1.2948457 1.3477223 1.3374130 1.7462924 1.3393998 0.9579940 0.5885530 1.1292464 
#>      5606      5608       596      6300     79139       836     84134       842 
#> 1.1417869 1.1304838 1.2727327 0.6232077 0.8256068 0.8613661 1.1949609 0.7906486 

# ... xgb
res2<- SEMml(ig, data, train, algo="xgb", vimp=FALSE)
#> 1 : z10452 
#> 2 : z1432 
#> 3 : z1616 
#> 4 : z4217 
#> 5 : z4741 
#> 6 : z4744 
#> 7 : z4747 
#> 8 : z54205 
#> 9 : z5530 
#> 10 : z5532 
#> 11 : z5533 
#> 12 : z5534 
#> 13 : z5535 
#> 14 : z5600 
#> 15 : z5603 
#> 16 : z5606 
#> 17 : z5608 
#> 18 : z596 
#> 19 : z6300 
#> 20 : z79139 
#> 21 : z836 
#> 22 : z84134 
#> 23 : z842 
#> 
#>  XGB solver ended normally after 23 iterations 
#> 
#>  logL: 70.10035  srmr: 0.0014393 
#> 
mse2<- predict(res2, data[-train, ], verbose=TRUE)
#>      amse     10452      1432      1616      4217      4741      4744      4747 
#> 1.5485977 2.3224700 1.0438189 2.0584053 2.1367318 0.9796140 1.1430960 0.7703339 
#>     54205      5530      5532      5533      5534      5535      5600      5603 
#> 1.4605175 2.0228881 2.0250110 2.2691126 2.0941554 1.5672913 0.8313812 1.2909388 
#>      5606      5608       596      6300     79139       836     84134       842 
#> 1.6402649 1.5021883 2.0930200 0.9203514 1.2559263 1.2960971 1.9018098 0.9923247 

# ... nn
res3<- SEMml(ig, data, train, algo="nn", vimp=FALSE)
#> 1 : z10452 
#> 2 : z1432 
#> 3 : z1616 
#> 4 : z4217 
#> 5 : z4741 
#> 6 : z4744 
#> 7 : z4747 
#> 8 : z54205 
#> 9 : z5530 
#> 10 : z5532 
#> 11 : z5533 
#> 12 : z5534 
#> 13 : z5535 
#> 14 : z5600 
#> 15 : z5603 
#> 16 : z5606 
#> 17 : z5608 
#> 18 : z596 
#> 19 : z6300 
#> 20 : z79139 
#> 21 : z836 
#> 22 : z84134 
#> 23 : z842 
#> 
#>  NN solver ended normally after 23 iterations 
#> 
#>  logL: -36.07395  srmr: 0.194907 
#> 
mse3<- predict(res3, data[-train, ], verbose=TRUE)
#>      amse     10452      1432      1616      4217      4741      4744      4747 
#> 1.6357984 1.5177940 1.7251374 2.0771794 1.2273523 2.7357480 4.8539956 2.7193534 
#>     54205      5530      5532      5533      5534      5535      5600      5603 
#> 2.4691147 0.9825952 1.2431796 1.6832744 1.3556345 0.8070092 1.7176961 2.1512864 
#>      5606      5608       596      6300     79139       836     84134       842 
#> 0.9625083 1.1228321 1.3024315 0.9162324 1.1671646 0.7527348 1.0170158 1.1160940 

# ... gam
res4<- SEMml(ig, data, train, algo="gam", vimp=FALSE)
#> 1 : z10452 
#> 2 : z1432 
#> 3 : z1616 
#> 4 : z4217 
#> 5 : z4741 
#> 6 : z4744 
#> 7 : z4747 
#> 8 : z54205 
#> 9 : z5530 
#> 10 : z5532 
#> 11 : z5533 
#> 12 : z5534 
#> 13 : z5535 
#> 14 : z5600 
#> 15 : z5603 
#> 16 : z5606 
#> 17 : z5608 
#> 18 : z596 
#> 19 : z6300 
#> 20 : z79139 
#> 21 : z836 
#> 22 : z84134 
#> 23 : z842 
#> 
#>  GAM solver ended normally after 23 iterations 
#> 
#>  logL: -46.77283  srmr: 0.3819281 
#> 
mse4<- predict(res4, data[-train, ], verbose=TRUE)
#>      amse     10452      1432      1616      4217      4741      4744      4747 
#> 0.9121870 1.1160997 0.8111562 1.1825384 1.0817833 1.0583429 1.1415647 0.6957230 
#>     54205      5530      5532      5533      5534      5535      5600      5603 
#> 1.1719439 0.9052094 0.8304518 1.3167643 0.8457815 0.7351789 0.5526870 1.0343893 
#>      5606      5608       596      6300     79139       836     84134       842 
#> 0.8886350 1.0385301 1.0676148 0.5683706 0.6435413 0.6957059 0.8558520 0.7424376 

# ... sem
res5<- SEMml(ig, data, train, algo="sem", vimp=FALSE)
#> 1 : z10452 
#> 2 : z1432 
#> 3 : z1616 
#> 4 : z4217 
#> 5 : z4741 
#> 6 : z4744 
#> 7 : z4747 
#> 8 : z54205 
#> 9 : z5530 
#> 10 : z5532 
#> 11 : z5533 
#> 12 : z5534 
#> 13 : z5535 
#> 14 : z5600 
#> 15 : z5603 
#> 16 : z5606 
#> 17 : z5608 
#> 18 : z596 
#> 19 : z6300 
#> 20 : z79139 
#> 21 : z836 
#> 22 : z84134 
#> 23 : z842 
#> 
#>  SEM solver ended normally after 23 iterations 
#> 
#>  logL: -47.90891  srmr: 0.3040025 
#> 
mse5<- predict(res5, data[-train, ], verbose=TRUE)
#>      amse     10452      1432      1616      4217      4741      4744      4747 
#> 0.8973153 1.1160997 0.8122749 1.1257768 1.0817833 1.0773819 1.0733166 0.7435081 
#>     54205      5530      5532      5533      5534      5535      5600      5603 
#> 1.1161717 0.9041601 0.8519291 1.2303948 0.8115211 0.6638488 0.5717884 1.0343893 
#>      5606      5608       596      6300     79139       836     84134       842 
#> 0.8886350 0.9082540 1.0675990 0.6118950 0.6766548 0.6957059 0.8591412 0.7160223 
end<- Sys.time()
print(end-start)
#> Time difference of 10.00831 secs
# }