Source code for sparkdq.checks.aggregate.count_checks.count_exact_check

from pydantic import Field, model_validator
from pyspark.sql import DataFrame

from sparkdq.core.base_check import BaseAggregateCheck
from sparkdq.core.base_config import BaseAggregateCheckConfig
from sparkdq.core.check_results import AggregateEvaluationResult
from sparkdq.core.severity import Severity
from sparkdq.exceptions import InvalidCheckConfigurationError
from sparkdq.plugin.check_config_registry import register_check_config


class RowCountExactCheck(BaseAggregateCheck):
    """
    Aggregate-level data quality check that ensures a DataFrame contains exactly a specified number of rows.

    This check is useful for strict data contract validations where the row count must match a known value,
    such as reference datasets, test snapshots, or controlled static inputs.

    Attributes:
        expected_count (int): The exact number of rows expected.
    """

    def __init__(self, check_id: str, expected_count: int, severity: Severity = Severity.CRITICAL):
        """
        Initialize a new RowCountExactCheck instance.

        Args:
            check_id (str): Unique identifier for the check instance.
            expected_count (int): The required number of rows.
            severity (Severity, optional): The severity level to assign if the check fails.
                Defaults to Severity.CRITICAL.
        """
        super().__init__(check_id=check_id, severity=severity)
        self.expected_count = expected_count

    def _evaluate_logic(self, df: DataFrame) -> AggregateEvaluationResult:
        """
        Evaluate the row count of the given DataFrame against the exact expected count.

        If the row count equals `expected_count`, the check passes.
        Otherwise, it fails and the actual count is included in the result metrics.

        Args:
            df (DataFrame): The Spark DataFrame to evaluate.

        Returns:
            AggregateEvaluationResult: An object indicating the outcome of the check,
            including relevant row count metrics.
        """
        actual = df.count()
        passed = actual == self.expected_count
        return AggregateEvaluationResult(
            passed=passed,
            metrics={
                "actual_row_count": actual,
                "expected_row_count": self.expected_count,
            },
        )


[docs] @register_check_config(check_name="row-count-exact-check") class RowCountExactCheckConfig(BaseAggregateCheckConfig): """ Declarative configuration model for the RowCountExactCheck. This configuration defines an exact row count requirement for a dataset. It ensures that the ``expected_count`` parameter is provided and is non-negative. Attributes: expected_count (int): The exact number of rows expected in the dataset. """ check_class = RowCountExactCheck expected_count: int = Field(..., description="Exact number of rows required", alias="expected-count")
[docs] @model_validator(mode="after") def validate_expected(self) -> "RowCountExactCheckConfig": """ Validate that the configured expected_count is greater than 0. Returns: RowCountExactCheckConfig: The validated configuration object. Raises: InvalidCheckConfigurationError: If ``expected_count`` is negative. """ if self.expected_count < 0: raise InvalidCheckConfigurationError( f"expected-count ({self.expected_count}) must be zero or positive" ) return self