The function computes the gradient matrix, i.e., the average conditional effects of the input variables w.r.t the neural network model, as discussed by Amesöder et al (2024).

getGradientWeight(object, thr = NULL, verbose = FALSE, ...)

Arguments

object

A neural network object from SEMdnn() function.

thr

A threshold value to apply to gradient weights of input nodes (variables). If NULL (default), the threshold is set to thr=mean(abs(gradient weights)).

verbose

A logical value. If FALSE (default), the processed graph will not be plotted to screen.

...

Currently ignored.

Value

A list od two object: (i) a data.frame including the connections together with their weights, and (ii) the DAG with colored edges. If abs(W) > thr and W < 0, the edge is inhibited and it is highlighted in blue; otherwise, if abs(W) > thr and W > 0, the edge is activated and it is highlighted in red.

Details

The partial derivatives method calculates the derivative (the gradient) of each output variable (y) with respect to each input variable (x) evaluated at each observation (i=1,...,n) of the training data. The contribution of each input is evaluated in terms of both magnitude taking into account not only the connection weights and activation functions, but also the values of each observation of the input variables. Once the gradients for each variable and observation, a summary gradient is calculated by averaging over the observation units. Finally, the average weights are entered into a matrix, W(pxp) and the element-wise product with the binary (1,0) adjacency matrix, A(pxp) of the input DAG, W*A maps the weights on the DAG edges. Note that the operations required to compute partial derivatives are time consuming compared to other methods such as Olden's (connection weight). The computational time increases with the size of the neural network or the size of the data. Therefore, the function uses a progress bar to check the progress of the gradient evaluation per observation.

References

Amesöder, C., Hartig, F. and Pichler, M. (2024), ‘cito': an R package for training neural networks using ‘torch'. Ecography, 2024: e07143. https://doi.org/10.1111/ecog.07143

Author

Mario Grassi mario.grassi@unipv.it

Examples


# \donttest{
if (torch::torch_is_installed()){

# load ALS data
ig<- alsData$graph
data<- alsData$exprs
data<- transformData(data)$data

dnn0 <- SEMdnn(ig, data, train=1:nrow(data), cowt = FALSE,
      #loss = "mse", hidden = 5*K, link = "selu",
      loss = "mse", hidden = c(10, 10, 10), link = "selu",
      validation = 0, bias = TRUE, lr = 0.01,
      epochs = 32, device = "cpu", verbose = TRUE)

res<- getGradientWeight(dnn0, thr=NULL, verbose=TRUE)
table(E(res$dag)$color)
}
#> Conducting the nonparanormal transformation via shrunkun ECDF...done.
#> 1 : z10452 z84134 z836 z4747 z4741 z4744 z79139 z5530 z5532 z5533 z5534 z5535 

#>    epoch   train_l valid_l
#> 32    32 0.2608381      NA
#> 
#> 2 : z842 z1432 z5600 z5603 z6300 

#>    epoch   train_l valid_l
#> 32    32 0.3279338      NA
#> 
#> 3 : z54205 z5606 z5608 

#>    epoch   train_l valid_l
#> 32    32 0.3397509      NA
#> 
#> 4 : z596 z4217 

#>    epoch   train_l valid_l
#> 32    32 0.4199702      NA
#> 
#> 5 : z1616 

#>    epoch   train_l valid_l
#> 32    32 0.3241398      NA
#> 
#> DNN solver ended normally after 736 iterations 
#> 
#>  logL: -43.22795  srmr: 0.1049757 
#> 
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |==============                                                        |  20%
  |                                                                            
  |============================                                          |  40%
  |                                                                            
  |==========================================                            |  60%
  |                                                                            
  |========================================================              |  80%
  |                                                                            
  |======================================================================| 100%

#> 
#>     gray50       red2 royalblue3 
#>         30          6          9 
# }