Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cli: improve generic decoders and type generation #164

Merged
merged 3 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 121 additions & 31 deletions cmd/hasura-ndc-go/command/internal/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"go/token"
"go/types"
"io"
"os"
"path"
Expand Down Expand Up @@ -61,7 +62,6 @@ func (ctb connectorTypeBuilder) String() string {
bs.WriteString(")\n")
}

bs.WriteString("var connector_Decoder = utils.NewDecoder()\n")
bs.WriteString(ctb.builder.String())
return bs.String()
}
Expand Down Expand Up @@ -642,51 +642,61 @@ func (cg *connectorGenerator) writeGetTypeValueDecoder(sb *connectorTypeBuilder,
case "*[]*any", "*[]*interface{}":
cg.writeScalarDecodeValue(sb, fieldName, "GetNullableArbitraryJSONPtrSlice", "", key, objectField, true)
default:
sb.builder.WriteString(" j.")
sb.builder.WriteString(fieldName)
sb.builder.WriteString(", err = utils.")
switch t := ty.(type) {
case *NullableType:
packagePaths := getTypePackagePaths(t.UnderlyingType, sb.packagePath)
tyName := getTypeArgumentName(t.UnderlyingType, sb.packagePath, false)

var tyName string
var packagePaths []string
if t.IsAnonymous() {
tyName, packagePaths = cg.getAnonymousObjectTypeName(sb, field.TypeAST, true)
} else {
packagePaths = getTypePackagePaths(t.UnderlyingType, sb.packagePath)
tyName = getTypeArgumentName(t.UnderlyingType, sb.packagePath, false)
}
for _, pkgPath := range packagePaths {
sb.imports[pkgPath] = ""
}
sb.builder.WriteString(" j.")
sb.builder.WriteString(fieldName)
sb.builder.WriteString(" = new(")
sb.builder.WriteString(tyName)
sb.builder.WriteString(")\n err = connector_Decoder.")
if field.Embedded {
sb.builder.WriteString("DecodeObject(j.")
sb.builder.WriteString(fieldName)
sb.builder.WriteString(", input)")
sb.builder.WriteString("DecodeNullableObject[")
sb.builder.WriteString(tyName)
sb.builder.WriteString("](input)")
} else {
sb.builder.WriteString("DecodeNullableObjectValue(j.")
sb.builder.WriteString(fieldName)
sb.builder.WriteString(`, input, "`)
sb.builder.WriteString("DecodeNullableObjectValue[")
sb.builder.WriteString(tyName)
sb.builder.WriteString(`](input, "`)
sb.builder.WriteString(key)
sb.builder.WriteString(`")`)
}
default:
var canEmpty bool
if len(objectField.Type) > 0 {
if typeEnum, err := objectField.Type.Type(); err == nil && typeEnum == schema.TypeNullable {
canEmpty = true
}
var tyName string
var packagePaths []string
if t.IsAnonymous() {
tyName, packagePaths = cg.getAnonymousObjectTypeName(sb, field.TypeAST, true)
} else {
packagePaths = getTypePackagePaths(ty, sb.packagePath)
tyName = getTypeArgumentName(ty, sb.packagePath, false)
}
for _, pkgPath := range packagePaths {
sb.imports[pkgPath] = ""
}

if field.Embedded {
sb.builder.WriteString(" err = connector_Decoder.DecodeObject(&j.")
sb.builder.WriteString(fieldName)
sb.builder.WriteString(", input)")
sb.builder.WriteString("DecodeObject")
sb.builder.WriteRune('[')
sb.builder.WriteString(tyName)
sb.builder.WriteString("](input)")
} else {
sb.builder.WriteString(" err = connector_Decoder.")
if canEmpty {
sb.builder.WriteString("DecodeNullableObjectValue")
} else {
sb.builder.WriteString("DecodeObjectValue")
sb.builder.WriteString("DecodeObjectValue")
if len(objectField.Type) > 0 {
if typeEnum, err := objectField.Type.Type(); err == nil && typeEnum == schema.TypeNullable {
sb.builder.WriteString("Default")
}
}
sb.builder.WriteString("(&j.")
sb.builder.WriteString(fieldName)
sb.builder.WriteString(`, input, "`)
sb.builder.WriteRune('[')
sb.builder.WriteString(tyName)
sb.builder.WriteString(`](input, "`)
sb.builder.WriteString(key)
sb.builder.WriteString(`")`)
}
Expand Down Expand Up @@ -715,6 +725,86 @@ func (cg *connectorGenerator) writeScalarDecodeValue(sb *connectorTypeBuilder, f
sb.builder.WriteString(`")`)
}

// generate anonymous object type name with absolute package paths removed
func (cg *connectorGenerator) getAnonymousObjectTypeName(sb *connectorTypeBuilder, goType types.Type, skipNullable bool) (string, []string) {
switch inferredType := goType.(type) {
case *types.Pointer:
var result string
if !skipNullable {
result += "*"
}
underlyingName, packagePaths := cg.getAnonymousObjectTypeName(sb, inferredType.Elem(), false)
return result + underlyingName, packagePaths
case *types.Struct:
packagePaths := []string{}
result := "struct{"
for i := 0; i < inferredType.NumFields(); i++ {
fieldVar := inferredType.Field(i)
fieldTag := inferredType.Tag(i)
if i > 0 {
result += "; "
}
result += fieldVar.Name() + " "
underlyingName, pkgPaths := cg.getAnonymousObjectTypeName(sb, fieldVar.Type(), false)
result += underlyingName
packagePaths = append(packagePaths, pkgPaths...)
if fieldTag != "" {
result += " `" + fieldTag + "`"
}
}
result += "}"
return result, packagePaths
case *types.Named:
packagePaths := []string{}
innerType := inferredType.Obj()
if innerType == nil {
return "", packagePaths
}

var result string
typeInfo := &TypeInfo{
Name: innerType.Name(),
}

innerPkg := innerType.Pkg()
if innerPkg != nil && innerPkg.Name() != "" && innerPkg.Path() != sb.packagePath {
packagePaths = append(packagePaths, innerPkg.Path())
result += innerPkg.Name() + "."
typeInfo.PackageName = innerPkg.Name()
typeInfo.PackagePath = innerPkg.Path()
}

result += innerType.Name()
typeParams := inferredType.TypeParams()
if typeParams != nil && typeParams.Len() > 0 {
// unwrap the generic type parameters such as Foo[T]
if err := parseTypeParameters(typeInfo, inferredType.String()); err == nil {
result += "["
for i, typeParam := range typeInfo.TypeParameters {
if i > 0 {
result += ", "
}
packagePaths = append(packagePaths, getTypePackagePaths(typeParam, sb.packagePath)...)
result += getTypeArgumentName(typeParam, sb.packagePath, false)
}
result += "]"
}
}

return result, packagePaths
case *types.Basic:
return inferredType.Name(), []string{}
case *types.Array:
result, packagePaths := cg.getAnonymousObjectTypeName(sb, inferredType.Elem(), false)
return "[]" + result, packagePaths
case *types.Slice:
result, packagePaths := cg.getAnonymousObjectTypeName(sb, inferredType.Elem(), false)
return "[]" + result, packagePaths
default:
return inferredType.String(), []string{}
}
}

func formatLocalFieldName(input string, others ...string) string {
name := fieldNameRegex.ReplaceAllString(input, "_")
return strings.Trim(strings.Join(append([]string{name}, others...), "_"), "_")
Expand Down
14 changes: 8 additions & 6 deletions cmd/hasura-ndc-go/command/internal/connector_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,18 @@ func (dch DataConnectorHandler) execQuery(ctx context.Context, state *`)
chb.Builder.imports[fn.ArgumentsType.PackagePath] = ""
}

sb.WriteString(`
var args `)
sb.WriteString(argName)
sb.WriteString("\n if parseErr := ")
if fn.ArgumentsType.CanMethod() {
sb.WriteString("\n var args ")
sb.WriteString(argName)
sb.WriteString("\n parseErr := ")
sb.WriteString("args.FromValue(rawArgs)")
} else {
sb.WriteString("connector_Decoder.DecodeObject(&args, rawArgs)")
sb.WriteString("\n args, parseErr := utils.DecodeObject[")
sb.WriteString(argName)
sb.WriteString("](rawArgs)")
}
sb.WriteString(`; parseErr != nil {
sb.WriteString(`
if parseErr != nil {
return nil, schema.UnprocessableContentError("failed to resolve arguments", map[string]any{
"cause": parseErr.Error(),
})
Expand Down
20 changes: 20 additions & 0 deletions cmd/hasura-ndc-go/command/internal/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,31 @@ import (
"github.com/hasura/ndc-sdk-go/schema"
)

// OperationKind the operation kind of connectors
type OperationKind string

const (
OperationFunction OperationKind = "Function"
OperationProcedure OperationKind = "Procedure"
)

// Scalar the structured information of the scalar
type Scalar struct {
Schema schema.ScalarType
NativeType *TypeInfo
}

// Type the interface of a type schema
type Type interface {
Kind() schema.TypeEnum
Schema() schema.TypeEncoder
SchemaName(isAbsolute bool) string
FullName() string
String() string
IsAnonymous() bool
}

// NullableType the information of the nullable type
type NullableType struct {
UnderlyingType Type
}
Expand All @@ -41,6 +46,10 @@ func (t *NullableType) Kind() schema.TypeEnum {
return schema.TypeNullable
}

func (t *NullableType) IsAnonymous() bool {
return t.UnderlyingType.IsAnonymous()
}

func (t NullableType) SchemaName(isAbsolute bool) string {
var result string
if isAbsolute {
Expand All @@ -65,6 +74,7 @@ func (t NullableType) String() string {
return "*" + t.UnderlyingType.String()
}

// ArrayType the information of the array type
type ArrayType struct {
ElementType Type
}
Expand All @@ -79,6 +89,10 @@ func (t *ArrayType) Kind() schema.TypeEnum {
return schema.TypeArray
}

func (t *ArrayType) IsAnonymous() bool {
return t.ElementType.IsAnonymous()
}

func (t *ArrayType) Schema() schema.TypeEncoder {
return schema.NewArrayType(t.ElementType.Schema())
}
Expand All @@ -100,6 +114,7 @@ func (t ArrayType) String() string {
return "[]" + t.ElementType.String()
}

// NamedType the information of a named type
type NamedType struct {
Name string
NativeType *TypeInfo
Expand All @@ -115,6 +130,10 @@ func (t *NamedType) Kind() schema.TypeEnum {
return schema.TypeNamed
}

func (t *NamedType) IsAnonymous() bool {
return t.NativeType.IsAnonymous
}

func (t *NamedType) Schema() schema.TypeEncoder {
return schema.NewNamedType(t.Name)
}
Expand Down Expand Up @@ -209,6 +228,7 @@ type Field struct {
Description *string
Embedded bool
Type Type
TypeAST types.Type
}

// ObjectInfo represents the serialization information of an object type.
Expand Down
21 changes: 20 additions & 1 deletion cmd/hasura-ndc-go/command/internal/schema_type_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ func (tp *TypeParser) Parse(fieldPaths []string) (*Field, error) {
if err != nil {
return nil, err
}
if ty == nil {
return nil, nil
}
tp.field.Type = ty
return tp.field, nil
}
Expand Down Expand Up @@ -314,6 +317,8 @@ func (tp *TypeParser) parseType(ty types.Type, fieldPaths []string) (Type, error
})

return NewNamedType(string(ScalarJSON), typeInfo), nil
case *types.Chan, *types.Signature, *types.Tuple, *types.Union:
return nil, nil
default:
return nil, fmt.Errorf("unsupported type: %s", ty.String())
}
Expand All @@ -322,6 +327,10 @@ func (tp *TypeParser) parseType(ty types.Type, fieldPaths []string) (Type, error
func (tp *TypeParser) parseStructType(objectInfo *ObjectInfo, inferredType *types.Struct, fieldPaths []string) error {
for i := 0; i < inferredType.NumFields(); i++ {
fieldVar := inferredType.Field(i)
if !fieldVar.Exported() {
continue
}

fieldTag := inferredType.Tag(i)
fieldKey, jsonOption := getFieldNameOrTag(fieldVar.Name(), fieldTag)
if jsonOption == jsonIgnore {
Expand All @@ -330,11 +339,15 @@ func (tp *TypeParser) parseStructType(objectInfo *ObjectInfo, inferredType *type
fieldParser := NewTypeParser(tp.schemaParser, &Field{
Name: fieldVar.Name(),
Embedded: fieldVar.Embedded(),
TypeAST: fieldVar.Type(),
}, fieldVar.Type(), tp.argumentFor)
field, err := fieldParser.Parse(append(fieldPaths, fieldVar.Name()))
if err != nil {
return err
}
if field == nil {
continue
}
embeddedObject, ok := tp.schemaParser.rawSchema.Objects[field.Type.SchemaName(false)]
if field.Embedded && ok {
// flatten embedded object fields to the parent object
Expand Down Expand Up @@ -466,7 +479,13 @@ func (tp *TypeParser) parseTypeInfoFromComments(typeInfo *TypeInfo, scope *types

func parseTypeParameters(rootType *TypeInfo, input string) error {
paramsString := strings.TrimPrefix(input, rootType.PackagePath+"."+rootType.Name)
rawParams := strings.Split(paramsString[1:len(paramsString)-1], ",")
if paramsString[0] == '[' {
paramsString = paramsString[1:]
}
if paramsString[len(paramsString)-1] == ']' {
paramsString = paramsString[:len(paramsString)-1]
}
rawParams := strings.Split(paramsString, ",")

for _, param := range rawParams {
param = strings.TrimSpace(param)
Expand Down
Loading