Skip to content

Commit

Permalink
Improved for data quality
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Mar 19, 2024
1 parent 81af2e8 commit 7ddf763
Show file tree
Hide file tree
Showing 15 changed files with 309 additions and 126 deletions.
32 changes: 31 additions & 1 deletion aligned/compiler/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,10 @@ def fill_na(self: T, value: FeatureFactory | Any) -> T:
from aligned.compiler.transformation_factory import FillMissingFactory

instance: FeatureFactory = self.copy_type() # type: ignore [attr-defined]

if instance.constraints:
instance.constraints.remove(Optional())

if not isinstance(value, FeatureFactory):
value = LiteralValue.from_value(value)

Expand Down Expand Up @@ -886,11 +890,15 @@ def dtype(self) -> FeatureType:
return FeatureType.bool()

def copy_type(self) -> Bool:
if self.constraints and Optional() in self.constraints:
return Bool().is_optional()
return Bool()


class Float(ArithmeticFeature, DecimalOperations):
def copy_type(self) -> Float:
if self.constraints and Optional() in self.constraints:
return Float().is_optional()
return Float()

@property
Expand All @@ -903,6 +911,8 @@ def aggregate(self) -> ArithmeticAggregation:

class Int8(ArithmeticFeature, CouldBeEntityFeature, CouldBeModelVersion, CategoricalEncodableFeature):
def copy_type(self) -> Int8:
if self.constraints and Optional() in self.constraints:
return Int8().is_optional()
return Int8()

@property
Expand All @@ -915,6 +925,8 @@ def aggregate(self) -> ArithmeticAggregation:

class Int16(ArithmeticFeature, CouldBeEntityFeature, CouldBeModelVersion, CategoricalEncodableFeature):
def copy_type(self) -> Int16:
if self.constraints and Optional() in self.constraints:
return Int16().is_optional()
return Int16()

@property
Expand All @@ -927,6 +939,8 @@ def aggregate(self) -> ArithmeticAggregation:

class Int32(ArithmeticFeature, CouldBeEntityFeature, CouldBeModelVersion, CategoricalEncodableFeature):
def copy_type(self) -> Int32:
if self.constraints and Optional() in self.constraints:
return Int32().is_optional()
return Int32()

@property
Expand All @@ -939,6 +953,8 @@ def aggregate(self) -> ArithmeticAggregation:

class Int64(ArithmeticFeature, CouldBeEntityFeature, CouldBeModelVersion, CategoricalEncodableFeature):
def copy_type(self) -> Int64:
if self.constraints and Optional() in self.constraints:
return Int64().is_optional()
return Int64()

@property
Expand All @@ -951,6 +967,8 @@ def aggregate(self) -> ArithmeticAggregation:

class UUID(FeatureFactory, CouldBeEntityFeature):
def copy_type(self) -> UUID:
if self.constraints and Optional() in self.constraints:
return UUID().is_optional()
return UUID()

@property
Expand Down Expand Up @@ -1000,6 +1018,8 @@ class String(
StringValidatable,
):
def copy_type(self) -> String:
if self.constraints and Optional() in self.constraints:
return String().is_optional()
return String()

@property
Expand Down Expand Up @@ -1073,7 +1093,9 @@ def as_image_url(self) -> ImageUrl:

class Json(FeatureFactory):
def copy_type(self: Json) -> Json:
return super().copy_type()
if self.constraints and Optional() in self.constraints:
return Json().is_optional()
return Json()

@property
def dtype(self) -> FeatureType:
Expand Down Expand Up @@ -1168,6 +1190,8 @@ class Embedding(FeatureFactory):
indexes: list[VectorIndexFactory] | None = None

def copy_type(self) -> Embedding:
if self.constraints and Optional() in self.constraints:
return Embedding().is_optional()
return Embedding()

@property
Expand Down Expand Up @@ -1205,6 +1229,8 @@ class List(FeatureFactory, Generic[GenericFeature]):
sub_type: GenericFeature

def copy_type(self) -> List:
if self.constraints and Optional() in self.constraints:
return List(self.sub_type.copy_type()).is_optional()
return List(self.sub_type.copy_type())

@property
Expand Down Expand Up @@ -1232,6 +1258,8 @@ def dtype(self) -> FeatureType:
return FeatureType.string()

def copy_type(self) -> ImageUrl:
if self.constraints and Optional() in self.constraints:
return ImageUrl().is_optional()
return ImageUrl()

def load_image(self) -> Image:
Expand All @@ -1248,6 +1276,8 @@ def dtype(self) -> FeatureType:
return FeatureType.array()

def copy_type(self) -> Image:
if self.constraints and Optional() in self.constraints:
return Image().is_optional()
return Image()

def to_grayscale(self) -> Image:
Expand Down
2 changes: 1 addition & 1 deletion aligned/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def execute_sql(self, query: str) -> RetrivalJob:

raise ValueError(f"Unable to find table `{column.table}` for query `{query}`")

all_features = set()
all_features: set[str] = set()

for table, columns in table_columns.items():
all_features.update(f'{table}:{column}' for column in columns)
Expand Down
134 changes: 134 additions & 0 deletions aligned/retrival_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,11 @@ def join_asof(
timestamp_unit=timestamp_unit,
)

def return_invalid(self, should_return_validation: bool | None = None) -> RetrivalJob:
if should_return_validation is None:
should_return_validation = False
return ReturnInvalidJob(self, should_return_validation)

def join(
self, job: RetrivalJob, method: str, left_on: str | list[str], right_on: str | list[str]
) -> RetrivalJob:
Expand Down Expand Up @@ -767,6 +772,135 @@ def copy_with(self: JobType, job: RetrivalJob) -> JobType:
return self


def polars_filter_expressions_from(features: list[Feature]) -> list[tuple[pl.Expr, str]]:
from aligned.schemas.constraints import (
Optional,
LowerBound,
UpperBound,
InDomain,
MinLength,
MaxLength,
EndsWith,
StartsWith,
LowerBoundInclusive,
UpperBoundInclusive,
Regex,
)

optional_constraint = Optional()

exprs: list[tuple[pl.Expr, str]] = []

for feature in features:
if not feature.constraints:
exprs.append((pl.col(feature.name).is_not_null(), f"Required {feature.name}"))
continue

if optional_constraint not in feature.constraints:
exprs.append((pl.col(feature.name).is_not_null(), f"Required {feature.name}"))
continue

for constraint in feature.constraints:
if isinstance(constraint, LowerBound):
exprs.append(
(pl.col(feature.name) > constraint.value, f"LowerBound {feature.name} {constraint.value}")
)
elif isinstance(constraint, LowerBoundInclusive):
exprs.append(
(
pl.col(feature.name) >= constraint.value,
f"LowerBoundInclusive {feature.name} {constraint.value}",
)
)
elif isinstance(constraint, UpperBound):
exprs.append(
(pl.col(feature.name) < constraint.value, f"UpperBound {feature.name} {constraint.value}")
)
elif isinstance(constraint, UpperBoundInclusive):
exprs.append(
(
pl.col(feature.name) <= constraint.value,
f"UpperBoundInclusive {feature.name} {constraint.value}",
)
)
elif isinstance(constraint, InDomain):
exprs.append(
(
pl.col(feature.name).is_in(constraint.values),
f"InDomain {feature.name} {constraint.values}",
)
)
elif isinstance(constraint, MinLength):
exprs.append(
(
pl.col(feature.name).str.lengths() > constraint.value,
f"MinLength {feature.name} {constraint.value}",
)
)
elif isinstance(constraint, MaxLength):
exprs.append(
(
pl.col(feature.name).str.lengths() < constraint.value,
f"MaxLength {feature.name} {constraint.value}",
)
)
elif isinstance(constraint, EndsWith):
exprs.append(
(
pl.col(feature.name).str.ends_with(constraint.value),
f"EndsWith {feature.name} {constraint.value}",
)
)
elif isinstance(constraint, StartsWith):
exprs.append(
(
pl.col(feature.name).str.starts_with(constraint.value),
f"StartsWith {feature.name} {constraint.value}",
)
)
elif isinstance(constraint, Regex):
exprs.append(
(
pl.col(feature.name).str.contains(constraint.value),
f"Regex {feature.name} {constraint.value}",
)
)

return exprs


@dataclass
class ReturnInvalidJob(RetrivalJob, ModificationJob):

job: RetrivalJob
should_return_validation: bool

def describe(self) -> str:
expressions = [
expr.is_not().alias(f"not {name}")
for expr, name in polars_filter_expressions_from(list(self.request_result.features))
]

return 'ReturnInvalidJob ' + self.job.describe() + ' with filter expressions ' + str(expressions)

async def to_lazy_polars(self) -> pl.LazyFrame:
raw_exprs = polars_filter_expressions_from(list(self.request_result.features))
expressions = [expr.is_not().alias(f"not {name}") for expr, name in raw_exprs]

if self.should_return_validation:
condition_cols = [f"not {name}" for _, name in raw_exprs]
return (
(await self.job.to_lazy_polars())
.with_columns(expressions)
.filter(pl.any_horizontal(*condition_cols))
)
else:
return (await self.job.to_lazy_polars()).filter(pl.any_horizontal(expressions))

async def to_pandas(self) -> pd.DataFrame:
return (await self.to_lazy_polars()).collect().to_pandas()


@dataclass
class CustomPolarsJob(RetrivalJob, ModificationJob):

Expand Down
19 changes: 19 additions & 0 deletions aligned/validation/tests/test_pandera_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,22 @@ async def test_validate_invalid_feature_view(titanic_feature_store: FeatureStore
)

assert validated_df.shape[0] == 16


@pytest.mark.asyncio
async def test_return_invalid_rows(titanic_feature_store: FeatureStore) -> None:
validated_job = titanic_feature_store.feature_view('titanic').all(limit=20).return_invalid()

validated_df = await validated_job.to_pandas()

assert validated_df.shape[0] == 4
assert validated_df.shape[1] == 11

with_validation = await (
titanic_feature_store.feature_view('titanic')
.all(limit=20)
.return_invalid(should_return_validation=True)
.to_polars()
)
assert with_validation.shape[0] == 4
assert with_validation.shape[1] == 20
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aligned"
version = "0.0.82"
version = "0.0.83"
description = "A data managment and lineage tool for ML applications."
authors = ["Mats E. Mollestad <[email protected]>"]
license = "Apache-2.0"
Expand Down
14 changes: 7 additions & 7 deletions test_data/credit_history.csv
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
student_loan_due,credit_card_due,bankruptcies,event_timestamp,due_sum,dob_ssn
22328,8419,0,2020-04-26 18:01:04.746575+00:00,30747,19530219_5179
2515,2944,0,2020-04-26 18:01:04.746575+00:00,5459,19520816_8737
33000,833,0,2020-04-26 18:01:04.746575+00:00,33833,19860413_2537
48955,5936,0,2020-04-27 18:01:04.746575+00:00,54891,19530219_5179
9501,1575,0,2020-04-27 18:01:04.746575+00:00,11076,19520816_8737
35510,6263,0,2020-04-27 18:01:04.746575+00:00,41773,19860413_2537
dob_ssn,event_timestamp,due_sum,credit_card_due,student_loan_due,bankruptcies
19530219_5179,2020-04-26 18:01:04.746575+00:00,30747,8419,22328,0
19520816_8737,2020-04-26 18:01:04.746575+00:00,5459,2944,2515,0
19860413_2537,2020-04-26 18:01:04.746575+00:00,33833,833,33000,0
19530219_5179,2020-04-27 18:01:04.746575+00:00,54891,5936,48955,0
19520816_8737,2020-04-27 18:01:04.746575+00:00,11076,1575,9501,0
19860413_2537,2020-04-27 18:01:04.746575+00:00,41773,6263,35510,0
Binary file modified test_data/credit_history_mater.parquet
Binary file not shown.
Loading

0 comments on commit 7ddf763

Please sign in to comment.