Source code for metatrain.utils.output_gradient
import warnings
from typing import List, Optional
import torch
[docs]
def compute_gradient(
target: torch.Tensor,
inputs: List[torch.Tensor],
is_training: bool,
destroy_graph: bool,
) -> List[torch.Tensor]:
"""
Calculates the gradient of a target tensor with respect to a list of input tensors.
``target`` must be a single torch.Tensor object. If target contains multiple values,
the gradient will be calculated with respect to the sum of all values.
:param target: The tensor for which the gradient is to be computed.
:param inputs: A list of tensors with respect to which the gradient is computed.
:param is_training: A boolean indicating whether the model is in training mode.
If True, the computation graph is retained for further gradient computations
and the graph of the derivative will be constructed, allowing to compute
higher-order derivatives.
:param destroy_graph: A boolean indicating whether to destroy the computation graph
after computing the gradient. If True, the graph is destroyed (unless
``is_training`` is True).
:return: A list of tensors representing the gradients of the target with respect to
each input
"""
grad_outputs: Optional[List[Optional[torch.Tensor]]] = [torch.ones_like(target)]
try:
gradient = torch.autograd.grad(
outputs=[target],
inputs=inputs,
grad_outputs=grad_outputs,
retain_graph=is_training or (not destroy_graph),
create_graph=is_training,
)
except RuntimeError as e:
# Torch raises an error if the target tensor does not require grad,
# but this could just mean that the target is a constant tensor, like in
# the case of composition models. In this case, we can safely ignore the error
# and we raise a warning instead. The warning can be caught and silenced in the
# appropriate places.
if (
"element 0 of tensors does not require grad and does not have a grad_fn"
in str(e)
):
warnings.warn(f"GRADIENT WARNING: {e}", RuntimeWarning, stacklevel=2)
gradient = [torch.zeros_like(i) for i in inputs]
else:
# Re-raise the error if it's not the one above
raise
if gradient is None:
raise ValueError(
"Unexpected None value for computed gradient. "
"One or more operations inside the model might "
"not have a gradient implementation."
)
else:
return gradient