Output gradient¶
- metatrain.utils.output_gradient.compute_gradient(target: Tensor, inputs: List[Tensor], is_training: bool, destroy_graph: bool) List[Tensor] [source]¶
Calculates the gradient of a target tensor with respect to a list of input tensors.
target
must be a single torch.Tensor object. If target contains multiple values, the gradient will be calculated with respect to the sum of all values.- Parameters:
target (Tensor) – The tensor for which the gradient is to be computed.
inputs (List[Tensor]) – A list of tensors with respect to which the gradient is computed.
is_training (bool) – A boolean indicating whether the model is in training mode. If True, the computation graph is retained for further gradient computations and the graph of the derivative will be constructed, allowing to compute higher-order derivatives.
destroy_graph (bool) – A boolean indicating whether to destroy the computation graph after computing the gradient. If True, the graph is destroyed (unless
is_training
is True).
- Returns:
A list of tensors representing the gradients of the target with respect to each input
- Return type: