Skip to content

Commit

Permalink
Filling missing columns in more sources
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Mar 12, 2024
1 parent da9757c commit 15ef90d
Show file tree
Hide file tree
Showing 17 changed files with 217 additions and 187 deletions.
12 changes: 9 additions & 3 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:

return CustomLazyPolarsJob(
request=request, method=lambda: dill.loads(self.all_data_method)(request, limit)
)
).fill_missing_columns()

def all_between_dates(
self, request: RetrivalRequest, start_date: datetime, end_date: datetime
Expand All @@ -312,15 +312,15 @@ def all_between_dates(
return CustomLazyPolarsJob(
request=request,
method=lambda: dill.loads(self.all_between_dates_method)(request, start_date, end_date),
)
).fill_missing_columns()

def features_for(self, facts: RetrivalJob, request: RetrivalRequest) -> RetrivalJob:
from aligned.retrival_job import CustomLazyPolarsJob
import dill

return CustomLazyPolarsJob(
request=request, method=lambda: dill.loads(self.features_for_method)(facts, request)
)
).fill_missing_columns()

@classmethod
def multi_source_features_for(
Expand Down Expand Up @@ -619,6 +619,7 @@ def all_with_limit(self, limit: int | None) -> RetrivalJob:
right_on=self.right_on,
timestamp_unit=self.timestamp_unit,
)
.fill_missing_columns()
)

def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
Expand All @@ -639,6 +640,7 @@ def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
timestamp_unit=self.timestamp_unit,
)
.aggregate(request)
.fill_missing_columns()
.derive_features([request])
)

Expand All @@ -661,6 +663,7 @@ def all_between_dates(
right_on=self.right_on,
)
.aggregate(request)
.fill_missing_columns()
.derive_features([request])
)

Expand Down Expand Up @@ -729,6 +732,7 @@ def all_with_limit(self, limit: int | None) -> RetrivalJob:
self.source.all_data(self.left_request, limit=limit)
.derive_features([self.left_request])
.join(right_job, method=self.method, left_on=self.left_on, right_on=self.right_on)
.fill_missing_columns()
)

def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
Expand All @@ -741,6 +745,7 @@ def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
self.source.all_data(self.left_request, limit=limit)
.derive_features([self.left_request])
.join(right_job, method=self.method, left_on=self.left_on, right_on=self.right_on)
.fill_missing_columns()
.aggregate(request)
.derive_features([request])
)
Expand All @@ -763,6 +768,7 @@ def all_between_dates(
left_on=self.left_on,
right_on=self.right_on,
)
.fill_missing_columns()
.aggregate(request)
.derive_features([request])
)
Expand Down
28 changes: 0 additions & 28 deletions aligned/local/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,34 +240,6 @@ def request_result(self) -> RequestResult:
def retrival_requests(self) -> list[RetrivalRequest]:
return [self.request]

def file_transformations(self, df: pd.DataFrame) -> pd.DataFrame:
from aligned.data_source.batch_data_source import ColumnFeatureMappable

entity_names = self.request.entity_names
all_names = list(self.request.all_required_feature_names.union(entity_names))

request_features = all_names
if isinstance(self.source, ColumnFeatureMappable):
request_features = self.source.feature_identifier_for(all_names)

df.rename(
columns=dict(zip(request_features, all_names)),
inplace=True,
)

if self.request.event_timestamp is None:
raise ValueError(f'Source {self.source} have no event timestamp to filter on')

event_timestamp_column = self.request.event_timestamp.name
# Making sure it is in the correct format
df[event_timestamp_column] = pd.to_datetime(
df[event_timestamp_column], infer_datetime_format=True, utc=True
)

start_date_ts = pd.to_datetime(self.start_date, utc=True)
end_date_ts = pd.to_datetime(self.end_date, utc=True)
return df.loc[df[event_timestamp_column].between(start_date_ts, end_date_ts)]

def file_transform_polars(self, df: pl.LazyFrame) -> pl.LazyFrame:
from aligned.data_source.batch_data_source import ColumnFeatureMappable

Expand Down
58 changes: 36 additions & 22 deletions aligned/retrival_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,17 @@ async def to_pandas(self) -> SupervisedDataSet[pd.DataFrame]:
data, entities, features, self.target_columns, self.job.request_result.event_timestamp
)

async def to_polars(self) -> SupervisedDataSet[pl.DataFrame]:
dataset = await self.to_lazy_polars()

return SupervisedDataSet(
data=dataset.data.collect(),
entity_columns=dataset.entity_columns,
features=dataset.feature_columns,
target=dataset.target_columns,
event_timestamp_column=dataset.event_timestamp_column,
)

async def to_lazy_polars(self) -> SupervisedDataSet[pl.LazyFrame]:
data = await self.job.to_lazy_polars()
if self.should_filter_out_null_targets:
Expand Down Expand Up @@ -760,7 +771,7 @@ def copy_with(self: JobType, job: RetrivalJob) -> JobType:
class CustomPolarsJob(RetrivalJob, ModificationJob):

job: RetrivalJob
polars_method: Callable[[pl.LazyFrame], pl.LazyFrame]
polars_method: Callable[[pl.LazyFrame], pl.LazyFrame] # type: ignore

async def to_lazy_polars(self) -> pl.LazyFrame:
df = await self.job.to_lazy_polars()
Expand Down Expand Up @@ -1389,38 +1400,41 @@ class FillMissingColumnsJob(RetrivalJob, ModificationJob):
job: RetrivalJob

async def to_pandas(self) -> pd.DataFrame:
from aligned.schemas.constraints import Optional

data = await self.job.to_pandas()
for request in self.retrival_requests:

missing = request.all_required_feature_names - set(data.columns)
if not missing:
continue
optional_constraint = Optional()
for feature in request.features:
if (
feature.constraints
and optional_constraint in feature.constraints
and feature.name not in data.columns
):
data[feature] = None

logger.warn(
f"""
Some features is missing.
Will fill values with None, but it could be a potential problem: {missing}
"""
)
for feature in missing:
data[feature] = None
return data

async def to_lazy_polars(self) -> pl.LazyFrame:
from aligned.schemas.constraints import Optional

data = await self.job.to_lazy_polars()
optional_constraint = Optional()

for request in self.retrival_requests:

missing = request.all_required_feature_names - set(data.columns)
if not missing:
continue
missing_columns = [
feature.name
for feature in request.features
if feature.constraints
and optional_constraint in feature.constraints
and feature.name not in data.columns
]

if missing_columns:
data = data.with_columns([pl.lit(None).alias(feature) for feature in missing_columns])

logger.warn(
f"""
Some features is missing.
Will fill values with None, but it could be a potential problem: {missing}
"""
)
data = data.with_columns([pl.lit(None).alias(feature) for feature in missing])
return data


Expand Down
15 changes: 11 additions & 4 deletions aligned/schemas/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,12 @@ def multi_source_features_for(
sub_source = source.view.materialized_source or source.view.source
sub_request = source.sub_request(request)

sub_job = sub_source.all_data(sub_request, limit=None).derive_features()
sub_job = (
sub_source.all_data(sub_request, limit=None)
.derive_features()
.with_request([request])
.fill_missing_columns()
)

if request.aggregated_features:
available_features = sub_job.aggregate(request).derive_features()
Expand All @@ -388,10 +393,11 @@ def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
sub_req = self.sub_request(request)

core_job = sub_source.all_data(sub_req, limit=limit)

if request.aggregated_features:
job = core_job.aggregate(request)
else:
job = core_job.derive_features()
job = core_job.derive_features().with_request([request]).fill_missing_columns()

return job.derive_features([request]).rename(self.renames)

Expand All @@ -402,11 +408,12 @@ def all_between_dates(

sub_req = self.sub_request(request)

core_job = sub_source.all_between_dates(sub_req, start_date, end_date)
core_job = sub_source.all_between_dates(sub_req, start_date, end_date).fill_missing_columns()

if request.aggregated_features:
job = core_job.aggregate(request)
else:
job = core_job.derive_features()
job = core_job.derive_features().with_request([request]).fill_missing_columns()
return job.derive_features([request]).rename(self.renames)

def depends_on(self) -> set[FeatureLocation]:
Expand Down
40 changes: 37 additions & 3 deletions aligned/sources/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def test_read_csv(point_in_time_data_test: DataTest) -> None:
batch_source=file_source,
)
compiled = view.compile_instance()
assert compiled.source.path == file_source.path # type: ignore
assert compiled.source.path == file_source.path

store.add_compiled_view(compiled)

Expand Down Expand Up @@ -170,9 +170,43 @@ class Test:
filled = b.fill_na(0)

expected_df = df.with_columns(pl.lit(None).alias('b'), pl.lit(0).alias('filled'))
loaded = await Test.query().all().to_polars()
loaded = await Test.query().all().to_polars() # type: ignore

assert loaded.equals(expected_df.select(loaded.columns))

facts = await Test.query().features_for({'a': [2]}).to_polars()
facts = await Test.query().features_for({'a': [2]}).to_polars() # type: ignore
assert expected_df.filter(pl.col('a') == 2).equals(facts.select(expected_df.columns))


@pytest.mark.asyncio
async def test_read_optional_view() -> None:

source = FileSource.csv_at('test_data/optional_test.csv')
df = pl.DataFrame(
{
'a': [1, 2, 3],
'c': [1, 2, 3],
}
)
await source.write_polars(df.lazy())

@feature_view(name='test_a', source=source)
class TestA:
a = Int32().as_entity()
c = Int32()

@feature_view(name='test', source=TestA) # type: ignore
class Test:
a = Int32().as_entity()
b = Int32().is_optional()
c = Int32()

filled = b.fill_na(0)

expected_df = df.with_columns(pl.lit(None).alias('b'), pl.lit(0).alias('filled'))
loaded = await Test.query().all().to_polars() # type: ignore

assert loaded.equals(expected_df.select(loaded.columns))

facts = await Test.query().features_for({'a': [2]}).to_polars() # type: ignore
assert expected_df.filter(pl.col('a') == 2).equals(facts.select(expected_df.columns))
3 changes: 0 additions & 3 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ class BreastDiagnoseFeatureView(FeatureView):
metadata = FeatureViewMetadata(
name='breast_features',
description='Features defining a scan and diagnose of potential cancer cells',
tags={},
source=scan_without_datetime,
)

Expand Down Expand Up @@ -231,7 +230,6 @@ class BreastDiagnoseFeatureView(FeatureView):
metadata = FeatureViewMetadata(
name='breast_features',
description='Features defining a scan and diagnose of potential cancer cells',
tags={},
source=scan_with_datetime,
)

Expand Down Expand Up @@ -292,7 +290,6 @@ class BreastDiagnoseFeatureView(FeatureView):
metadata = FeatureViewMetadata(
name='breast_features',
description='Features defining a scan and diagnose of potential cancer cells',
tags={},
source=scan_with_datetime,
)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aligned"
version = "0.0.78"
version = "0.0.79"
description = "A data managment and lineage tool for ML applications."
authors = ["Mats E. Mollestad <[email protected]>"]
license = "Apache-2.0"
Expand Down
14 changes: 7 additions & 7 deletions test_data/credit_history.csv
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
credit_card_due,dob_ssn,due_sum,bankruptcies,student_loan_due,event_timestamp
8419,19530219_5179,30747,0,22328,1587924064746575
2944,19520816_8737,5459,0,2515,1587924064746575
833,19860413_2537,33833,0,33000,1587924064746575
5936,19530219_5179,54891,0,48955,1588010464746575
1575,19520816_8737,11076,0,9501,1588010464746575
6263,19860413_2537,41773,0,35510,1588010464746575
bankruptcies,event_timestamp,student_loan_due,credit_card_due,dob_ssn,due_sum
0,1587924064746575,22328,8419,19530219_5179,30747
0,1587924064746575,2515,2944,19520816_8737,5459
0,1587924064746575,33000,833,19860413_2537,33833
0,1588010464746575,48955,5936,19530219_5179,54891
0,1588010464746575,9501,1575,19520816_8737,11076
0,1588010464746575,35510,6263,19860413_2537,41773
Binary file modified test_data/credit_history_mater.parquet
Binary file not shown.
Loading

0 comments on commit 15ef90d

Please sign in to comment.