import sys
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
TypeVar,
Union,
get_args,
get_origin,
get_type_hints,
)
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
import pandas as pd
if TYPE_CHECKING:
import narwhals as nw # noqa: F401
import pandas as pd # noqa: F401
import polars as pl # noqa: F401
from .backends.base import BackendAdapter
from .exceptions import SchemaError
from .typing import Col
TStructFrame = TypeVar("TStructFrame", bound="BaseSchema")
[docs]
class FieldInfo:
"""Stores metadata for column mapping and field-level validation.
Attributes:
alias: Map this attribute to a differently-named DataFrame column.
ge: Value must be greater than or equal to this.
gt: Value must be strictly greater than this.
le: Value must be less than or equal to this.
lt: Value must be strictly less than this.
isin: Value must be one of these allowed values.
regex: String value must match this regex pattern.
min_length: String value must be at least this long.
max_length: String value must be at most this long.
nullable: Whether NaN/None values are allowed (default True).
unique: Whether all values must be unique (default False).
"""
def __init__(
self,
alias: Optional[str] = None,
ge: Optional[float] = None,
gt: Optional[float] = None,
le: Optional[float] = None,
lt: Optional[float] = None,
isin: Optional[List[Any]] = None,
regex: Optional[str] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
nullable: bool = True,
unique: bool = False,
):
self.alias = alias
self.ge = ge
self.gt = gt
self.le = le
self.lt = lt
self.isin = isin
self.regex = regex
self.min_length = min_length
self.max_length = max_length
self.nullable = nullable
self.unique = unique
def __repr__(self) -> str:
parts = []
for attr in [
"alias",
"ge",
"gt",
"le",
"lt",
"isin",
"regex",
"min_length",
"max_length",
"nullable",
"unique",
]:
val = getattr(self, attr)
# Skip defaults
if attr == "nullable" and val is True:
continue
if attr == "unique" and val is False:
continue
if val is not None:
parts.append(f"{attr}={val!r}")
return f"Field({', '.join(parts)})" if parts else "Field()"
[docs]
def Field( # noqa: N802 (function name should be lowercase)
alias: Optional[str] = None,
ge: Optional[float] = None,
gt: Optional[float] = None,
le: Optional[float] = None,
lt: Optional[float] = None,
isin: Optional[List[Any]] = None,
regex: Optional[str] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
nullable: bool = True,
unique: bool = False,
) -> Any:
"""Helper function to define a field's properties and constraints.
Args:
alias: Map this attribute to a differently-named DataFrame column.
ge: Value must be >= this threshold.
gt: Value must be > this threshold.
le: Value must be <= this threshold.
lt: Value must be < this threshold.
isin: Value must be one of these allowed values.
regex: String value must match this regex pattern.
min_length: String value must be at least this long.
max_length: String value must be at most this long.
nullable: Whether NaN/None values are allowed (default True).
unique: Whether all values must be unique (default False).
Returns:
A FieldInfo metadata object consumed by Schema during class creation.
"""
return FieldInfo(
alias=alias,
ge=ge,
gt=gt,
le=le,
lt=lt,
isin=isin,
regex=regex,
min_length=min_length,
max_length=max_length,
nullable=nullable,
unique=unique,
)
[docs]
class BaseSchema:
"""Base class for the Object-DataFrame Mapper (ODM).
Define your DataFrame schema as a Python class with typed attributes.
Schema validates column existence, runtime dtypes, and field-level
constraints, while providing IDE-friendly autocomplete and type safety.
**Do not instantiate BaseSchema directly.** Use backend-specific Schema classes:
from frameright.pandas import Schema # For pandas
from frameright.polars.eager import Schema # For polars eager
from frameright.polars.lazy import Schema # For polars lazy
Example with pandas::
import pandas as pd
from frameright.pandas import Schema, Col
class Orders(Schema):
order_id: Col[int]
revenue: Col[float]
orders = Orders(pd.DataFrame(...))
orders.revenue # Returns pd.Series
orders.revenue.sum() # Use pandas methods
Example with polars::
import polars as pl
from frameright.polars.eager import Schema, Col
class Orders(Schema):
order_id: Col[int]
revenue: Col[float]
orders = Orders(pl.DataFrame(...))
orders.revenue # Returns pl.Series
orders.revenue.sum() # Use polars methods
Use ``fr_data`` to access the underlying DataFrame.
"""
# Stores the parsed schema for the specific child class
_fr_schema: Dict[str, Dict[str, Any]]
_fr_backend: BackendAdapter # Set by concrete subclasses (must be non-None)
def __init__(
self,
df: Any,
copy: bool = False,
validate: bool = True,
validate_types: bool = True,
coerce: bool = False,
coerce_errors: str = "raise",
strict: bool = False,
):
"""Initialise the Schema wrapper.
Args:
df: The DataFrame to wrap.
copy: If True, copy the DataFrame. Defaults to False to save memory.
validate: If True, run schema validation on construction. Defaults to True.
validate_types: If True, also check runtime dtypes during validation.
Only used when ``validate`` is True.
coerce: If True, attempt to convert DataFrame columns to match the schema's
type annotations before validation. Useful for data from sources
that don't preserve dtypes (e.g., CSV files). Defaults to False.
coerce_errors: How to handle coercion errors when ``coerce`` is True.
'raise' (default), 'coerce' (set failures to NaN), or 'ignore'.
strict: If True, reject DataFrames with columns not defined in the schema.
Defaults to False (extra columns are allowed).
"""
# Concrete subclasses must set _fr_backend at class level
if not hasattr(self.__class__, "_fr_backend") or self._fr_backend is None:
raise RuntimeError(
f"{self.__class__.__name__} must set _fr_backend. "
"Use a backend-specific class like frameright.pandas.Schema"
)
self._fr_df = self._fr_backend.copy(df) if copy else df
# Apply type coercion if requested
if coerce:
for attr_name, meta in self.__class__._fr_schema.items():
col = meta["df_col"]
inner_type = meta["inner_type"]
field_info = meta["field_info"]
if (
not self._fr_backend.has_column(self._fr_df, col)
or inner_type is None
):
continue
self._fr_df = self._fr_backend.coerce_column(
self._fr_df,
col,
inner_type,
errors=coerce_errors,
nullable=field_info.nullable,
)
if validate:
self.fr_validate(check_types=validate_types, strict=strict)
def __init_subclass__(cls, **kwargs: Any) -> None:
"""Metaclass hook to parse the schema and inject properties at load time."""
super().__init_subclass__(**kwargs)
cls._fr_schema = {}
# Resolve type hints with Col injected into the namespace so
# that ``from __future__ import annotations`` and TYPE_CHECKING-guarded
# imports both work without NameError at runtime.
module = sys.modules.get(cls.__module__)
globalns = dict(vars(module)) if module else {}
localns: Dict[str, Any] = {"Col": Col}
hints = get_type_hints(cls, globalns=globalns, localns=localns)
for attr_name, attr_type in hints.items():
if attr_name.startswith("_"):
continue
# 1. Parse Type Hints (Handle Col[T] and Optional[Col[T]])
origin = get_origin(attr_type)
args = get_args(attr_type)
is_optional = origin is Union and type(None) in args
# Extract the actual Col[T] if it was wrapped in Optional
col_type = (
next((a for a in args if get_origin(a) is Col), attr_type)
if is_optional
else attr_type
)
# Validate that the annotation is Col[T] or Optional[Col[T]]
if get_origin(col_type) is not Col:
raise SchemaError(
f"Attribute '{attr_name}' in {cls.__name__} must be annotated as "
f"Col[T] or Optional[Col[T]], got {attr_type}"
)
# Extract the inner primitive type (e.g., float from Col[float])
inner_type = (
get_args(col_type)[0]
if get_origin(col_type) is Col and get_args(col_type)
else None
)
# Handle Union types inside Col (e.g., Col[str | None] -> str)
if inner_type is not None and get_origin(inner_type) is Union:
union_args = get_args(inner_type)
inner_type = next((t for t in union_args if t is not type(None)), None)
# 2. Parse Field Metadata (Alias and Validation constraints)
class_var = getattr(cls, attr_name, None)
if isinstance(class_var, FieldInfo):
field_info = class_var
actual_df_col = field_info.alias or attr_name
else:
# Check parent classes for inherited FieldInfo
field_info = None
for base in cls.__mro__[1:]:
base_schema = getattr(base, "_fr_schema", {})
if attr_name in base_schema:
field_info = base_schema[attr_name]["field_info"]
break
if field_info is None:
field_info = FieldInfo()
actual_df_col = field_info.alias or attr_name
# Store the parsed schema for validation later
cls._fr_schema[attr_name] = {
"df_col": actual_df_col,
"inner_type": inner_type,
"field_info": field_info,
"is_optional": is_optional,
}
# 3. Inject the safe Property wrapper
def make_property(col_name: str, optional_flag: bool) -> property:
def getter(self: "BaseSchema") -> Any:
if optional_flag and not self._fr_backend.has_column(
self._fr_df, col_name
):
return None
# For LazyFrames, use get_column_ref() to return expressions (pl.Expr)
# For eager DataFrames, use get_column() to return Series
if hasattr(self._fr_backend, "get_column_ref") and hasattr(
self._fr_df, "__class__"
):
# Check if it's a LazyFrame (polars backend)
df_type = type(self._fr_df).__name__
if df_type == "LazyFrame":
return self._fr_backend.get_column_ref(
self._fr_df, col_name
)
# Return native Series directly (pd.Series, pl.Series, or nw.Series)
return self._fr_backend.get_column(self._fr_df, col_name)
def setter(self: "BaseSchema", value: Any) -> None:
self._fr_df = self._fr_backend.set_column(
self._fr_df, col_name, value
)
return property(getter, setter)
setattr(cls, attr_name, make_property(actual_df_col, is_optional))
# ------------------------------------------------------------------
# Core Methods (Prefixed with fr_ to avoid namespace collisions)
# ------------------------------------------------------------------
[docs]
def fr_validate(self, check_types: bool = True, strict: bool = False) -> Self:
"""Validate column existence, runtime dtypes, and field-level constraints.
Uses Pandera for validation, with errors translated into Schema
exception types (MissingColumnError, TypeMismatchError,
ConstraintViolationError).
Args:
check_types: If True, also validate that column dtypes match the
type annotations. Defaults to True.
strict: If True, reject DataFrames with columns not defined in the schema.
Defaults to False (extra columns are allowed).
Returns:
self, for method chaining.
Raises:
MissingColumnError: If a required column is not present.
TypeMismatchError: If a column's dtype doesn't match the annotation.
ConstraintViolationError: If a field-level constraint is violated.
"""
schema = self._fr_backend.build_pandera_schema(
self.__class__._fr_schema,
self._fr_df,
check_types=check_types,
strict=strict,
)
self._fr_backend.validate_with_pandera(self._fr_df, schema, lazy=True)
return self
@property
def fr_data(self) -> Any:
"""Return the underlying DataFrame.
For pandas backend, returns ``pd.DataFrame``.
For polars backend, returns ``pl.DataFrame`` or ``pl.LazyFrame``.
For narwhals backend, returns ``nw.DataFrame``.
This property gives direct access to the DataFrame for performing
operations using the backend's native API::
# Pandas operations
df.fr_data.groupby('column').sum()
# Polars operations
df.fr_data.filter(pl.col('x') > 5)
# LazyFrame operations
lazy_df.fr_data.collect()
"""
return self._fr_df
def __len__(self) -> int:
"""Return the number of rows in the DataFrame.
For LazyFrames, this will trigger execution (collect) to get the row count.
If you want to avoid execution, use `.fr_data.collect().shape[0]` instead.
"""
return self._fr_backend.num_rows(self._fr_df)
# ------------------------------------------------------------------
# Schema Introspection
# ------------------------------------------------------------------
[docs]
@classmethod
def fr_schema_info(cls) -> List[Dict[str, Any]]:
"""Return the schema definition as a list of dictionaries.
Returns:
A list of dicts, one per column, with keys: attribute, column,
type, required, nullable, unique, constraints.
"""
rows: List[Dict[str, Any]] = []
for attr_name, meta in cls._fr_schema.items():
fi: FieldInfo = meta["field_info"]
inner = meta["inner_type"]
constraints: Dict[str, Any] = {}
for key in [
"ge",
"gt",
"le",
"lt",
"isin",
"regex",
"min_length",
"max_length",
]:
val = getattr(fi, key, None)
if val is not None:
constraints[key] = val
rows.append(
{
"attribute": attr_name,
"column": meta["df_col"],
"type": inner.__name__ if inner else "Any",
"required": not meta["is_optional"],
"nullable": fi.nullable,
"unique": fi.unique,
"constraints": constraints or None,
}
)
return rows
# ------------------------------------------------------------------
# Python Protocols
# ------------------------------------------------------------------
def __repr__(self) -> str:
schema = self.__class__._fr_schema
req = sum(1 for m in schema.values() if not m["is_optional"])
opt = sum(1 for m in schema.values() if m["is_optional"])
head = self._fr_backend.head(self._fr_df)
return (
f"<{self.__class__.__name__} [{self._fr_backend.name}]: "
f"{len(self)} rows x "
f"{self._fr_backend.num_cols(self._fr_df)} cols "
f"({req} required, {opt} optional)>\n"
f"{head}"
)
def __eq__(self, other: object) -> bool:
"""Check equality with another Schema of the same type."""
if not isinstance(other, self.__class__):
return NotImplemented
return self._fr_backend.equals(self._fr_df, other._fr_df)