Loss¶
- class metatrain.utils.loss.LossParams[source]¶
Bases:
TypedDict- type: NotRequired[str] = 'mse'¶
- weight: NotRequired[float] = 1.0¶
- reduction: NotRequired[Literal['none', 'mean', 'sum']] = 'mean'¶
- class metatrain.utils.loss.LossSpecification[source]¶
Bases:
TypedDict- type: NotRequired[str] = 'mse'¶
- weight: NotRequired[float] = 1.0¶
- reduction: NotRequired[Literal['none', 'mean', 'sum']] = 'mean'¶
- gradients: NotRequired[dict[str, LossParams]] = {}¶
- class metatrain.utils.loss.LossInterface(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
ABCAbstract base for all loss functions.
Subclasses must implement the
computemethod.- Parameters:
- abstractmethod compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Any | None = None) Tensor[source]¶
Compute the loss value.
- Parameters:
- Returns:
Value of the loss.
- Return type:
- class metatrain.utils.loss.BaseTensorMapLoss(name: str, gradient: str | None, weight: float, reduction: str, *, loss_fn: _Loss)[source]¶
Bases:
LossInterfaceBackbone for pointwise losses on
TensorMapentries.Provides a compute_flattened() helper that extracts values or gradients, flattens them, applies an optional mask, and computes the torch loss.
- Parameters:
- compute_flattened(tensor_map_predictions_for_target: TensorMap, tensor_map_targets_for_target: TensorMap, tensor_map_mask_for_target: TensorMap | None = None) Tensor[source]¶
Flatten prediction and target blocks (and optional mask), then apply the torch loss.
- class metatrain.utils.loss.MaskedTensorMapLoss(name: str, gradient: str | None, weight: float, reduction: str, *, loss_fn: _Loss)[source]¶
Bases:
BaseTensorMapLossPointwise masked loss on
TensorMapentries.Inherits flattening and torch-loss logic from BaseTensorMapLoss.
- compute(predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap], extra_data: Dict[str, TensorMap] | None = None) Tensor[source]¶
Gather and flatten target and prediction blocks, then compute loss.
- Parameters:
predictions (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.
targets (Dict[str, TensorMap]) – Mapping from target names to TensorMaps.
extra_data (Dict[str, TensorMap] | None) – Additional data for loss computation. Assumes that, for the target
nameused in the constructor, there is a corresponding data fieldname + "_mask"that contains the tensor to be used for masking. It should have the same metadata as the target and prediction tensors.
- Returns:
Scalar loss tensor.
- Return type:
- class metatrain.utils.loss.TensorMapMSELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
BaseTensorMapLossUnmasked mean-squared error on
TensorMapentries.
- class metatrain.utils.loss.TensorMapMAELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
BaseTensorMapLossUnmasked mean-absolute error on
TensorMapentries.
- class metatrain.utils.loss.TensorMapHuberLoss(name: str, gradient: str | None, weight: float, reduction: str, delta: float)[source]¶
Bases:
BaseTensorMapLossUnmasked Huber loss on
TensorMapentries.
- class metatrain.utils.loss.TensorMapMaskedMSELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
MaskedTensorMapLossMasked mean-squared error on
TensorMapentries.
- class metatrain.utils.loss.TensorMapMaskedMAELoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
MaskedTensorMapLossMasked mean-absolute error on
TensorMapentries.
- class metatrain.utils.loss.TensorMapMaskedHuberLoss(name: str, gradient: str | None, weight: float, reduction: str, delta: float)[source]¶
Bases:
MaskedTensorMapLossMasked Huber loss on
TensorMapentries.
- class metatrain.utils.loss.ShiftAgnosticMSE(name: str, gradient: str | None, weight: float, int_weight: float, grad_penalty_weight: float, reduction: str)[source]¶
Bases:
LossInterfaceShift agnostic MSE loss on
TensorMapentries.This loss assumes that the target is some kind of profile along the properties of the
TensorBlock. It finds the rigid shift between the predictions and targets that minimizes the MSE, and returns that minimal MSE.- Parameters:
name (str) – key for the target in the prediction/target dictionary.
gradient (str | None) – optional gradient field name.
weight (float) – weight of the loss contribution in the final aggregation.
int_weight (float) – The loss function can also contain the MSE on the cumulative profile. This number weights the contribution of the cumulative term in the final loss. If 0, no cumulative term is added.
grad_penalty_weight (float) – The loss function penalizes gradients of the predicted profiles in the regions where the target is NaN. This number weights the contribution of the penalty term in the final loss. If 0, the predictions on those regions are free to be what they want.
reduction (str) – reduction mode for torch loss.
- class metatrain.utils.loss.TensorMapEnsembleLoss(name: str, gradient: str | None, weight: float, reduction: str, loss_fn: Module)[source]¶
Bases:
BaseTensorMapLossLoss for ensembles based on
TensorMapentries. Assumes that ensemble is the outermost dimension ofTensorBlockproperties.- Parameters:
- compute_flattened(pred_mean: TensorMap, target: TensorMap, pred_var: TensorMap) Tensor[source]¶
Flatten prediction and target blocks (and optional mask), then apply the torch loss.
- class metatrain.utils.loss.GaussianCRPSLoss(reduction: str = 'mean', eps: float = 1e-12)[source]¶
Bases:
ModuleGaussian CRPS loss.
This implements the closed-form expression for the CRPS of a Gaussian predictive distribution \(\mathcal{N}(\mu, \sigma^2)\) evaluated at a target value \(x\):
\[\text{CRPS}(x; \mu, \sigma) = \sigma \left[ z(2\Phi(z) - 1) + 2\phi(z) - \frac{1}{\sqrt{\pi}} \right]\]where \(z = \frac{x - \mu}{\sigma}\), \(\Phi\) is the standard normal CDF, and \(\phi\) is the standard normal PDF.
- Parameters:
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class metatrain.utils.loss.EmpiricalCRPSLoss(reduction: str = 'mean')[source]¶
Bases:
ModuleEmpirical CRPS loss for ensemble predictions.
The ensemble predictions \(\{Y_i\}_{i=1}^M\) for each data point define an empirical predictive distribution:
\[F_M(y) = \frac{1}{M} \sum_{i=1}^M \mathbb{1}_{Y_i \le y}\]The CRPS of this empirical distribution at observation \(z\) has the closed form:
\[\text{CRPS}(F_M, z) = \frac{1}{M} \sum_{i=1}^M |Y_i - z| - \frac{1}{2 M^2} \sum_{i,j} |Y_i - Y_j|\]- Parameters:
reduction (str) – ‘none’, ‘mean’, or ‘sum’.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class metatrain.utils.loss.TensorMapGaussianNLLLoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
TensorMapEnsembleLossGaussian negative log-likelihood loss for
TensorMapentries.
- class metatrain.utils.loss.TensorMapGaussianCRPSLoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
TensorMapEnsembleLossGaussian CRPS loss for
TensorMapentries.
- class metatrain.utils.loss.TensorMapEmpiricalCRPSLoss(name: str, gradient: str | None, weight: float, reduction: str)[source]¶
Bases:
TensorMapEnsembleLossEmpirical CRPS loss for
TensorMapentries.- Parameters:
- class metatrain.utils.loss.LossAggregator(targets: Dict[str, TargetInfo], config: Dict[str, LossSpecification])[source]¶
Bases:
LossInterfaceAggregate multiple
LossInterfaceterms with scheduled weights and metadata.- Parameters:
targets (Dict[str, TargetInfo]) – mapping from target names to
TargetInfo.config (Dict[str, LossSpecification]) – per-target configuration dict.
- class metatrain.utils.loss.LossType(*values)[source]¶
Bases:
EnumEnumeration of available loss types and their implementing classes.
- Parameters:
key – string key for the loss type.
cls – class implementing the loss type.
- MSE = ('mse', <class 'metatrain.utils.loss.TensorMapMSELoss'>)¶
- MAE = ('mae', <class 'metatrain.utils.loss.TensorMapMAELoss'>)¶
- HUBER = ('huber', <class 'metatrain.utils.loss.TensorMapHuberLoss'>)¶
- MASKED_MSE = ('masked_mse', <class 'metatrain.utils.loss.TensorMapMaskedMSELoss'>)¶
- MASKED_MAE = ('masked_mae', <class 'metatrain.utils.loss.TensorMapMaskedMAELoss'>)¶
- MASKED_HUBER = ('masked_huber', <class 'metatrain.utils.loss.TensorMapMaskedHuberLoss'>)¶
- POINTWISE = ('pointwise', <class 'metatrain.utils.loss.BaseTensorMapLoss'>)¶
- MASKED_POINTWISE = ('masked_pointwise', <class 'metatrain.utils.loss.MaskedTensorMapLoss'>)¶
- SHIFT_AGNOSTIC_MSE = ('shift_agnostic_mse', <class 'metatrain.utils.loss.ShiftAgnosticMSE'>)¶
- GAUSSIAN_NLL = ('gaussian_nll_ensemble', <class 'metatrain.utils.loss.TensorMapGaussianNLLLoss'>)¶
- GAUSSIAN_CRPS = ('gaussian_crps_ensemble', <class 'metatrain.utils.loss.TensorMapGaussianCRPSLoss'>)¶
- EMPIRICAL_CRPS = ('empirical_crps_ensemble', <class 'metatrain.utils.loss.TensorMapEmpiricalCRPSLoss'>)¶
- property cls: Type[LossInterface]¶
Class implementing this loss type.
- metatrain.utils.loss.create_loss(loss_type: str, *, name: str, gradient: str | None, weight: float, reduction: str, **extra_kwargs: Any) LossInterface[source]¶
Factory to instantiate a concrete
LossInterfacegiven its string key.- Parameters:
loss_type (str) – string key matching one of the members of
LossType.name (str) – target name for the loss.
gradient (str | None) – gradient name, if present.
weight (float) – weight of the loss contribution in the final aggregation.
reduction (str) – reduction mode for the torch loss.
**extra_kwargs (Any) – additional hyperparameters specific to the loss type.
- Returns:
instance of the selected loss.
- Return type: