Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(types): assign name f{} to unnamed fields of struct (record) #20490

Merged
merged 7 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion e2e_test/batch/basic/dml_update.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ db error: ERROR: Failed to run the query

Caused by these errors (recent errors listed first):
1: cannot cast type "record" to "struct<v1 integer, v2 integer>"
2: cannot cast to struct field "v1"
2: cannot cast struct field "f1" to struct field "v1"
3: cannot cast type "character varying" to "integer" in Assign context


Expand Down
2 changes: 1 addition & 1 deletion e2e_test/batch/types/map.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ db error: ERROR: Failed to run the query
Caused by these errors (recent errors listed first):
1: Failed to bind expression: map_from_entries(ARRAY[ROW('a', 1, 2)])
2: Expr error
3: the underlying struct for map must have exactly two fields, got: StructType { field_names: [], field_types: [Varchar, Int32, Int32] }
3: the underlying struct for map must have exactly two fields, got: StructType { fields: [("f1", Varchar), ("f2", Int32), ("f3", Int32)] }


query error
Expand Down
13 changes: 13 additions & 0 deletions e2e_test/batch/types/struct/struct.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@ select Row('foo', 'bar', null);
----
(foo,bar,)

query T
select (Row('foo', 'bar')).f1;
----
foo

query T
select (Row('foo', 'bar')).f2;
----
bar

query error column "f3" not found in struct type
select (Row('foo', 'bar')).f3;

