Skip to content

sparkdq.engine

The sparkdq.engine subpackage contains the execution logic for data quality validation.

engine

BatchDQEngine

Bases: BaseDQEngine

Engine for executing data quality checks on Spark DataFrames in batch mode.

This engine applies both row-level and aggregate-level checks using the BatchCheckRunner, and annotates the DataFrame with error metadata.

Source code in sparkdq/engine/batch/dq_engine.py
class BatchDQEngine(BaseDQEngine):
    """
    Engine for executing data quality checks on Spark DataFrames in batch mode.

    This engine applies both row-level and aggregate-level checks using the
    ``BatchCheckRunner``, and annotates the DataFrame with error metadata.
    """

    def run_batch(
        self, df: DataFrame, reference_datasets: Optional[ReferenceDatasetDict] = None
    ) -> BatchValidationResult:
        """
        Run all registered checks against the given DataFrame.

        This method applies both row-level and aggregate-level checks and
        returns a validation result containing the annotated DataFrame and
        the aggregated check results.

        Args:
            df (DataFrame): The input Spark DataFrame to validate.
            reference_datasets (ReferenceDatasetDict, optional):
                A dictionary of named reference DataFrames used by integrity checks.
                Required for checks that compare values against external datasets
                (e.g., foreign key validation). Each key should match the
                `reference_dataset` name expected by the check.

        Returns:
            BatchValidationResult: Object containing the validated DataFrame,
            aggregate check results, and the original input schema.
        """
        if self.check_set is None:
            raise MissingCheckSetError()
        input_columns = df.columns
        runner = BatchCheckRunner(self.fail_levels)
        validated_df, aggregate_results = runner.run(df, self.check_set.get_all(), reference_datasets)

        return BatchValidationResult(validated_df, aggregate_results, input_columns)

run_batch

run_batch(
    df: DataFrame,
    reference_datasets: Optional[
        ReferenceDatasetDict
    ] = None,
) -> BatchValidationResult

Run all registered checks against the given DataFrame.

This method applies both row-level and aggregate-level checks and returns a validation result containing the annotated DataFrame and the aggregated check results.

Parameters:

Name Type Description Default
df DataFrame

The input Spark DataFrame to validate.

required
reference_datasets ReferenceDatasetDict

A dictionary of named reference DataFrames used by integrity checks. Required for checks that compare values against external datasets (e.g., foreign key validation). Each key should match the reference_dataset name expected by the check.

None

Returns:

Name Type Description
BatchValidationResult BatchValidationResult

Object containing the validated DataFrame,

BatchValidationResult

aggregate check results, and the original input schema.

Source code in sparkdq/engine/batch/dq_engine.py
def run_batch(
    self, df: DataFrame, reference_datasets: Optional[ReferenceDatasetDict] = None
) -> BatchValidationResult:
    """
    Run all registered checks against the given DataFrame.

    This method applies both row-level and aggregate-level checks and
    returns a validation result containing the annotated DataFrame and
    the aggregated check results.

    Args:
        df (DataFrame): The input Spark DataFrame to validate.
        reference_datasets (ReferenceDatasetDict, optional):
            A dictionary of named reference DataFrames used by integrity checks.
            Required for checks that compare values against external datasets
            (e.g., foreign key validation). Each key should match the
            `reference_dataset` name expected by the check.

    Returns:
        BatchValidationResult: Object containing the validated DataFrame,
        aggregate check results, and the original input schema.
    """
    if self.check_set is None:
        raise MissingCheckSetError()
    input_columns = df.columns
    runner = BatchCheckRunner(self.fail_levels)
    validated_df, aggregate_results = runner.run(df, self.check_set.get_all(), reference_datasets)

    return BatchValidationResult(validated_df, aggregate_results, input_columns)

BatchValidationResult dataclass

Encapsulates the results of a batch data quality validation run.

Includes: - The validated Spark DataFrame, annotated with validation metadata. - A list of results from aggregate-level checks. - The original input column names, used to restore the pre-validation structure.

This class provides convenience methods to access only the passing or failing rows, making it easier to route or analyze validated data downstream.

Attributes:

Name Type Description
df DataFrame

The DataFrame after validation, including:

  • _dq_passed (bool): Row-level pass/fail status.
  • _dq_errors (array): Structured errors from failed checks.
  • _dq_aggregate_errors (array, optional): Errors from failed aggregates.
aggregate_results List[AggregateCheckResult]

Results of all aggregate checks.

input_columns List[str]

Names of the original input columns.

timestamp datetime

Timestamp of when the validation result was created.

Source code in sparkdq/engine/batch/validation_result.py
@dataclass(frozen=True)
class BatchValidationResult:
    """
    Encapsulates the results of a batch data quality validation run.

    Includes:
    - The validated Spark DataFrame, annotated with validation metadata.
    - A list of results from aggregate-level checks.
    - The original input column names, used to restore the pre-validation structure.

    This class provides convenience methods to access only the passing or failing rows,
    making it easier to route or analyze validated data downstream.

    Attributes:
        df (DataFrame): The DataFrame after validation, including:

            - _dq_passed (bool): Row-level pass/fail status.
            - _dq_errors (array): Structured errors from failed checks.
            - _dq_aggregate_errors (array, optional): Errors from failed aggregates.

        aggregate_results (List[AggregateCheckResult]): Results of all aggregate checks.
        input_columns (List[str]): Names of the original input columns.
        timestamp (datetime): Timestamp of when the validation result was created.
    """

    df: DataFrame
    aggregate_results: List[AggregateCheckResult]
    input_columns: List[str]
    timestamp: datetime = field(default_factory=datetime.now)

    def pass_df(self) -> DataFrame:
        """
        Return only the rows that passed all critical checks.

        This method filters for rows where `_dq_passed` is true and restores
        the original column structure from the input DataFrame.

        Returns:
            DataFrame: DataFrame containing only valid rows, with original schema.
        """
        return self.df.filter("_dq_passed").select(*self.input_columns)

    def fail_df(self) -> DataFrame:
        """
        Return only the rows that failed one or more critical checks.

        This includes validation metadata such as ``_dq_errors``, ``_dq_passed``, and,
        if present, ``_dq_aggregate_errors``. Additionally, a ``_dq_validation_ts`` column
        is added for downstream auditing or tracking.

        Returns:
            DataFrame: DataFrame containing invalid rows and relevant error metadata.
        """
        cols = self.input_columns + ["_dq_errors", "_dq_passed"]
        if "_dq_aggregate_errors" in self.df.columns:
            cols.append("_dq_aggregate_errors")
        df = self.df.filter("NOT _dq_passed").select(*cols)
        df = df.withColumn("_dq_validation_ts", lit(self.timestamp))
        return df

    def warn_df(self) -> DataFrame:
        """
        Returns rows that passed all critical checks but contain warning-level violations.

        These are rows where ``_dq_passed`` is True, but the ``_dq_errors`` array contains
        at least one entry with severity == **WARNING**.

        Returns:
            DataFrame: Filtered DataFrame of rows with warnings.
        """
        df = self.df.filter(
            col("_dq_passed")
            & (array_size(col("_dq_errors")) > 0)
            & array_contains(expr("transform(_dq_errors, x -> x.severity)"), Severity.WARNING.value)
        ).select(*self.input_columns + ["_dq_errors"])
        df = df.withColumn("_dq_validation_ts", lit(self.timestamp))
        return df

    def summary(self) -> ValidationSummary:
        """
        Create a summary of the validation results.

        This includes total record count, pass/fail statistics,
        and number of rows with warning-level errors.

        Returns:
            ValidationSummary: Structured summary of the validation outcome.
        """
        total = self.df.count()
        passed = self.df.filter("_dq_passed").count()
        failed = total - passed
        warnings = self.df.filter(
            col("_dq_passed")
            & (array_size(col("_dq_errors")) > 0)
            & array_contains(expr("transform(_dq_errors, x -> x.severity)"), Severity.WARNING.value)
        ).count()
        rate = round(passed / total if total else 0.0, 2)
        return ValidationSummary(
            total_records=total,
            passed_records=passed,
            failed_records=failed,
            warning_records=warnings,
            pass_rate=rate,
            timestamp=self.timestamp,
        )

