Predict method for DNN objects.
# S3 method for class 'DNN'
predict(object, newdata, verbose = FALSE, ...)
A model fitting object from SEMdnn()
function.
A matrix containing new data with rows corresponding to subjects, and columns to variables.
Print predicted out-of-sample MSE values (default = FALSE).
Currently ignored.
A list of 2 objects:
"PE", vector of the prediction error equal to the Mean Squared Error (MSE) for each out-of-bag prediction. The first value of PE is the AMSE, where we average over all (sink and mediators) graph nodes.
"Yhat", the matrix of continuous predicted values of graph nodes (excluding source nodes) based on out-of-bag samples.
# \donttest{
if (torch::torch_is_installed()){
# Load Amyotrophic Lateral Sclerosis (ALS)
data<- alsData$exprs; dim(data)
data<- transformData(data)$data
ig<- alsData$graph; gplot(ig)
group<- alsData$group
#...with train-test (0.5-0.5) samples
set.seed(123)
train<- sample(1:nrow(data), 0.5*nrow(data))
start<- Sys.time()
dnn0 <- SEMdnn(ig, data, train, cowt = FALSE, thr = NULL,
#loss = "mse", hidden = 5*K, link = "selu",
loss = "mse", hidden = c(10, 10, 10), link = "selu",
validation = 0, bias = TRUE, lr = 0.01,
epochs = 32, device = "cpu", verbose = TRUE)
end<- Sys.time()
print(end-start)
mse0 <- predict(dnn0, data[-train, ], verbose=TRUE)
# SEMrun vs. SEMdnn MSE comparison
sem0 <- SEMrun(ig, data[train, ], SE="none", limit=1000)
mse0 <- predict(sem0, data[-train,], verbose=TRUE)
#...with a binary outcome (1=case, 0=control)
ig1<- mapGraph(ig, type="outcome"); gplot(ig1)
outcome<- ifelse(group == 0, -1, 1); table(outcome)
data1<- cbind(outcome, data); data1[1:5,1:5]
start<- Sys.time()
dnn1 <- SEMdnn(ig1, data1, train, cowt = TRUE, thr = NULL,
#loss = "mse", hidden = 5*K, link = "selu",
loss = "mse", hidden = c(10, 10, 10), link = "selu",
validation = 0, bias = TRUE, lr = 0.01,
epochs = 32, device = "cpu", verbose = TRUE)
end<- Sys.time()
print(end-start)
mse1 <- predict(dnn1, data1[-train, ])
yobs <- group[-train]
yhat <- mse1$Yhat[ ,"outcome"]
benchmark(yobs, yhat, thr=0, F1=FALSE)
}
#> Conducting the nonparanormal transformation via shrunkun ECDF...done.
#> 1 : z10452 z84134 z836 z4747 z4741 z4744 z79139 z5530 z5532 z5533 z5534 z5535
#> epoch train_l valid_l
#> 32 32 0.233789 NA
#>
#> 2 : z842 z1432 z5600 z5603 z6300
#> epoch train_l valid_l
#> 32 32 0.2671098 NA
#>
#> 3 : z54205 z5606 z5608
#> epoch train_l valid_l
#> 32 32 0.2902266 NA
#>
#> 4 : z596 z4217
#> epoch train_l valid_l
#> 32 32 0.3085203 NA
#>
#> 5 : z1616
#> epoch train_l valid_l
#> 32 32 0.3130042 NA
#>
#> DNN solver ended normally after 736 iterations
#>
#> logL: -32.97723 srmr: 0.0935154
#>
#> Time difference of 11.33274 secs
#> amse 10452 84134 836 4747 4741 4744 79139
#> 0.5757168 0.4026496 0.4884897 0.4263059 0.4562469 0.7344475 0.7633227 0.5492031
#> 5530 5532 5533 5534 5535 842 1432 5600
#> 0.2999405 0.2509800 1.1086082 0.2349699 0.6357745 0.7024854 0.6775760 0.4973224
#> 5603 6300 54205 5606 5608 596 4217 1616
#> 0.8554188 0.4312981 0.8111550 0.4845471 0.5130992 0.7592643 0.5571802 0.6012024
#> NLMINB solver ended normally after 1 iterations
#>
#> deviance/df: 6.263349 srmr: 0.3041398
#>
#> amse 10452 84134 836 4747 4741 4744 79139
#> 0.8973153 1.1160997 0.8591412 0.6957059 0.7435081 1.0773819 1.0733166 0.6766548
#> 5530 5532 5533 5534 5535 842 1432 5600
#> 0.9041601 0.8519291 1.2303948 0.8115211 0.6638488 0.7160223 0.8122749 0.5717884
#> 5603 6300 54205 5606 5608 596 4217 1616
#> 1.0343893 0.6118950 1.1161717 0.8886350 0.9082540 1.0675990 1.0817833 1.1257768
#> 1 : zoutcome
#> epoch train_l valid_l
#> 32 32 0.004096301 NA
#>
#> 2 : z10452 z84134 z836 z4747 z4741 z4744 z79139 z5530 z5532 z5533 z5534 z5535
#> epoch train_l valid_l
#> 32 32 0.2232471 NA
#>
#> 3 : z842 z1432 z5600 z5603 z6300
#> epoch train_l valid_l
#> 32 32 0.2565469 NA
#>
#> 4 : z54205 z5606 z5608
#> epoch train_l valid_l
#> 32 32 0.2627199 NA
#>
#> 5 : z596 z4217
#> epoch train_l valid_l
#> 32 32 0.2888601 NA
#>
#> 6 : z1616
#> epoch train_l valid_l
#> 32 32 0.2608405 NA
#>
#> DNN solver ended normally after 768 iterations
#>
#> logL: -31.61147 srmr: 0.0817129
#>
#> Time difference of 13.60366 secs
#> ypred
#> yobs 0 1
#> 0 5 1
#> 1 3 71
#>
#> sp se acc mcc
#> 1 0.8333333 0.9594595 0.95 0.6960492
# }