query T
select Row();
----
Expand Down
6 changes: 5 additions & 1 deletion src/common/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,12 @@ impl DataType {
};
match self {
DataType::Struct(t) => {
if !t.is_unnamed() {
// To be consistent with `From<&PbDataType>`,
// we only set field names when it's a named struct.
pb.field_names = t.names().map(|s| s.into()).collect();
}
pb.field_type = t.types().map(|f| f.to_protobuf()).collect();
pb.field_names = t.names().map(|s| s.into()).collect();
}
DataType::List(datatype) => {
pb.field_type = vec![datatype.to_protobuf()];
Expand Down
109 changes: 52 additions & 57 deletions src/common/src/types/struct_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use anyhow::anyhow;
use itertools::Itertools;

use super::DataType;
use crate::util::iter_util::{ZipEqDebug, ZipEqFast};
use crate::util::iter_util::ZipEqFast;

/// A cheaply cloneable struct type.
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
Expand All @@ -29,112 +29,112 @@ pub struct StructType(Arc<StructTypeInner>);
impl Debug for StructType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StructType")
.field("field_names", &self.0.field_names)
.field("field_types", &self.0.field_types)
.field("fields", &self.0.fields)
.finish()
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct StructTypeInner {
/// Details about a struct type. There are 2 cases for a struct:
/// 1. `field_names.len() == field_types.len()`: it represents a struct with named fields,
/// e.g. `STRUCT<i INT, j VARCHAR>`.
/// 2. `field_names.len() == 0`: it represents a struct with unnamed fields,
/// e.g. `ROW(1, 2)`.
field_names: Box<[String]>,
field_types: Box<[DataType]>,
/// The name and data type of each field.
///
/// If fields are unnamed, the names will be `f1`, `f2`, etc.
fields: Box<[(String, DataType)]>,
/// Whether the fields are unnamed.
is_unnamed: bool,
}

impl StructType {
/// Creates a struct type with named fields.
pub fn new(named_fields: impl IntoIterator<Item = (impl Into<String>, DataType)>) -> Self {
let iter = named_fields.into_iter();
let mut field_types = Vec::with_capacity(iter.size_hint().0);
let mut field_names = Vec::with_capacity(iter.size_hint().0);
for (name, ty) in iter {
field_names.push(name.into());
field_types.push(ty);
}
let fields = named_fields
.into_iter()
.map(|(name, ty)| (name.into(), ty))
.collect();

Self(Arc::new(StructTypeInner {
field_types: field_types.into(),
field_names: field_names.into(),
fields,
is_unnamed: false,
}))
}

/// Creates a struct type with no fields.
/// Creates a struct type with no fields. This makes no sense in practice.
#[cfg(test)]
pub fn empty() -> Self {
Self(Arc::new(StructTypeInner {
field_types: Box::new([]),
field_names: Box::new([]),
}))
Self::unnamed(Vec::new())
}

/// Creates a struct type with unnamed fields.
/// Creates a struct type with unnamed fields. The names will be assigned `f1`, `f2`, etc.
pub fn unnamed(fields: Vec<DataType>) -> Self {
let fields = fields
.into_iter()
.enumerate()
.map(|(i, ty)| (format!("f{}", i + 1), ty))
.collect();

Self(Arc::new(StructTypeInner {
field_types: fields.into(),
field_names: Box::new([]),
fields,
is_unnamed: true,
}))
}

/// Whether the fields are unnamed.
pub fn is_unnamed(&self) -> bool {
self.0.is_unnamed
}

/// Returns the number of fields.
pub fn len(&self) -> usize {
self.0.field_types.len()
self.0.fields.len()
}

/// Returns `true` if there are no fields.
pub fn is_empty(&self) -> bool {
self.0.field_types.is_empty()
self.0.fields.is_empty()
}

/// Gets an iterator over the names of the fields.
///
/// If the struct field is unnamed, the iterator returns **no names**.
/// If fields are unnamed, the field names will be `f1`, `f2`, etc.
pub fn names(&self) -> impl ExactSizeIterator<Item = &str> {
self.0.field_names.iter().map(|s| s.as_str())
self.0.fields.iter().map(|(name, _)| name.as_str())
}

/// Gets an iterator over the types of the fields.
pub fn types(&self) -> impl ExactSizeIterator<Item = &DataType> {
self.0.field_types.iter()
self.0.fields.iter().map(|(_, ty)| ty)
}

/// Gets an iterator over the fields.
///
/// If the struct field is unnamed, the iterator returns **empty strings**.
pub fn iter(&self) -> impl Iterator<Item = (&str, &DataType)> {
self.0
.field_names
.iter()
.map(|s| s.as_str())
.chain(std::iter::repeat("").take(self.0.field_types.len() - self.0.field_names.len()))
.zip_eq_debug(self.0.field_types.iter())
/// If fields are unnamed, the field names will be `f1`, `f2`, etc.
pub fn iter(&self) -> impl ExactSizeIterator<Item = (&str, &DataType)> {
self.0.fields.iter().map(|(name, ty)| (name.as_str(), ty))
}

/// Compares the datatype with another, ignoring nested field names and metadata.
pub fn equals_datatype(&self, other: &StructType) -> bool {
if self.0.field_types.len() != other.0.field_types.len() {
if self.len() != other.len() {
return false;
}
(self.0.field_types.iter())
.zip_eq_fast(other.0.field_types.iter())

(self.types())
.zip_eq_fast(other.types())
.all(|(a, b)| a.equals_datatype(b))
}
}

impl Display for StructType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if self.0.field_names.is_empty() {
if self.is_unnamed() {
// To be consistent with the return type of `ROW` in Postgres.
write!(f, "record")
} else {
write!(
f,
"struct<{}>",
(self.0.field_types.iter())
.zip_eq_fast(self.0.field_names.iter())
.map(|(d, s)| format!("{} {}", s, d))
self.iter()
.map(|(name, ty)| format!("{} {}", name, ty))
.join(", ")
)
}
Expand All @@ -151,19 +151,14 @@ impl FromStr for StructType {
if !(s.starts_with("struct<") && s.ends_with('>')) {
return Err(anyhow!("expect struct<...>"));
};
let mut field_types = Vec::new();
let mut field_names = Vec::new();
let mut fields = Vec::new();
for field in s[7..s.len() - 1].split(',') {
let field = field.trim();
let mut iter = field.split_whitespace();
let field_name = iter.next().unwrap();
let field_type = iter.next().unwrap();
field_names.push(field_name.to_owned());
field_types.push(DataType::from_str(field_type)?);
let field_name = iter.next().unwrap().to_owned();
let field_type = DataType::from_str(iter.next().unwrap())?;
fields.push((field_name, field_type));
}
Ok(Self(Arc::new(StructTypeInner {
field_types: field_types.into(),
field_names: field_names.into(),
})))
Ok(Self::new(fields))
}
}
4 changes: 1 addition & 3 deletions src/connector/codec/src/decoder/avro/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use risingwave_common::types::{
DataType, Date, DatumCow, Interval, JsonbVal, MapValue, ScalarImpl, Time, Timestamp,
Timestamptz, ToOwnedDatum,
};
use risingwave_common::util::iter_util::ZipEqFast;

pub use self::schema::{avro_schema_to_column_descs, MapHandling, ResolvedAvroSchema};
use super::utils::scaled_bigint_to_rust_decimal;
Expand Down Expand Up @@ -263,8 +262,7 @@ impl<'a> AvroParseOptionsInner<'a> {
return Err(create_error());
};
struct_type_info
.names()
.zip_eq_fast(struct_type_info.types())
.iter()
.map(|(field_name, field_type)| {
if let Some(idx) = record_schema.lookup.get(field_name) {
let value = &descs[*idx].1;
Expand Down
6 changes: 1 addition & 5 deletions src/connector/src/parser/unified/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use risingwave_common::types::{
DataType, Date, Decimal, Int256, Interval, JsonbVal, ScalarImpl, Time, Timestamp, Timestamptz,
ToOwnedDatum,
};
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_connector_codec::decoder::utils::scaled_bigint_to_rust_decimal;
use simd_json::base::ValueAsObject;
use simd_json::prelude::{
Expand Down Expand Up @@ -532,10 +531,7 @@ impl JsonParseOptions {
// Collecting into a Result<Vec<_>> doesn't reserve the capacity in advance, so we `Vec::with_capacity` instead.
// https://github.com/rust-lang/rust/issues/48994
let mut fields = Vec::with_capacity(struct_type_info.len());
for (field_name, field_type) in struct_type_info
.names()
.zip_eq_fast(struct_type_info.types())
{
for (field_name, field_type) in struct_type_info.iter() {
let field_value = json_object_get_case_insensitive(value, field_name)
.unwrap_or_else(|| {
let error = AccessError::Undefined {
Expand Down
8 changes: 3 additions & 5 deletions src/connector/src/sink/deltalake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use risingwave_common::bail;
use risingwave_common::bitmap::Bitmap;
use risingwave_common::catalog::Schema;
use risingwave_common::types::DataType;
use risingwave_common::util::iter_util::{ZipEqDebug, ZipEqFast};
use risingwave_common::util::iter_util::ZipEqDebug;
use risingwave_pb::connector_service::sink_metadata::Metadata::Serialized;
use risingwave_pb::connector_service::sink_metadata::SerializedMetadata;
use risingwave_pb::connector_service::SinkMetadata;
Expand Down Expand Up @@ -260,10 +260,8 @@ fn check_field_type(rw_data_type: &DataType, dl_data_type: &DeltaLakeDataType) -
DataType::Struct(rw_struct) => {
if let DeltaLakeDataType::Struct(dl_struct) = dl_data_type {
let mut result = true;
for ((rw_name, rw_type), dl_field) in rw_struct
.names()
.zip_eq_fast(rw_struct.types())
.zip_eq_debug(dl_struct.fields())
for ((rw_name, rw_type), dl_field) in
rw_struct.iter().zip_eq_debug(dl_struct.fields())
{
result = check_field_type(rw_type, dl_field.data_type())?
&& result
Expand Down
9 changes: 2 additions & 7 deletions src/expr/impl/src/scalar/to_jsonb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,11 @@ impl ToJsonb for MapRef<'_> {
impl ToJsonb for StructRef<'_> {
fn add_to(self, data_type: &DataType, builder: &mut Builder) -> Result<()> {
builder.begin_object();
for (i, (value, (field_name, field_type))) in self
for (value, (field_name, field_type)) in self
.iter_fields_ref()
.zip_eq_debug(data_type.as_struct().iter())
.enumerate()
{
if field_name.is_empty() {
builder.display(format_args!("f{}", i + 1));
} else {
builder.add_string(field_name);
};
builder.add_string(field_name);
value.add_to(field_type, builder)?;
}
builder.end_object();
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/planner_test/tests/testdata/output/array.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@
logical_plan: |-
LogicalProject { exprs: [Array as $expr1] }
└─LogicalValues { rows: [[]], schema: Schema { fields: [] } }
batch_plan: 'BatchValues { rows: [[ARRAY[]:List(Struct(StructType { field_names: ["f1"], field_types: [Int32] }))]] }'
batch_plan: 'BatchValues { rows: [[ARRAY[]:List(Struct(StructType { fields: [("f1", Int32)] }))]] }'
stream_plan: |-
StreamMaterialize { columns: [array, _row_id(hidden)], stream_key: [_row_id], pk_columns: [_row_id], pk_conflict: NoCheck }
└─StreamValues { rows: [[ARRAY[]:List(Struct(StructType { field_names: ["f1"], field_types: [Int32] })), 0:Int64]] }
└─StreamValues { rows: [[ARRAY[]:List(Struct(StructType { fields: [("f1", Int32)] })), 0:Int64]] }
- sql: |
select array_cat(array[66], array[123]);
logical_plan: |-
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/planner_test/tests/testdata/output/expr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@
├─BatchExchange { order: [], dist: HashShard(t.j) }
│ └─BatchScan { table: t, columns: [t.k, t.j], distribution: SomeShard }
└─BatchExchange { order: [], dist: HashShard(t.j) }
└─BatchProject { exprs: [t.j, JsonbPopulateRecord(null:Struct(StructType { field_names: ["a", "b"], field_types: [Int32, Int32] }), t.j) as $expr1] }
└─BatchProject { exprs: [t.j, JsonbPopulateRecord(null:Struct(StructType { fields: [("a", Int32), ("b", Int32)] }), t.j) as $expr1] }
└─BatchNestedLoopJoin { type: Inner, predicate: true, output: all }
├─BatchExchange { order: [], dist: Single }
│ └─BatchHashAgg { group_key: [t.j], aggs: [] }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,27 @@
select * from t where (v1,v3) > (2,3);
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchFilter { predicate: (Row(t.v1, t.v3) > '(2,3)':Struct(StructType { field_names: [], field_types: [Int32, Int32] })) }
└─BatchFilter { predicate: (Row(t.v1, t.v3) > '(2,3)':Struct(StructType { fields: [("f1", Int32), ("f2", Int32)] })) }
└─BatchScan { table: t, columns: [t.v1, t.v2, t.v3], scan_ranges: [t.v1 >= Int32(2)], distribution: UpstreamHashShard(t.v1, t.v2, t.v3) }
- sql: |
create table t(v1 int, v2 int, v3 int, primary key(v1,v2,v3));
select * from t where (v3,v2,v1) > (1,2,3);
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchFilter { predicate: (Row(t.v3, t.v2, t.v1) > '(1,2,3)':Struct(StructType { field_names: [], field_types: [Int32, Int32, Int32] })) }
└─BatchFilter { predicate: (Row(t.v3, t.v2, t.v1) > '(1,2,3)':Struct(StructType { fields: [("f1", Int32), ("f2", Int32), ("f3", Int32)] })) }
└─BatchScan { table: t, columns: [t.v1, t.v2, t.v3], distribution: UpstreamHashShard(t.v1, t.v2, t.v3) }
- sql: |
create table t(v1 int, v2 int, v3 int, primary key(v1,v2,v3));
select * from t where (v1,v2,v1) > (1,2,3);
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchFilter { predicate: (Row(t.v1, t.v2, t.v1) > '(1,2,3)':Struct(StructType { field_names: [], field_types: [Int32, Int32, Int32] })) }
└─BatchFilter { predicate: (Row(t.v1, t.v2, t.v1) > '(1,2,3)':Struct(StructType { fields: [("f1", Int32), ("f2", Int32), ("f3", Int32)] })) }
└─BatchScan { table: t, columns: [t.v1, t.v2, t.v3], scan_ranges: [(t.v1, t.v2) >= (Int32(1), Int32(2))], distribution: UpstreamHashShard(t.v1, t.v2, t.v3) }
- sql: |
create table t1(v1 int, v2 int, v3 int);
create materialized view mv1 as select * from t1 order by v1 asc, v2 asc, v3 desc;
select * from mv1 where (v1,v2,v3) > (1,3,1);
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchFilter { predicate: (Row(mv1.v1, mv1.v2, mv1.v3) > '(1,3,1)':Struct(StructType { field_names: [], field_types: [Int32, Int32, Int32] })) }
└─BatchFilter { predicate: (Row(mv1.v1, mv1.v2, mv1.v3) > '(1,3,1)':Struct(StructType { fields: [("f1", Int32), ("f2", Int32), ("f3", Int32)] })) }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we could further simplify Struct(StructType to just Struct?

└─BatchScan { table: mv1, columns: [mv1.v1, mv1.v2, mv1.v3], scan_ranges: [(mv1.v1, mv1.v2) >= (Int32(1), Int32(3))], distribution: UpstreamHashShard(mv1.v1, mv1.v2, mv1.v3) }
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@
insert into s values (1,2,(1,2,(1,2,null)));
logical_plan: |-
LogicalInsert { table: s, mapping: [0:0, 1:1, 2:2] }
└─LogicalValues { rows: [[1:Int32, 2:Int32, Row(1:Int32, 2:Int32, Row(1:Int32, 2:Int32, null:Int32))]], schema: Schema { fields: [*VALUES*_0.column_0:Int32, *VALUES*_0.column_1:Int32, *VALUES*_0.column_2:Struct(StructType { field_names: ["v1", "v2", "v3"], field_types: [Int32, Int32, Struct(StructType { field_names: ["v1", "v2", "v3"], field_types: [Int32, Int32, Int32] })] })] } }
└─LogicalValues { rows: [[1:Int32, 2:Int32, Row(1:Int32, 2:Int32, Row(1:Int32, 2:Int32, null:Int32))]], schema: Schema { fields: [*VALUES*_0.column_0:Int32, *VALUES*_0.column_1:Int32, *VALUES*_0.column_2:Struct(StructType { fields: [("v1", Int32), ("v2", Int32), ("v3", Struct(StructType { fields: [("v1", Int32), ("v2", Int32), ("v3", Int32)] }))] })] } }
create_table_with_connector:
format: plain
encode: protobuf
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/planner_test/tests/testdata/output/update.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchUpdate { table: t, exprs: [Field($4, 0:Int32), Field($4, 1:Int32), $2] }
└─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, $expr10011::Struct(StructType { field_names: ["v1", "v2"], field_types: [Int32, Int32] }) as $expr1] }
└─BatchProject { exprs: [t.v1, t.v2, t._row_id, t._rw_timestamp, $expr10011::Struct(StructType { fields: [("v1", Int32), ("v2", Int32)] }) as $expr1] }
└─BatchNestedLoopJoin { type: LeftOuter, predicate: true, output: all }
├─BatchExchange { order: [], dist: Single }
│ └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id, t._rw_timestamp], distribution: UpstreamHashShard(t._row_id) }
└─BatchValues { rows: [['(666.66,777)':Struct(StructType { field_names: [], field_types: [Decimal, Int32] })]] }
└─BatchValues { rows: [['(666.66,777)':Struct(StructType { fields: [("f1", Decimal), ("f2", Int32)] })]] }
Loading
Loading