pass_df

pass_df() -> DataFrame

Return only the rows that passed all critical checks.

This method filters for rows where _dq_passed is true and restores the original column structure from the input DataFrame.

Returns:

Name Type Description
DataFrame DataFrame

DataFrame containing only valid rows, with original schema.

Source code in sparkdq/engine/batch/validation_result.py
def pass_df(self) -> DataFrame:
    """
    Return only the rows that passed all critical checks.

    This method filters for rows where `_dq_passed` is true and restores
    the original column structure from the input DataFrame.

    Returns:
        DataFrame: DataFrame containing only valid rows, with original schema.
    """
    return self.df.filter("_dq_passed").select(*self.input_columns)

fail_df

fail_df() -> DataFrame

Return only the rows that failed one or more critical checks.

This includes validation metadata such as _dq_errors, _dq_passed, and, if present, _dq_aggregate_errors. Additionally, a _dq_validation_ts column is added for downstream auditing or tracking.

Returns:

Name Type Description
DataFrame DataFrame

DataFrame containing invalid rows and relevant error metadata.

Source code in sparkdq/engine/batch/validation_result.py
def fail_df(self) -> DataFrame:
    """
    Return only the rows that failed one or more critical checks.

    This includes validation metadata such as ``_dq_errors``, ``_dq_passed``, and,
    if present, ``_dq_aggregate_errors``. Additionally, a ``_dq_validation_ts`` column
    is added for downstream auditing or tracking.

    Returns:
        DataFrame: DataFrame containing invalid rows and relevant error metadata.
    """
    cols = self.input_columns + ["_dq_errors", "_dq_passed"]
    if "_dq_aggregate_errors" in self.df.columns:
        cols.append("_dq_aggregate_errors")
    df = self.df.filter("NOT _dq_passed").select(*cols)
    df = df.withColumn("_dq_validation_ts", lit(self.timestamp))
    return df

warn_df

warn_df() -> DataFrame

Returns rows that passed all critical checks but contain warning-level violations.

These are rows where _dq_passed is True, but the _dq_errors array contains at least one entry with severity == WARNING.

Returns:

Name Type Description
DataFrame DataFrame

Filtered DataFrame of rows with warnings.

Source code in sparkdq/engine/batch/validation_result.py
def warn_df(self) -> DataFrame:
    """
    Returns rows that passed all critical checks but contain warning-level violations.

    These are rows where ``_dq_passed`` is True, but the ``_dq_errors`` array contains
    at least one entry with severity == **WARNING**.

    Returns:
        DataFrame: Filtered DataFrame of rows with warnings.
    """
    df = self.df.filter(
        col("_dq_passed")
        & (array_size(col("_dq_errors")) > 0)
        & array_contains(expr("transform(_dq_errors, x -> x.severity)"), Severity.WARNING.value)
    ).select(*self.input_columns + ["_dq_errors"])
    df = df.withColumn("_dq_validation_ts", lit(self.timestamp))
    return df

summary

summary() -> ValidationSummary

Create a summary of the validation results.

This includes total record count, pass/fail statistics, and number of rows with warning-level errors.

Returns:

Name Type Description
ValidationSummary ValidationSummary

Structured summary of the validation outcome.

Source code in sparkdq/engine/batch/validation_result.py
def summary(self) -> ValidationSummary:
    """
    Create a summary of the validation results.

    This includes total record count, pass/fail statistics,
    and number of rows with warning-level errors.

    Returns:
        ValidationSummary: Structured summary of the validation outcome.
    """
    total = self.df.count()
    passed = self.df.filter("_dq_passed").count()
    failed = total - passed
    warnings = self.df.filter(
        col("_dq_passed")
        & (array_size(col("_dq_errors")) > 0)
        & array_contains(expr("transform(_dq_errors, x -> x.severity)"), Severity.WARNING.value)
    ).count()
    rate = round(passed / total if total else 0.0, 2)
    return ValidationSummary(
        total_records=total,
        passed_records=passed,
        failed_records=failed,
        warning_records=warnings,
        pass_rate=rate,
        timestamp=self.timestamp,
    )