Loss

class metatrain.utils.loss.LossParams[source]

Bases: TypedDict

type: NotRequired[str] = 'mse'
weight: NotRequired[float] = 1.0
reduction: NotRequired[Literal['none', 'mean', 'sum']] = 'mean'
class metatrain.utils.loss.LossSpecification[source]

Bases: TypedDict

type: NotRequired[str] = 'mse'
weight: NotRequired[float] = 1.0
reduction: NotRequired[Literal['none', 'mean', 'sum']] = 'mean'
gradients: NotRequired[dict[str, LossParams]] = {}
class metatrain.utils.loss.LossInterface(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: ABC

Abstract base for all loss functions.

Subclasses must implement the compute method.

Parameters:
  • name (str) – key in the predictions/targets dict to select the TensorMap.

  • gradient (str | None) – optional name of a gradient field to extract.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch losses (“mean”, “sum”, etc.).

target: str
gradient: str | None
weight: float
reduction: str
loss_kwargs: Dict[str, Any]
abstractmethod compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Any | None = None) Tensor[source]

Compute the loss value.

Parameters:
  • predictions (Dict[str, TensorMap]) – mapping from target names to the predictions for those targets.

  • targets (Dict[str, TensorMap]) – mapping from target names to the reference targets.

  • extra_data (Any | None) – Any extra data needed for the loss computation.

Returns:

Value of the loss.

Return type:

Tensor

classmethod from_config(cfg: Dict[str, Any]) LossInterface[source]

Instantiate a loss from a config dict.

Parameters:

cfg (Dict[str, Any]) – keyword args matching the loss constructor.

Returns:

instance of a LossInterface subclass.

Return type:

LossInterface

class metatrain.utils.loss.BaseTensorMapLoss(name: str, gradient: str | None, weight: float, reduction: str, *, loss_fn: _Loss)[source]

Bases: LossInterface

Backbone for pointwise losses on TensorMap entries.

Provides a compute_flattened() helper that extracts values or gradients, flattens them, applies an optional mask, and computes the torch loss.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – dummy here; real weighting in ScheduledLoss.

  • reduction (str) – reduction mode for torch loss.

  • loss_fn (_Loss) – pre-instantiated torch.nn loss (e.g. MSELoss).

compute_flattened(tensor_map_predictions_for_target: TensorMap, tensor_map_targets_for_target: TensorMap, tensor_map_mask_for_target: TensorMap | None = None) Tensor[source]

Flatten prediction and target blocks (and optional mask), then apply the torch loss.

Parameters:
  • tensor_map_predictions_for_target (TensorMap) – predicted TensorMap.

  • tensor_map_targets_for_target (TensorMap) – target TensorMap.

  • tensor_map_mask_for_target (TensorMap | None) – optional mask TensorMap.

Returns:

scalar torch.Tensor of the computed loss.

Return type:

Tensor

compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Any | None = None) Tensor[source]

Compute the unmasked pointwise loss.

Parameters:
  • predictions (Dict[str, TensorMap]) – mapping of names to TensorMap.

  • targets (Dict[str, TensorMap]) – mapping of names to TensorMap.

  • extra_data (Any | None) – ignored for unmasked losses.

Returns:

scalar torch.Tensor loss.

Return type:

Tensor

class metatrain.utils.loss.MaskedTensorMapLoss(name: str, gradient: str | None, weight: float, reduction: str, *, loss_fn: _Loss)[source]

Bases: BaseTensorMapLoss

Pointwise masked loss on TensorMap entries.

Inherits flattening and torch-loss logic from BaseTensorMapLoss.

Parameters:
  • name (str)

  • gradient (str | None)

  • weight (float)

  • reduction (str)

  • loss_fn (_Loss)

compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Dict[str, TensorMap] | None = None) Tensor[source]

Gather and flatten target and prediction blocks, then compute loss.

