Skip to content

Commit

Permalink
fixing tests and addressing PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Clifton Tyler committed Nov 7, 2023
1 parent c9f62f6 commit 0778e49
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 85 deletions.
10 changes: 5 additions & 5 deletions wizard/core/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from etils import epath
import pandas as pd
import requests
from state import Distribution
from state import FileObject


@dataclasses.dataclass
Expand Down Expand Up @@ -90,15 +90,15 @@ def get_dataframe(file_type: FileType, file: io.BytesIO | epath.Path) -> pd.Data
raise NotImplementedError()


def file_from_url(file_type: FileType, url: str) -> Distribution:
def file_from_url(file_type: FileType, url: str) -> FileObject:
"""Downloads locally and extracts the file information."""
file_path = hash_file_path(url)
if not file_path.exists():
download_file(url, file_path)
with file_path.open("rb") as file:
sha256 = _sha256(file.read())
df = get_dataframe(file_type, file_path)
return Distribution(
return FileObject(
name=url.split("/")[-1],
description="",
content_url=url,
Expand All @@ -108,11 +108,11 @@ def file_from_url(file_type: FileType, url: str) -> Distribution:
)


def file_from_upload(file_type: FileType, file: io.BytesIO) -> Distribution:
def file_from_upload(file_type: FileType, file: io.BytesIO) -> FileObject:
"""Uploads locally and extracts the file information."""
sha256 = _sha256(file.getvalue())
df = get_dataframe(file_type, file)
return Distribution(
return FileObject(
name=file.name,
description="",
content_url=f"data/{file.name}",
Expand Down
133 changes: 75 additions & 58 deletions wizard/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
"""

import dataclasses
from typing import Type, TypeVar

import streamlit as st

import mlcroissant as mlc

T = TypeVar('T', bound='Parent')

def init_state():

if Croissant not in st.session_state:
st.session_state[Croissant] = Croissant()
if Metadata not in st.session_state:
st.session_state[Metadata] = Metadata()

if mlc.Dataset not in st.session_state:
st.session_state[mlc.Dataset] = None
Expand All @@ -23,13 +25,15 @@ def init_state():


class CurrentStep:
"""hold all major state variables for the application"""
start = "start"
load = "load"
editor = "editor"


@dataclasses.dataclass
class Distribution:
class FileObject:
"""FileObject analogue for editor"""
name: str | None = None
description: str | None = None
content_size: str | None = None
Expand All @@ -38,41 +42,56 @@ class Distribution:
sha256: str | None = None
df: str | None = None

@dataclasses.dataclass
class FileSet:
"""FileSet analogue for editor"""
contained_in: list[str] = dataclasses.field(default_factory=list)
description: str | None = None
encoding_format: str | None = ""
includes: str | None = ""
name: str = ""

@dataclasses.dataclass
class Field:
"""Field analogue for editor"""
name: str | None = None
description: str | None = None
data_type: str | None = None

@dataclasses.dataclass
class RecordSet:
"""Record Set analogue for editor"""
name: str = ""
description: str | None = None
is_enumeration: bool | None = None
key: str | list[str] | None = None
fields: list[Field] = dataclasses.field(default_factory=list)


@dataclasses.dataclass
class Metadata:

"""main croissant data object, helper functions exist to load and unload this into the mlcroissant version"""
name: str = ""
description: str | None = None
citation: str | None = None
license: str | None = ""
url: str = ""
distributions: list[FileObject | FileSet] = dataclasses.field(default_factory=list)
record_sets: list[RecordSet] = dataclasses.field(default_factory=list)

def __bool__(self):
return self.name != "" and self.url != ""

@dataclasses.dataclass
class Croissant:
metadata: Metadata = Metadata()
distributions: list[Distribution] = dataclasses.field(default_factory=list)
record_sets: list[RecordSet] = dataclasses.field(default_factory=list)
def update_metadata(self, metadata: Metadata) -> None:
self.metadata = Metadata
def add_distribution(self, distribution: Distribution) -> None:

def update_metadata(self, description: str , citation: str , license: license , url: str = "", name: str = "") -> None:
self.name = name
self.description = description
self.citation = citation
self.license = license
self.url = url
def add_distribution(self, distribution: FileSet | FileObject) -> None:
self.distributions.append(distribution)
def update_distribution(self, key: int, distribution: Distribution) -> None:
def update_distribution(self, key: int, distribution: FileSet | FileObject) -> None:
self.distributions[key] = distribution
def remove_distribution(self, key: int) -> None:
del self.distributions[key]
Expand All @@ -83,52 +102,50 @@ def update_record_set(self, key: int, record_set: RecordSet) -> None:
def remove_record_set(self, key: int) -> None:
del self.record_sets[key]

def CanonicalToWizard(dataset: mlc.Dataset) -> Croissant:
canonical_metadata = dataset.metadata
metadata = Metadata(
name=canonical_metadata.name,
description=canonical_metadata.description,
citation=canonical_metadata.citation,
license=canonical_metadata.license,
url=canonical_metadata.url
)
distributions = []
for file in canonical_metadata.distribution:
if isinstance(file, mlc.nodes.FileObject):
distributions.append(Distribution(
name=file.name,
description=file.description,
content_size=file.content_size,
encoding_format=file.encoding_format,
sha256=file.sha256,
))
else:
distributions.append(Distribution(
name=file.name,
description=file.description,
encoding_format=file.encoding_format,
@classmethod
def from_canonical(cls: Type[T], dataset: mlc.Dataset) -> T:
canonical_metadata = dataset.metadata
distributions = []
for file in canonical_metadata.distribution:
if isinstance(file, mlc.nodes.FileObject):
distributions.append(FileObject(
name=file.name,
description=file.description,
content_size=file.content_size,
encoding_format=file.encoding_format,
sha256=file.sha256,
))
else:
distributions.append(FileSet(
name=file.name,
description=file.description,
encoding_format=file.encoding_format,
))
record_sets = []
for record_set in canonical_metadata.record_sets:
fields = []
for field in record_set.fields:
fields.append(Field(
name=field.name,
description=field.description,
data_type=field.data_types
))
record_sets.append(RecordSet(
name=record_set.name,
description=record_set.description,
is_enumeration=record_set.is_enumeration,
key=record_set.key,
fields=fields
))
record_sets = []
for record_set in canonical_metadata.record_sets:
fields = []
for field in record_set.fields:
fields.append(Field(
name=field.name,
description=field.description,
data_type=field.data_types
))
record_sets.append(RecordSet(
name=record_set.name,
description=record_set.description,
is_enumeration=record_set.is_enumeration,
key=record_set.key,
fields=fields
))
return Croissant(
metadata=metadata,
distributions=distributions,
record_sets=record_sets
)
return cls(
name=canonical_metadata.name,
description=canonical_metadata.description,
citation=canonical_metadata.citation,
license=canonical_metadata.license,
url=canonical_metadata.url,
distributions=distributions,
record_sets=record_sets,
)

def set_form_step(action, step=None):
"""Maintains the user's location within the wizard."""
Expand Down
16 changes: 8 additions & 8 deletions wizard/views/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from core.files import file_from_url
from core.files import FILE_TYPES
import pandas as pd
from state import Croissant
from state import Distribution
from state import FileObject
from state import Metadata
import streamlit as st
from utils import DF_HEIGHT
from utils import needed_field
Expand All @@ -30,7 +30,7 @@ def render_files():
file = file_from_upload(file_type, uploaded_file)
else:
raise ValueError("should have either `url` or `uploaded_file`.")
st.session_state[Croissant].add_distribution(file)
st.session_state[Metadata].add_distribution(file)
dtypes = file.df.dtypes
fields = pd.DataFrame(
{
Expand All @@ -39,18 +39,18 @@ def render_files():
"description": "",
}
)
st.session_state[Croissant].add_record_set(
st.session_state[Metadata].add_record_set(
{
"fields": fields,
"name": file.name + "_record_set",
"description": "",
}
)
for key, file in enumerate(st.session_state[Croissant].distributions):
for key, file in enumerate(st.session_state[Metadata].distributions):
with st.container():

def delete_line():
st.session_state[Croissant].remove_distribution(key)
st.session_state[Metadata].remove_distribution(key)

name = st.text_input(
needed_field("Name"),
Expand Down Expand Up @@ -79,12 +79,12 @@ def delete_line():
st.dataframe(file.df, height=DF_HEIGHT)
_, col = st.columns([5, 1])
col.button("Remove", key=f'{key}_url', on_click=delete_line, type="primary")
file = Distribution(
file = FileObject(
name=name,
description=description,
content_url=file.content_url,
encoding_format=encoding_format,
sha256=sha256,
df=file.df,
)
st.session_state[Croissant].update_distribution(key, file)
st.session_state[Metadata].update_distribution(key, file)
6 changes: 3 additions & 3 deletions wizard/views/jsonld.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import pandas as pd
from state import Croissant
from state import Metadata
import streamlit as st

import mlcroissant as mlc


def render_jsonld():
if not st.session_state[Croissant]:
if not st.session_state[Metadata]:
return st.code({}, language="json")
try:
croissant = st.session_state[Croissant]
croissant = st.session_state[Metadata]
distribution = []
for file in croissant.distributions:
distribution.append(
Expand Down
5 changes: 2 additions & 3 deletions wizard/views/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from pathlib import Path
import tempfile

from state import CanonicalToWizard
from state import Croissant
from state import CurrentStep
from state import Metadata
from state import set_form_step
import streamlit as st

Expand All @@ -23,7 +22,7 @@ def render_load():
with open(newfile_name, mode="wb+") as outfile:
outfile.write(file_cont)
dataset = mlc.Dataset(newfile_name)
st.session_state[Croissant] = CanonicalToWizard(dataset)
st.session_state[Metadata] = Metadata.from_canonical(dataset)
set_form_step("Jump", CurrentStep.editor)
st.rerun()
except mlc.ValidationError as e:
Expand Down
7 changes: 3 additions & 4 deletions wizard/views/metadata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from state import Croissant
from state import Metadata
import streamlit as st
from utils import needed_field
Expand All @@ -23,7 +22,7 @@


def render_metadata():
metadata = st.session_state[Croissant].metadata
metadata = st.session_state[Metadata]
name = st.text_input(
label=needed_field("Name"),
value=metadata.name,
Expand Down Expand Up @@ -54,10 +53,10 @@ def render_metadata():
placeholder="@book{\n title={Title}\n}",
)
# We fully recreate the session state in order to force the re-rendering.
st.session_state[Croissant].update_metadata(Metadata(
st.session_state[Metadata].update_metadata(
name=name,
description=description,
license=license,
url=url,
citation=citation
))
)
7 changes: 3 additions & 4 deletions wizard/views/record_sets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pandas as pd
from state import Croissant
from state import RecordSet
from state import Metadata
import streamlit as st
from utils import DF_HEIGHT

Expand All @@ -13,10 +12,10 @@


def render_record_sets():
if len(st.session_state[Croissant].record_sets) == 0:
if len(st.session_state[Metadata].record_sets) == 0:
st.markdown("Please add files first.")
else:
for record_set in st.session_state[Croissant].record_sets:
for record_set in st.session_state[Metadata].record_sets:
record_set_conv = pd.DataFrame(record_set.fields)
record_set_conv.drop(columns=["data_type"])
with st.container():
Expand Down

0 comments on commit 0778e49

Please sign in to comment.