R/SEMdnn.R
getGradientWeight.Rd
The function computes the gradient matrix, i.e., the average conditional effects of the input variables w.r.t the neural network model, as discussed by Amesöder et al (2024).
getGradientWeight(object, thr = NULL, verbose = FALSE, ...)
A neural network object from SEMdnn()
function.
A threshold value to apply to gradient weights of input nodes (variables). If NULL (default), the threshold is set to thr=mean(abs(gradient weights)).
A logical value. If FALSE (default), the processed graph will not be plotted to screen.
Currently ignored.
A list od two object: (i) a data.frame including the connections together with their weights, and (ii) the DAG with colored edges. If abs(W) > thr and W < 0, the edge is inhibited and it is highlighted in blue; otherwise, if abs(W) > thr and W > 0, the edge is activated and it is highlighted in red.
The partial derivatives method calculates the derivative (the gradient) of each output variable (y) with respect to each input variable (x) evaluated at each observation (i=1,...,n) of the training data. The contribution of each input is evaluated in terms of both magnitude taking into account not only the connection weights and activation functions, but also the values of each observation of the input variables. Once the gradients for each variable and observation, a summary gradient is calculated by averaging over the observation units. Finally, the average weights are entered into a matrix, W(pxp) and the element-wise product with the binary (1,0) adjacency matrix, A(pxp) of the input DAG, W*A maps the weights on the DAG edges. Note that the operations required to compute partial derivatives are time consuming compared to other methods such as Olden's (connection weight). The computational time increases with the size of the neural network or the size of the data. Therefore, the function uses a progress bar to check the progress of the gradient evaluation per observation.
Amesöder, C., Hartig, F. and Pichler, M. (2024), ‘cito': an R package for training neural networks using ‘torch'. Ecography, 2024: e07143. https://doi.org/10.1111/ecog.07143
# \donttest{
if (torch::torch_is_installed()){
# load ALS data
ig<- alsData$graph
data<- alsData$exprs
data<- transformData(data)$data
dnn0 <- SEMdnn(ig, data, train=1:nrow(data), cowt = FALSE,
#loss = "mse", hidden = 5*K, link = "selu",
loss = "mse", hidden = c(10, 10, 10), link = "selu",
validation = 0, bias = TRUE, lr = 0.01,
epochs = 32, device = "cpu", verbose = TRUE)
res<- getGradientWeight(dnn0, thr=NULL, verbose=TRUE)
table(E(res$dag)$color)
}
#> Conducting the nonparanormal transformation via shrunkun ECDF...done.
#> 1 : z10452 z84134 z836 z4747 z4741 z4744 z79139 z5530 z5532 z5533 z5534 z5535
#> epoch train_l valid_l
#> 32 32 0.2608381 NA
#>
#> 2 : z842 z1432 z5600 z5603 z6300
#> epoch train_l valid_l
#> 32 32 0.3279338 NA
#>
#> 3 : z54205 z5606 z5608
#> epoch train_l valid_l
#> 32 32 0.3397509 NA
#>
#> 4 : z596 z4217
#> epoch train_l valid_l
#> 32 32 0.4199702 NA
#>
#> 5 : z1616
#> epoch train_l valid_l
#> 32 32 0.3241398 NA
#>
#> DNN solver ended normally after 736 iterations
#>
#> logL: -43.22795 srmr: 0.1049757
#>
#>
|
| | 0%
|
|============== | 20%
|
|============================ | 40%
|
|========================================== | 60%
|
|======================================================== | 80%
|
|======================================================================| 100%
#>
#> gray50 red2 royalblue3
#> 30 6 9
# }