Source code for sparkdq.checks.aggregate.uniqueness_checks.unique_rows_check
from typing import List, Optional
from pydantic import Field
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from sparkdq.core.base_check import BaseAggregateCheck
from sparkdq.core.base_config import BaseAggregateCheckConfig
from sparkdq.core.check_results import AggregateEvaluationResult
from sparkdq.core.severity import Severity
from sparkdq.exceptions import MissingColumnError
from sparkdq.plugin.check_config_registry import register_check_config
class UniqueRowsCheck(BaseAggregateCheck):
    """
    Aggregate-level data quality check that ensures row-level uniqueness in the DataFrame.
    A row is considered non-unique if the same combination of values appears more than once,
    either across all columns or a user-defined subset.
    """
    def __init__(
        self,
        check_id: str,
        subset_columns: Optional[List[str]] = None,
        severity: Severity = Severity.CRITICAL,
    ):
        """
        Initialize a UniqueRowsCheck instance.
        Args:
            check_id (str): Unique identifier for the check.
            subset_columns (Optional[List[str]]): List of columns to define uniqueness.
                If None, all columns are used.
            severity (Severity): Severity level assigned if the check fails.
        """
        super().__init__(check_id=check_id, severity=severity)
        self.subset_columns = subset_columns
    def _evaluate_logic(self, df: DataFrame) -> AggregateEvaluationResult:
        """
        Evaluate the uniqueness of rows in the DataFrame.
        Returns:
            AggregateEvaluationResult: Indicates whether any duplicated row combinations exist,
            and reports how many were found.
        """
        cols = self.subset_columns or df.columns
        for col in cols:
            if col not in df.columns:
                raise MissingColumnError(col, df.columns)
        duplicate_groups = df.groupBy(cols).count().filter(F.col("count") > 1)
        duplicate_count = duplicate_groups.count()
        passed = duplicate_count == 0
        return AggregateEvaluationResult(
            passed=passed,
            metrics={
                "duplicate_row_groups": duplicate_count,
                "checked_columns": cols,
            },
        )
[docs]
@register_check_config(check_name="unique-rows-check")
class UniqueRowsCheckConfig(BaseAggregateCheckConfig):
    """
    Declarative configuration for the UniqueRowsCheck.
    This check verifies that no duplicate row combinations exist in the dataset.
    Uniqueness can be enforced across all columns or a selected subset.
    Attributes:
        subset_columns (Optional[List[str]]): List of columns to define uniqueness.
            If not provided, all columns are used.
    """
    check_class = UniqueRowsCheck
    subset_columns: Optional[List[str]] = Field(
        default=None,
        alias="subset-columns",
        description="List of columns used to determine row uniqueness. Defaults to all columns.",
    )