Parameters:
  • predictions (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.

  • targets (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.

  • extra_data (Dict[str, TensorMap] | None) – Additional data for loss computation. Assumes that, for the target name used in the constructor, there is a corresponding data field name + "_mask" that contains the tensor to be used for masking. It should have the same metadata as the target and prediction tensors.

Returns:

Scalar loss tensor.

Return type:

Tensor

class metatrain.utils.loss.TensorMapMSELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: BaseTensorMapLoss

Unmasked mean-squared error on TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch loss.

class metatrain.utils.loss.TensorMapMAELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: BaseTensorMapLoss

Unmasked mean-absolute error on TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch loss.

class metatrain.utils.loss.TensorMapHuberLoss(name: str, gradient: str | None, weight: float, reduction: str, delta: float)[source]

Bases: BaseTensorMapLoss

Unmasked Huber loss on TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch loss.

  • delta (float) – threshold parameter for HuberLoss.

class metatrain.utils.loss.TensorMapMaskedMSELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: MaskedTensorMapLoss

Masked mean-squared error on TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch loss.

class metatrain.utils.loss.TensorMapMaskedMAELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: MaskedTensorMapLoss

Masked mean-absolute error on TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch loss.

class metatrain.utils.loss.TensorMapMaskedHuberLoss(name: str, gradient: str | None, weight: float, reduction: str, delta: float)[source]

Bases: MaskedTensorMapLoss

Masked Huber loss on TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch loss.

  • delta (float) – threshold parameter for HuberLoss.

class metatrain.utils.loss.ShiftAgnosticMSE(name: str, gradient: str | None, weight: float, int_weight: float, grad_penalty_weight: float, reduction: str)[source]

Bases: LossInterface

Shift agnostic MSE loss on TensorMap entries.

This loss assumes that the target is some kind of profile along the properties of the TensorBlock. It finds the rigid shift between the predictions and targets that minimizes the MSE, and returns that minimal MSE.

Parameters:
  • name (str) – key for the target in the prediction/target dictionary.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • int_weight (float) – The loss function can also contain the MSE on the cumulative profile. This number weights the contribution of the cumulative term in the final loss. If 0, no cumulative term is added.

  • grad_penalty_weight (float) – The loss function penalizes gradients of the predicted profiles in the regions where the target is NaN. This number weights the contribution of the penalty term in the final loss. If 0, the predictions on those regions are free to be what they want.

  • reduction (str) – reduction mode for torch loss.

compute(model_predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Any | None = None) Tensor[source]

Gather and flatten target and prediction blocks, then compute shift agnostic loss.

Parameters:
  • model_predictions (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.

  • targets (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.

  • extra_data (Any | None) – extra data, not needed for this loss function

Returns:

Scalar loss tensor.

Return type:

Tensor

class metatrain.utils.loss.TensorMapEnsembleLoss(name: str, gradient: str | None, weight: float, reduction: str, loss_fn: Module)[source]

Bases: BaseTensorMapLoss

Loss for ensembles based on TensorMap entries. Assumes that ensemble is the outermost dimension of TensorBlock properties.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch loss.

  • loss_fn (Module) – pre-instantiated torch.nn loss.

compute_flattened(pred_mean: TensorMap, target: TensorMap, pred_var: TensorMap) Tensor[source]

Flatten prediction and target blocks (and optional mask), then apply the torch loss.

Parameters:
  • pred_mean (TensorMap) – mean of ensemble predictions TensorMap.

  • target (TensorMap) – target TensorMap.

  • pred_var (TensorMap) – variance of ensemble predictions TensorMap.

Returns:

scalar torch.Tensor of the computed loss.

Return type:

Tensor

compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Dict[str, TensorMap] | None = None) Tensor[source]

Gather and flatten target and prediction blocks, then compute loss.

Parameters:
  • predictions (Dict[str, TensorMap]) – Mapping from target names to TensorMaps, must contain ensemble as the outer-most property dimension.

  • targets (Dict[str, TensorMap]) – Mapping from target names to their ref value TensorMaps.

  • extra_data (Dict[str, TensorMap] | None) – Ignored for this loss.

Returns:

Scalar loss tensor.

Return type:

Tensor

class metatrain.utils.loss.GaussianCRPSLoss(reduction: str = 'mean', eps: float = 1e-12)[source]

Bases: Module

Gaussian CRPS loss.

This implements the closed-form expression for the CRPS of a Gaussian predictive distribution \(\mathcal{N}(\mu, \sigma^2)\) evaluated at a target value \(x\):

\[\text{CRPS}(x; \mu, \sigma) = \sigma \left[ z(2\Phi(z) - 1) + 2\phi(z) - \frac{1}{\sqrt{\pi}} \right]\]

where \(z = \frac{x - \mu}{\sigma}\), \(\Phi\) is the standard normal CDF, and \(\phi\) is the standard normal PDF.

Parameters:
  • reduction (str) – ‘none’, ‘mean’, or ‘sum’.

  • eps (float) – small constant for numerical stability on variance.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(input: Tensor, target: Tensor, var: Tensor) Tensor[source]

Compute the Gaussian CRPS loss.

Parameters:
  • input (Tensor) – Mean predictions.

  • target (Tensor) – Target values.

  • var (Tensor) – Variance of the predictions.

Returns:

Value of the loss.

Return type:

Tensor

class metatrain.utils.loss.EmpiricalCRPSLoss(reduction: str = 'mean')[source]

Bases: Module

Empirical CRPS loss for ensemble predictions.

The ensemble predictions \(\{Y_i\}_{i=1}^M\) for each data point define an empirical predictive distribution:

\[F_M(y) = \frac{1}{M} \sum_{i=1}^M \mathbb{1}_{Y_i \le y}\]

The CRPS of this empirical distribution at observation \(z\) has the closed form:

\[\text{CRPS}(F_M, z) = \frac{1}{M} \sum_{i=1}^M |Y_i - z| - \frac{1}{2 M^2} \sum_{i,j} |Y_i - Y_j|\]
Parameters:

reduction (str) – ‘none’, ‘mean’, or ‘sum’.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(ensemble: Tensor, target: Tensor) Tensor[source]

Compute the Empirical CRPS loss.

Parameters:
  • ensemble (Tensor) – Ensemble predictions, shape (B, M).

  • target (Tensor) – Target values, shape (B,).

Returns:

Value of the loss.

Return type:

Tensor

class metatrain.utils.loss.TensorMapGaussianNLLLoss(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: TensorMapEnsembleLoss

Gaussian negative log-likelihood loss for TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch loss.

class metatrain.utils.loss.TensorMapGaussianCRPSLoss(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: TensorMapEnsembleLoss

Gaussian CRPS loss for TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch loss.

class metatrain.utils.loss.TensorMapEmpiricalCRPSLoss(name: str, gradient: str | None, weight: float, reduction: str)[source]

Bases: TensorMapEnsembleLoss

Empirical CRPS loss for TensorMap entries.

Parameters:
  • name (str) – key in the predictions/targets dict.

  • gradient (str | None) – optional gradient field name.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for torch loss.

compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Dict[str, TensorMap] | None = None) Tensor[source]

Gather and flatten target and prediction blocks, then compute loss.

Parameters:
  • predictions (Dict[str, TensorMap]) – Mapping from target names to TensorMaps, must contain ensemble as the outer-most property dimension.

  • targets (Dict[str, TensorMap]) – Mapping from target names to their ref value TensorMaps.

  • extra_data (Dict[str, TensorMap] | None) – Ignored for this loss.

Returns:

Scalar loss tensor.

Return type:

Tensor

class metatrain.utils.loss.LossAggregator(targets: Dict[str, TargetInfo], config: Dict[str, LossSpecification])[source]

Bases: LossInterface

Aggregate multiple LossInterface terms with scheduled weights and metadata.

Parameters:
compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Any | None = None) Tensor[source]

Sum over all scheduled losses present in the predictions.

Parameters:
  • predictions (Dict[str, TensorMap]) – mapping from target names to TensorMap.

  • targets (Dict[str, TensorMap]) – mapping from target names to TensorMap.

  • extra_data (Any | None) – Any extra data needed for the loss computation.

Returns:

scalar torch.Tensor with the total loss.

Return type:

Tensor

class metatrain.utils.loss.LossType(*values)[source]

Bases: Enum

Enumeration of available loss types and their implementing classes.

Parameters:
  • key – string key for the loss type.

  • cls – class implementing the loss type.

MSE = ('mse', <class 'metatrain.utils.loss.TensorMapMSELoss'>)
MAE = ('mae', <class 'metatrain.utils.loss.TensorMapMAELoss'>)
HUBER = ('huber', <class 'metatrain.utils.loss.TensorMapHuberLoss'>)
MASKED_MSE = ('masked_mse', <class 'metatrain.utils.loss.TensorMapMaskedMSELoss'>)
MASKED_MAE = ('masked_mae', <class 'metatrain.utils.loss.TensorMapMaskedMAELoss'>)
MASKED_HUBER = ('masked_huber', <class 'metatrain.utils.loss.TensorMapMaskedHuberLoss'>)
POINTWISE = ('pointwise', <class 'metatrain.utils.loss.BaseTensorMapLoss'>)
MASKED_POINTWISE = ('masked_pointwise', <class 'metatrain.utils.loss.MaskedTensorMapLoss'>)
SHIFT_AGNOSTIC_MSE = ('shift_agnostic_mse', <class 'metatrain.utils.loss.ShiftAgnosticMSE'>)
GAUSSIAN_NLL = ('gaussian_nll_ensemble', <class 'metatrain.utils.loss.TensorMapGaussianNLLLoss'>)
GAUSSIAN_CRPS = ('gaussian_crps_ensemble', <class 'metatrain.utils.loss.TensorMapGaussianCRPSLoss'>)
EMPIRICAL_CRPS = ('empirical_crps_ensemble', <class 'metatrain.utils.loss.TensorMapEmpiricalCRPSLoss'>)
property key: str

String key for this loss type.

property cls: Type[LossInterface]

Class implementing this loss type.

classmethod from_key(key: str) LossType[source]

Look up a LossType by its string key.

Parameters:

key (str) – key that identifies the loss type.

Raises:

ValueError – if the key is not valid.

Returns:

the matching LossType enum member.

Return type:

LossType

metatrain.utils.loss.create_loss(loss_type: str, *, name: str, gradient: str | None, weight: float, reduction: str, **extra_kwargs: Any) LossInterface[source]

Factory to instantiate a concrete LossInterface given its string key.

Parameters:
  • loss_type (str) – string key matching one of the members of LossType.

  • name (str) – target name for the loss.

  • gradient (str | None) – gradient name, if present.

  • weight (float) – weight of the loss contribution in the final aggregation.

  • reduction (str) – reduction mode for the torch loss.

  • **extra_kwargs (Any) – additional hyperparameters specific to the loss type.

Returns:

instance of the selected loss.

Return type:

LossInterface