Predict method for DNN objects.

# S3 method for class 'DNN'
predict(object, newdata, verbose = FALSE, ...)

Arguments

object

A model fitting object from SEMdnn() function.

newdata

A matrix containing new data with rows corresponding to subjects, and columns to variables.

verbose

Print predicted out-of-sample MSE values (default = FALSE).

...

Currently ignored.

Value

A list of 2 objects:

  1. "PE", vector of the prediction error equal to the Mean Squared Error (MSE) for each out-of-bag prediction. The first value of PE is the AMSE, where we average over all (sink and mediators) graph nodes.

  2. "Yhat", the matrix of continuous predicted values of graph nodes (excluding source nodes) based on out-of-bag samples.

Author

Mario Grassi mario.grassi@unipv.it

Examples


# \donttest{
if (torch::torch_is_installed()){

# Load Amyotrophic Lateral Sclerosis (ALS)
data<- alsData$exprs; dim(data)
data<- transformData(data)$data
ig<- alsData$graph; gplot(ig)
group<- alsData$group 

#...with train-test (0.5-0.5) samples
set.seed(123)
train<- sample(1:nrow(data), 0.5*nrow(data))

start<- Sys.time()
dnn0 <- SEMdnn(ig, data, train, cowt = FALSE, thr = NULL,
      #loss = "mse", hidden = 5*K, link = "selu",
      loss = "mse", hidden = c(10, 10, 10), link = "selu",
      validation = 0, bias = TRUE, lr = 0.01,
      epochs = 32, device = "cpu", verbose = TRUE)
end<- Sys.time()
print(end-start)
mse0 <- predict(dnn0, data[-train, ], verbose=TRUE)

# SEMrun vs. SEMdnn MSE comparison
sem0 <- SEMrun(ig, data[train, ], SE="none", limit=1000)
mse0 <- predict(sem0, data[-train,], verbose=TRUE)

#...with a binary outcome (1=case, 0=control)

ig1<- mapGraph(ig, type="outcome"); gplot(ig1)
outcome<- ifelse(group == 0, -1, 1); table(outcome)
data1<- cbind(outcome, data); data1[1:5,1:5]

start<- Sys.time()
dnn1 <- SEMdnn(ig1, data1, train, cowt = TRUE, thr = NULL,
      #loss = "mse", hidden = 5*K, link = "selu",
      loss = "mse", hidden = c(10, 10, 10), link = "selu",
      validation = 0, bias = TRUE, lr = 0.01,
      epochs = 32, device = "cpu", verbose = TRUE)
end<- Sys.time()
print(end-start)

mse1 <- predict(dnn1, data1[-train, ])
yobs <- group[-train]
yhat <- mse1$Yhat[ ,"outcome"]
benchmark(yobs, yhat, thr=0, F1=FALSE)
}
#> Conducting the nonparanormal transformation via shrunkun ECDF...done.

#> 1 : z10452 z84134 z836 z4747 z4741 z4744 z79139 z5530 z5532 z5533 z5534 z5535 

#>    epoch  train_l valid_l
#> 32    32 0.233789      NA
#> 
#> 2 : z842 z1432 z5600 z5603 z6300 

#>    epoch   train_l valid_l
#> 32    32 0.2671098      NA
#> 
#> 3 : z54205 z5606 z5608 

#>    epoch   train_l valid_l
#> 32    32 0.2902266      NA
#> 
#> 4 : z596 z4217 

#>    epoch   train_l valid_l
#> 32    32 0.3085203      NA
#> 
#> 5 : z1616 

#>    epoch   train_l valid_l
#> 32    32 0.3130042      NA
#> 
#> DNN solver ended normally after 736 iterations 
#> 
#>  logL: -32.97723  srmr: 0.0935154 
#> 
#> Time difference of 11.33274 secs
#>      amse     10452     84134       836      4747      4741      4744     79139 
#> 0.5757168 0.4026496 0.4884897 0.4263059 0.4562469 0.7344475 0.7633227 0.5492031 
#>      5530      5532      5533      5534      5535       842      1432      5600 
#> 0.2999405 0.2509800 1.1086082 0.2349699 0.6357745 0.7024854 0.6775760 0.4973224 
#>      5603      6300     54205      5606      5608       596      4217      1616 
#> 0.8554188 0.4312981 0.8111550 0.4845471 0.5130992 0.7592643 0.5571802 0.6012024 
#> NLMINB solver ended normally after 1 iterations 
#> 
#> deviance/df: 6.263349  srmr: 0.3041398 
#> 

#>      amse     10452     84134       836      4747      4741      4744     79139 
#> 0.8973153 1.1160997 0.8591412 0.6957059 0.7435081 1.0773819 1.0733166 0.6766548 
#>      5530      5532      5533      5534      5535       842      1432      5600 
#> 0.9041601 0.8519291 1.2303948 0.8115211 0.6638488 0.7160223 0.8122749 0.5717884 
#>      5603      6300     54205      5606      5608       596      4217      1616 
#> 1.0343893 0.6118950 1.1161717 0.8886350 0.9082540 1.0675990 1.0817833 1.1257768 

#> 1 : zoutcome 

#>    epoch     train_l valid_l
#> 32    32 0.004096301      NA
#> 
#> 2 : z10452 z84134 z836 z4747 z4741 z4744 z79139 z5530 z5532 z5533 z5534 z5535 

#>    epoch   train_l valid_l
#> 32    32 0.2232471      NA
#> 
#> 3 : z842 z1432 z5600 z5603 z6300 

#>    epoch   train_l valid_l
#> 32    32 0.2565469      NA
#> 
#> 4 : z54205 z5606 z5608 

#>    epoch   train_l valid_l
#> 32    32 0.2627199      NA
#> 
#> 5 : z596 z4217 

#>    epoch   train_l valid_l
#> 32    32 0.2888601      NA
#> 
#> 6 : z1616 

#>    epoch   train_l valid_l
#> 32    32 0.2608405      NA
#> 
#> DNN solver ended normally after 768 iterations 
#> 
#>  logL: -31.61147  srmr: 0.0817129 
#> 
#> Time difference of 13.60366 secs
#>     ypred
#> yobs  0  1
#>    0  5  1
#>    1  3 71
#> 
#>          sp        se  acc       mcc
#> 1 0.8333333 0.9594595 0.95 0.6960492
# }