Output gradient

metatrain.utils.output_gradient.compute_gradient(target: Tensor, inputs: List[Tensor], is_training: bool, destroy_graph: bool) List[Tensor][source]

Calculates the gradient of a target tensor with respect to a list of input tensors.

target must be a single torch.Tensor object. If target contains multiple values, the gradient will be calculated with respect to the sum of all values.

Parameters:
  • target (Tensor) – The tensor for which the gradient is to be computed.

  • inputs (List[Tensor]) – A list of tensors with respect to which the gradient is computed.

  • is_training (bool) – A boolean indicating whether the model is in training mode. If True, the computation graph is retained for further gradient computations and the graph of the derivative will be constructed, allowing to compute higher-order derivatives.

  • destroy_graph (bool) – A boolean indicating whether to destroy the computation graph after computing the gradient. If True, the graph is destroyed (unless is_training is True).

Returns:

A list of tensors representing the gradients of the target with respect to each input

Return type:

List[Tensor]