diff --git a/packages/protobuf/src/codegen-info.ts b/packages/protobuf/src/codegen-info.ts index fd737ecd9..e6314b865 100644 --- a/packages/protobuf/src/codegen-info.ts +++ b/packages/protobuf/src/codegen-info.ts @@ -47,11 +47,11 @@ interface CodegenInfo { | DescOneof | DescField | DescService - | DescMethod, + | DescMethod ) => string; readonly symbols: Record; readonly getUnwrappedFieldType: ( - field: DescField | DescExtension, + field: DescField | DescExtension ) => ScalarType | undefined; readonly wktSourceFiles: readonly string[]; /** @@ -60,7 +60,7 @@ interface CodegenInfo { readonly scalarDefaultValue: (type: ScalarType, longType: LongType) => any; // eslint-disable-line @typescript-eslint/no-explicit-any readonly scalarZeroValue: ( type: T, - longType: L, + longType: L ) => ScalarValue; /** * @deprecated please use reifyWkt from @bufbuild/protoplugin/ecmascript instead @@ -75,6 +75,7 @@ type RuntimeSymbolName = | "proto3" | "Message" | "PartialMessage" + | "PartialStrictMessage" | "PlainMessage" | "FieldList" | "MessageType" @@ -116,6 +117,7 @@ export const codegenInfo: CodegenInfo = { proto3: {typeOnly: false, privateImportPath: "./proto3.js", publicImportPath: packageName}, Message: {typeOnly: false, privateImportPath: "./message.js", publicImportPath: packageName}, PartialMessage: {typeOnly: true, privateImportPath: "./message.js", publicImportPath: packageName}, + PartialStrictMessage: {typeOnly: true, privateImportPath: "./message.js", publicImportPath: packageName}, PlainMessage: {typeOnly: true, privateImportPath: "./message.js", publicImportPath: packageName}, FieldList: {typeOnly: true, privateImportPath: "./field-list.js", publicImportPath: packageName}, MessageType: {typeOnly: true, privateImportPath: "./message-type.js", publicImportPath: packageName}, diff --git a/packages/protobuf/src/index.ts b/packages/protobuf/src/index.ts index 0a07feab7..add914e53 100644 --- a/packages/protobuf/src/index.ts +++ b/packages/protobuf/src/index.ts @@ -21,7 +21,12 @@ export { protoDelimited } from "./proto-delimited.js"; export { codegenInfo } from "./codegen-info.js"; export { Message } from "./message.js"; -export type { AnyMessage, PartialMessage, PlainMessage } from "./message.js"; +export type { + AnyMessage, + PartialMessage, + PartialStrictMessage, + PlainMessage, +} from "./message.js"; export { isMessage } from "./is-message.js"; export type { FieldInfo, OneofInfo } from "./field.js"; diff --git a/packages/protobuf/src/message.ts b/packages/protobuf/src/message.ts index cc9d3a1f1..f735e479d 100644 --- a/packages/protobuf/src/message.ts +++ b/packages/protobuf/src/message.ts @@ -46,7 +46,7 @@ export class Message = AnyMessage> { return this.getType().runtime.util.equals( this.getType(), this as unknown as T, - other, + other ); } @@ -96,7 +96,7 @@ export class Message = AnyMessage> { throw new Error( `cannot decode ${this.getType().typeName} from JSON: ${ e instanceof Error ? e.message : String(e) - }`, + }` ); } return this.fromJson(json, options); @@ -204,6 +204,22 @@ export type PartialMessage> = { [P in keyof T as T[P] extends Function ? never : P]?: PartialField; }; +export type PartialStrictMessage> = { + // eslint-disable-next-line @typescript-eslint/ban-types -- we use `Function` to identify methods + [P in keyof T as T[P] extends Function ? never : P]: PartialStrictField; +}; + +// prettier-ignore +type PartialStrictField = + F extends (Date | Uint8Array | bigint | boolean | string | number) ? F + : F extends Array ? Array> + : F extends ReadonlyArray ? ReadonlyArray> + : F extends Message ? PartialStrictMessage + : F extends OneofSelectedMessage ? {case: C; value: PartialStrictMessage} + : F extends { case: string | undefined; value?: unknown; } ? F + : F extends {[key: string|number]: Message} ? {[key: string|number]: PartialStrictMessage} + : F ; + // prettier-ignore type PartialField = F extends (Date | Uint8Array | bigint | boolean | string | number) ? F diff --git a/packages/protoc-gen-es/src/declaration.ts b/packages/protoc-gen-es/src/declaration.ts index 9c648891d..7c2190181 100644 --- a/packages/protoc-gen-es/src/declaration.ts +++ b/packages/protoc-gen-es/src/declaration.ts @@ -64,10 +64,11 @@ function generateEnum(schema: Schema, f: GeneratedFile, enumeration: DescEnum) { function generateMessage(schema: Schema, f: GeneratedFile, message: DescMessage) { const protoN = getNonEditionRuntime(schema, message.file); const { - PartialMessage, FieldList, Message, PlainMessage, + PartialStrictMessage, + PartialMessage, BinaryReadOptions, JsonReadOptions, JsonValue @@ -86,7 +87,8 @@ function generateMessage(schema: Schema, f: GeneratedFile, message: DescMessage) } f.print(); } - f.print(" constructor(data?: ", PartialMessage, "<", m, ">);"); + // eslint-disable-next-line @typescript-eslint/no-unsafe-argument + f.print(" constructor(data?: ", schema.strict ? PartialStrictMessage : PartialMessage, "<", m, ">);"); f.print(); generateWktMethods(schema, f, message); f.print(" static readonly runtime: typeof ", protoN, ";"); @@ -125,7 +127,7 @@ function generateOneof(schema: Schema, f: GeneratedFile, oneof: DescOneof) { f.print(` } | {`); } f.print(f.jsDoc(field, " ")); - const { typing } = getFieldTypeInfo(field); + const { typing } = getFieldTypeInfo(field, schema.strict); f.print(` value: `, typing, `;`); f.print(` case: "`, localName(field), `";`); } @@ -136,7 +138,7 @@ function generateOneof(schema: Schema, f: GeneratedFile, oneof: DescOneof) { function generateField(schema: Schema, f: GeneratedFile, field: DescField) { f.print(f.jsDoc(field, " ")); const e: Printable = []; - const { typing, optional } = getFieldTypeInfo(field); + const { typing, optional } = getFieldTypeInfo(field, schema.strict); if (!optional) { e.push(" ", localName(field), ": ", typing, ";"); } else { @@ -151,7 +153,7 @@ function generateExtension( f: GeneratedFile, ext: DescExtension, ) { - const { typing } = getFieldTypeInfo(ext); + const { typing } = getFieldTypeInfo(ext, schema.strict); const e = f.import(ext.extendee).toTypeOnly(); f.print(f.jsDoc(ext)); f.print(f.exportDecl("declare const", localName(ext)), ": ", schema.runtime.Extension, "<", e, ", ", typing, ">;"); @@ -232,7 +234,7 @@ function generateWktStaticMethods(schema: Schema, f: GeneratedFile, message: Des case "google.protobuf.BoolValue": case "google.protobuf.StringValue": case "google.protobuf.BytesValue": { - const {typing} = getFieldTypeInfo(ref.value); + const {typing} = getFieldTypeInfo(ref.value, schema.strict); f.print(" static readonly fieldWrapper: {") f.print(" wrapField(value: ", typing, "): ", message, ",") f.print(" unwrapField(value: ", message, "): ", typing, ",") diff --git a/packages/protoc-gen-es/src/typescript.ts b/packages/protoc-gen-es/src/typescript.ts index 0d2b7bc90..3f2a66af7 100644 --- a/packages/protoc-gen-es/src/typescript.ts +++ b/packages/protoc-gen-es/src/typescript.ts @@ -74,6 +74,7 @@ function generateMessage(schema: Schema, f: GeneratedFile, message: DescMessage) const protoN = getNonEditionRuntime(schema, message.file); const { PartialMessage, + PartialStrictMessage, FieldList, Message, PlainMessage, @@ -94,7 +95,7 @@ function generateMessage(schema: Schema, f: GeneratedFile, message: DescMessage) } f.print(); } - f.print(" constructor(data?: ", PartialMessage, "<", message, ">) {"); + f.print(" constructor(data?: ", schema.strict ? PartialStrictMessage : PartialMessage, "<", message, ">) {"); f.print(" super();"); f.print(" ", protoN, ".util.initPartial(data, this);"); f.print(" }"); @@ -148,7 +149,7 @@ function generateOneof(schema: Schema, f: GeneratedFile, oneof: DescOneof) { f.print(` } | {`); } f.print(f.jsDoc(field, " ")); - const { typing } = getFieldTypeInfo(field); + const { typing } = getFieldTypeInfo(field, schema.strict); f.print(` value: `, typing, `;`); f.print(` case: "`, localName(field), `";`); } @@ -159,7 +160,7 @@ function generateOneof(schema: Schema, f: GeneratedFile, oneof: DescOneof) { function generateField(schema: Schema, f: GeneratedFile, field: DescField) { f.print(f.jsDoc(field, " ")); const e: Printable = []; - const { typing, optional, typingInferrableFromZeroValue } = getFieldTypeInfo(field); + const { typing, optional, typingInferrableFromZeroValue } = getFieldTypeInfo(field, schema.strict); if (optional) { e.push(" ", localName(field), "?: ", typing, ";"); } else { @@ -184,7 +185,7 @@ function generateExtension( ext: DescExtension, ) { const protoN = getNonEditionRuntime(schema, ext.file); - const { typing } = getFieldTypeInfo(ext); + const { typing } = getFieldTypeInfo(ext, schema.strict); f.print(f.jsDoc(ext)); f.print(f.exportDecl("const", ext), " = ", protoN, ".makeExtension<", ext.extendee, ", ", typing, ">("); f.print(" ", f.string(ext.typeName), ", "); @@ -651,7 +652,7 @@ function generateWktStaticMethods(schema: Schema, f: GeneratedFile, message: Des case "google.protobuf.BoolValue": case "google.protobuf.StringValue": case "google.protobuf.BytesValue": { - const {typing} = getFieldTypeInfo(ref.value); + const {typing} = getFieldTypeInfo(ref.value, schema.strict); f.print(" static readonly fieldWrapper = {") f.print(" wrapField(value: ", typing, "): ", message, " {") f.print(" return new ", message, "({value});") diff --git a/packages/protoc-gen-es/src/util.ts b/packages/protoc-gen-es/src/util.ts index 5fa1f5a6a..08d4e0042 100644 --- a/packages/protoc-gen-es/src/util.ts +++ b/packages/protoc-gen-es/src/util.ts @@ -25,7 +25,10 @@ import { import type { Printable } from "@bufbuild/protoplugin/ecmascript"; import { localName } from "@bufbuild/protoplugin/ecmascript"; -export function getFieldTypeInfo(field: DescField | DescExtension): { +export function getFieldTypeInfo( + field: DescField | DescExtension, + strict: boolean +): { typing: Printable; optional: boolean; typingInferrableFromZeroValue: boolean; @@ -38,7 +41,7 @@ export function getFieldTypeInfo(field: DescField | DescExtension): { typing.push(scalarTypeScriptType(field.scalar, field.longType)); optional = field.optional || - field.proto.label === FieldDescriptorProto_Label.REQUIRED; + (!strict && field.proto.label === FieldDescriptorProto_Label.REQUIRED); typingInferrableFromZeroValue = true; break; case "message": { @@ -64,7 +67,7 @@ export function getFieldTypeInfo(field: DescField | DescExtension): { }); optional = field.optional || - field.proto.label === FieldDescriptorProto_Label.REQUIRED; + (!strict && field.proto.label === FieldDescriptorProto_Label.REQUIRED); typingInferrableFromZeroValue = true; break; case "map": { @@ -86,7 +89,7 @@ export function getFieldTypeInfo(field: DescField | DescExtension): { case "scalar": valueType = scalarTypeScriptType( field.mapValue.scalar, - LongType.BIGINT, + LongType.BIGINT ); break; case "message": @@ -127,7 +130,7 @@ export function getFieldDefaultValueExpression( enumAs: | "enum_value_as_is" | "enum_value_as_integer" - | "enum_value_as_cast_integer" = "enum_value_as_is", + | "enum_value_as_cast_integer" = "enum_value_as_is" ): Printable | undefined { if (field.repeated) { return undefined; @@ -142,11 +145,11 @@ export function getFieldDefaultValueExpression( switch (field.fieldKind) { case "enum": { const enumValue = field.enum.values.find( - (value) => value.number === defaultValue, + (value) => value.number === defaultValue ); if (enumValue === undefined) { throw new Error( - `invalid enum default value: ${String(defaultValue)} for ${enumValue}`, + `invalid enum default value: ${String(defaultValue)} for ${enumValue}` ); } return literalEnumValue(enumValue, enumAs); @@ -171,7 +174,7 @@ export function getFieldZeroValueExpression( enumAs: | "enum_value_as_is" | "enum_value_as_integer" - | "enum_value_as_cast_integer" = "enum_value_as_is", + | "enum_value_as_cast_integer" = "enum_value_as_is" ): Printable | undefined { if (field.repeated) { return "[]"; @@ -193,7 +196,7 @@ export function getFieldZeroValueExpression( case "scalar": { const defaultValue = codegenInfo.scalarZeroValue( field.scalar, - field.longType, + field.longType ); return literalScalarValue(defaultValue, field); } @@ -202,7 +205,7 @@ export function getFieldZeroValueExpression( function literalScalarValue( value: ScalarValue, - field: (DescField | DescExtension) & { fieldKind: "scalar" }, + field: (DescField | DescExtension) & { fieldKind: "scalar" } ): Printable { switch (field.scalar) { case ScalarType.DOUBLE: @@ -214,28 +217,28 @@ function literalScalarValue( case ScalarType.SINT32: if (typeof value != "number") { throw new Error( - `Unexpected value for ${ScalarType[field.scalar]} ${field.toString()}: ${String(value)}`, + `Unexpected value for ${ScalarType[field.scalar]} ${field.toString()}: ${String(value)}` ); } return value; case ScalarType.BOOL: if (typeof value != "boolean") { throw new Error( - `Unexpected value for ${ScalarType[field.scalar]} ${field.toString()}: ${String(value)}`, + `Unexpected value for ${ScalarType[field.scalar]} ${field.toString()}: ${String(value)}` ); } return value; case ScalarType.STRING: if (typeof value != "string") { throw new Error( - `Unexpected value for ${ScalarType[field.scalar]} ${field.toString()}: ${String(value)}`, + `Unexpected value for ${ScalarType[field.scalar]} ${field.toString()}: ${String(value)}` ); } return { kind: "es_string", value }; case ScalarType.BYTES: if (!(value instanceof Uint8Array)) { throw new Error( - `Unexpected value for ${ScalarType[field.scalar]} ${field.toString()}: ${String(value)}`, + `Unexpected value for ${ScalarType[field.scalar]} ${field.toString()}: ${String(value)}` ); } return value; @@ -246,7 +249,7 @@ function literalScalarValue( case ScalarType.FIXED64: if (typeof value != "bigint" && typeof value != "string") { throw new Error( - `Unexpected value for ${ScalarType[field.scalar]} ${field.toString()}: ${String(value)}`, + `Unexpected value for ${ScalarType[field.scalar]} ${field.toString()}: ${String(value)}` ); } return { @@ -263,7 +266,7 @@ function literalEnumValue( enumAs: | "enum_value_as_is" | "enum_value_as_integer" - | "enum_value_as_cast_integer", + | "enum_value_as_cast_integer" ): Printable { switch (enumAs) { case "enum_value_as_is": diff --git a/packages/protoplugin/src/ecmascript/parameter.ts b/packages/protoplugin/src/ecmascript/parameter.ts index 3ae8bda2f..47b7955a5 100644 --- a/packages/protoplugin/src/ecmascript/parameter.ts +++ b/packages/protoplugin/src/ecmascript/parameter.ts @@ -25,11 +25,12 @@ export interface ParsedParameter { importExtension: string; jsImportStyle: "module" | "legacy_commonjs"; sanitizedParameter: string; + strict: boolean; } export function parseParameter( parameter: string | undefined, - parseExtraOption: ((key: string, value: string) => void) | undefined, + parseExtraOption: ((key: string, value: string) => void) | undefined ): ParsedParameter { let targets: Target[] = ["js", "dts"]; let tsNocheck = true; @@ -38,6 +39,7 @@ export function parseParameter( const rewriteImports: RewriteImports = []; let importExtension = ".js"; let jsImportStyle: "module" | "legacy_commonjs" = "module"; + let strict = false; const rawParameters: string[] = []; for (const { key, value, raw } of splitParameter(parameter)) { // Whether this key/value plugin parameter pair should be @@ -94,7 +96,7 @@ export function parseParameter( if (parts.length !== 2) { throw new PluginOptionError( raw, - "must be in the form of :", + "must be in the form of :" ); } const [pattern, target] = parts; @@ -135,6 +137,21 @@ export function parseParameter( } break; } + case "strict": { + switch (value) { + case "true": + case "1": + strict = true; + break; + case "false": + case "0": + strict = false; + break; + default: + throw new PluginOptionError(raw); + } + break; + } default: if (parseExtraOption === undefined) { throw new PluginOptionError(raw); @@ -154,6 +171,7 @@ export function parseParameter( const sanitizedParameter = rawParameters.join(","); return { + strict, targets, tsNocheck, bootstrapWkt, @@ -166,7 +184,7 @@ export function parseParameter( } function splitParameter( - parameter: string | undefined, + parameter: string | undefined ): { key: string; value: string; raw: string }[] { if (parameter == undefined) { return []; diff --git a/packages/protoplugin/src/ecmascript/runtime-imports.ts b/packages/protoplugin/src/ecmascript/runtime-imports.ts index 28fed57f1..f330ddca9 100644 --- a/packages/protoplugin/src/ecmascript/runtime-imports.ts +++ b/packages/protoplugin/src/ecmascript/runtime-imports.ts @@ -21,6 +21,7 @@ export interface RuntimeImports { proto3: ImportSymbol; Message: ImportSymbol; PartialMessage: ImportSymbol; + PartialStrictMessage: ImportSymbol; PlainMessage: ImportSymbol; FieldList: ImportSymbol; MessageType: ImportSymbol; @@ -47,6 +48,7 @@ export function createRuntimeImports(bootstrapWkt: boolean): RuntimeImports { proto3: infoToSymbol("proto3", bootstrapWkt), Message: infoToSymbol("Message", bootstrapWkt), PartialMessage: infoToSymbol("PartialMessage", bootstrapWkt), + PartialStrictMessage: infoToSymbol("PartialStrictMessage", bootstrapWkt), PlainMessage: infoToSymbol("PlainMessage", bootstrapWkt), FieldList: infoToSymbol("FieldList", bootstrapWkt), MessageType: infoToSymbol("MessageType", bootstrapWkt), @@ -69,12 +71,12 @@ export function createRuntimeImports(bootstrapWkt: boolean): RuntimeImports { function infoToSymbol( name: keyof typeof codegenInfo.symbols, - bootstrapWkt: boolean, + bootstrapWkt: boolean ): ImportSymbol { const info = codegenInfo.symbols[name]; const symbol = createImportSymbol( name, - bootstrapWkt ? info.privateImportPath : info.publicImportPath, + bootstrapWkt ? info.privateImportPath : info.publicImportPath ); return info.typeOnly ? symbol.toTypeOnly() : symbol; } diff --git a/packages/protoplugin/src/ecmascript/schema.ts b/packages/protoplugin/src/ecmascript/schema.ts index a9994d849..d916247fb 100644 --- a/packages/protoplugin/src/ecmascript/schema.ts +++ b/packages/protoplugin/src/ecmascript/schema.ts @@ -78,6 +78,11 @@ export interface Schema { * The original google.protobuf.compiler.CodeGeneratorRequest. */ readonly proto: CodeGeneratorRequest; + + /** + * strict mode with `required` support + */ + readonly strict: boolean; } interface SchemaController extends Schema { @@ -90,7 +95,7 @@ export function createSchema( parameter: ParsedParameter, pluginName: string, pluginVersion: string, - featureSetDefaults: FeatureSetDefaults | undefined, + featureSetDefaults: FeatureSetDefaults | undefined ): SchemaController { const descriptorSet = createDescriptorSet(request.protoFile, { featureSetDefaults, @@ -98,13 +103,13 @@ export function createSchema( const filesToGenerate = findFilesToGenerate(descriptorSet, request); const runtime = createRuntimeImports(parameter.bootstrapWkt); const createTypeImport = ( - desc: DescMessage | DescEnum | DescExtension, + desc: DescMessage | DescEnum | DescExtension ): ImportSymbol => { const name = codegenInfo.localName(desc); const from = makeImportPath( desc.file, parameter.bootstrapWkt, - filesToGenerate, + filesToGenerate ); return createImportSymbol(name, from); }; @@ -114,12 +119,13 @@ export function createSchema( pluginName, pluginVersion, parameter.sanitizedParameter, - parameter.tsNocheck, + parameter.tsNocheck ); let target: Target | undefined; const generatedFiles: GeneratedFileController[] = []; return { targets: parameter.targets, + strict: parameter.strict, runtime, proto: request, files: filesToGenerate, @@ -127,7 +133,7 @@ export function createSchema( generateFile(name) { if (target === undefined) { throw new Error( - "prepareGenerate() must be called before generateFile()", + "prepareGenerate() must be called before generateFile()" ); } const genFile = createGeneratedFile( @@ -138,11 +144,11 @@ export function createSchema( rewriteImportPath( importPath, parameter.rewriteImports, - parameter.importExtension, + parameter.importExtension ), createTypeImport, runtime, - createPreamble, + createPreamble ); generatedFiles.push(genFile); return genFile; @@ -160,17 +166,15 @@ export function createSchema( function findFilesToGenerate( descriptorSet: DescriptorSet, - request: CodeGeneratorRequest, + request: CodeGeneratorRequest ) { const missing = request.fileToGenerate.filter((fileToGenerate) => - descriptorSet.files.every( - (file) => fileToGenerate !== file.name + ".proto", - ), + descriptorSet.files.every((file) => fileToGenerate !== file.name + ".proto") ); if (missing.length) { throw `files_to_generate missing in the request: ${missing.join(", ")}`; } return descriptorSet.files.filter((file) => - request.fileToGenerate.includes(file.name + ".proto"), + request.fileToGenerate.includes(file.name + ".proto") ); }