Predict method for DNN objects.
Usage
# S3 method for class 'DNN'
predict(object, newdata, newoutcome = NULL, verbose = FALSE, ...)Arguments
- object
A model fitting object from
SEMdnn()function.- newdata
A matrix containing new data with rows corresponding to subjects, and columns to variables. If newdata = NULL, the train data are used.
- newoutcome
A new character vector (as.factor) of labels for a categorical output (target) (default = NULL).
- verbose
Print predicted out-of-sample MSE values (default = FALSE).
- ...
Currently ignored.
Value
A list of three objects:
"PE", vector of the amse = average MSE over all (sink and mediators) graph nodes; r2 = 1 - amse; and srmr= Standardized Root Means Square Residual between the out-of-bag correlation matrix and the model correlation matrix.
"mse", vector of the Mean Squared Error (MSE) for each out-of-bag prediction of the sink and mediators graph nodes.
"Yhat", the matrix of continuous predicted values of graph nodes (excluding source nodes) based on out-of-bag samples.
Author
Mario Grassi mario.grassi@unipv.it
Examples
# \donttest{
if (torch::torch_is_installed()){
# Load Amyotrophic Lateral Sclerosis (ALS)
ig<- alsData$graph
data<- alsData$exprs
data<- transformData(data)$data
group<- alsData$group
#...with train-test (0.5-0.5) samples
set.seed(123)
train<- sample(1:nrow(data), 0.5*nrow(data))
#ncores<- parallel::detectCores(logical = FALSE)
start<- Sys.time()
dnn0 <- SEMdnn(ig, data[train, ], algo ="layerwise",
hidden = c(10,10,10), link = "selu", bias = TRUE,
epochs = 32, patience = 10, verbose = TRUE)
end<- Sys.time()
print(end-start)
pred.dnn <- predict(dnn0, data[-train, ], verbose=TRUE)
# SEMrun vs. SEMdnn MSE comparison
sem0 <- SEMrun(ig, data[train, ], algo="ricf", n_rep=0)
pred.sem <- predict(sem0, data[-train,], verbose=TRUE)
#...with a categorical (as.factor) outcome
outcome <- factor(ifelse(group == 0, "control", "case")); table(outcome)
start<- Sys.time()
dnn1 <- SEMdnn(ig, data[train, ], outcome[train], algo ="layerwise",
hidden = c(10,10,10), link = "selu", bias = TRUE,
epochs = 32, patience = 10, verbose = TRUE)
end<- Sys.time()
print(end-start)
pred <- predict(dnn1, data[-train, ], outcome[-train], verbose=TRUE)
yhat <- pred$Yhat[ ,levels(outcome)]; head(yhat)
yobs <- outcome[-train]; head(yobs)
classificationReport(yobs, yhat, verbose=TRUE)$stats
}
#> Conducting the nonparanormal transformation via shrunkun ECDF...done.
#> Running SEM model via DNN...
#>
#> layer 1 : z10452 z84134 z836 z4747 z4741 z4744 z79139 z5530 z5532 z5533 ...
#> train val base
#> 0.4104707 Inf 0.9875000
#>
#> layer 2 : z842 z1432 z5600 z5603 z6300
#> train val base
#> 0.4693277 Inf 0.9875000
#>
#> layer 3 : z54205 z5606 z5608
#> train val base
#> 0.5376277 Inf 0.9875000
#>
#> layer 4 : z596 z4217
#> train val base
#> 0.8641988 Inf 0.9875001
#>
#> layer 5 : z1616
#> train val base
#> 0.8268933 Inf 0.9875001
#> done.
#>
#> DNN solver ended normally after 160 iterations
#>
#> logL:-40.969572 srmr:0.182023
#> Time difference of 15.07302 secs
#> amse r2 srmr
#> 0.7186168 0.2813832 0.2397997
#> RICF solver ended normally after 2 iterations
#>
#> deviance/df: 6.262846 srmr: 0.3040025
#>
#> amse r2 srmr
#> 0.8571886 0.1428114 0.2948502
#> Running SEM model via DNN...
#>
#> layer 1 : zcase zcontrol
#> train val base
#> 0.003214309 Inf 0.987500012
#>
#> layer 2 : z10452 z84134 z836 z4747 z4741 z4744 z79139 z5530 z5532 z5533 ...
#> train val base
#> 0.4285718 Inf 0.9875000
#>
#> layer 3 : z842 z1432 z5600 z5603 z6300
#> train val base
#> 0.4605743 Inf 0.9875000
#>
#> layer 4 : z54205 z5606 z5608
#> train val base
#> 0.5376339 Inf 0.9875000
#>
#> layer 5 : z596 z4217
#> train val base
#> 0.8655762 Inf 0.9875001
#>
#> layer 6 : z1616
#> train val base
#> 0.8682160 Inf 0.9875001
#> done.
#>
#> DNN solver ended normally after 192 iterations
#>
#> logL:-39.856559 srmr:0.177495
#> Time difference of 27.0571 secs
#> amse r2 srmr
#> 0.6967868 0.3032132 0.2357844
#> pred
#> yobs case control
#> case 62 12
#> control 1 5
#>
#> precision recall f1 accuracy mcc support
#> case 0.9841270 0.8378378 0.9051095 0.8375 0.4321455 74
#> control 0.2941176 0.8333333 0.4347826 0.8375 0.4321455 6
#> macro avg 0.6391223 0.8355856 0.6699460 0.8375 0.4321455 80
#> weighted avg 0.9323763 0.8375000 0.8698350 0.8375 0.4321455 80
#> support_prop
#> case 0.925
#> control 0.075
#> macro avg 1.000
#> weighted avg 1.000
# }