diff --git a/__mocks__/typedData/example_enumNested.json b/__mocks__/typedData/example_enumNested.json new file mode 100644 index 000000000..6ee6a75bd --- /dev/null +++ b/__mocks__/typedData/example_enumNested.json @@ -0,0 +1,33 @@ +{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Example": [{ "name": "someEnum", "type": "enum", "contains": "EnumA" }], + "EnumA": [ + { "name": "Variant 1", "type": "()" }, + { "name": "Variant 2", "type": "(u128,StructA)" } + ], + "StructA": [{ "name": "nestedEnum", "type": "enum", "contains": "EnumB" }], + "EnumB": [ + { "name": "Variant A", "type": "()" }, + { "name": "Variant B", "type": "(StructB*)" } + ], + "StructB": [{ "name": "flag", "type": "bool" }] + }, + "primaryType": "Example", + "domain": { + "name": "StarkNet Mail", + "version": "1", + "chainId": "1", + "revision": "1" + }, + "message": { + "someEnum": { + "Variant 2": [2, { "nestedEnum": { "Variant B": [[{ "flag": true }, { "flag": false }]] } }] + } + } +} diff --git a/__tests__/utils/typedData.test.ts b/__tests__/utils/typedData.test.ts index 891e1aa3a..c21ea251d 100644 --- a/__tests__/utils/typedData.test.ts +++ b/__tests__/utils/typedData.test.ts @@ -3,6 +3,7 @@ import * as starkCurve from '@scure/starknet'; import typedDataExample from '../../__mocks__/typedData/baseExample.json'; import exampleBaseTypes from '../../__mocks__/typedData/example_baseTypes.json'; import exampleEnum from '../../__mocks__/typedData/example_enum.json'; +import exampleEnumNested from '../../__mocks__/typedData/example_enumNested.json'; import examplePresetTypes from '../../__mocks__/typedData/example_presetTypes.json'; import typedDataStructArrayExample from '../../__mocks__/typedData/mail_StructArray.json'; import typedDataSessionExample from '../../__mocks__/typedData/session_MerkleTree.json'; @@ -66,6 +67,10 @@ describe('typedData', () => { expect(encoded).toMatchInlineSnapshot( `"\\"Example\\"(\\"someEnum1\\":\\"EnumA\\",\\"someEnum2\\":\\"EnumB\\")\\"EnumA\\"(\\"Variant 1\\"(),\\"Variant 2\\"(\\"u128\\",\\"u128*\\"),\\"Variant 3\\"(\\"u128\\"))\\"EnumB\\"(\\"Variant 1\\"(),\\"Variant 2\\"(\\"u128\\"))"` ); + encoded = encodeType(exampleEnumNested.types, 'Example', TypedDataRevision.ACTIVE); + expect(encoded).toMatchInlineSnapshot( + `"\\"Example\\"(\\"someEnum\\":\\"EnumA\\")\\"EnumA\\"(\\"Variant 1\\"(),\\"Variant 2\\"(\\"u128\\",\\"StructA\\"))\\"EnumB\\"(\\"Variant A\\"(),\\"Variant B\\"(\\"StructB*\\"))\\"StructA\\"(\\"nestedEnum\\":\\"EnumB\\")\\"StructB\\"(\\"flag\\":\\"bool\\")"` + ); }); test('should get right type hash', () => { @@ -106,6 +111,10 @@ describe('typedData', () => { expect(typeHash).toMatchInlineSnapshot( `"0x393bf83422ca8626a2932696cfa0acb19dcad6de2fe84a2dd2ca7607ea5329a"` ); + typeHash = getTypeHash(exampleEnumNested.types, 'Example', TypedDataRevision.ACTIVE); + expect(typeHash).toMatchInlineSnapshot( + `"0x267f739fd83d30528a0fafb23df33b6c35ca0a5adbcfb32152721478fa9d0ce"` + ); }); test('should transform type selector', () => { @@ -329,6 +338,11 @@ describe('typedData', () => { `"0x150a589bb56a4fbf4ee01f52e44fd5adde6af94c02b37e383413fed185321a2"` ); + messageHash = getMessageHash(exampleEnumNested, exampleAddress); + expect(messageHash).toMatchInlineSnapshot( + `"0x6e70eb4ef625dda451094716eee7f31fa81ca0ba99d390885e9c7b0d64cd22"` + ); + expect(spyPedersen).not.toHaveBeenCalled(); expect(spyPoseidon).toHaveBeenCalled(); spyPedersen.mockRestore(); diff --git a/src/utils/typedData.ts b/src/utils/typedData.ts index 6a579dbd3..d16f34257 100644 --- a/src/utils/typedData.ts +++ b/src/utils/typedData.ts @@ -167,36 +167,46 @@ export function getDependencies( contains: string = '', revision: Revision = Revision.LEGACY ): string[] { + let dependencyTypes: string[] = [type]; + // Include pointers (struct arrays) if (type[type.length - 1] === '*') { - type = type.slice(0, -1); + dependencyTypes = [type.slice(0, -1)]; } else if (revision === Revision.ACTIVE) { // enum base if (type === 'enum') { - type = contains; + dependencyTypes = [contains]; } // enum element types else if (type.match(/^\(.*\)$/)) { - type = type.slice(1, -1); + dependencyTypes = type + .slice(1, -1) + .split(',') + .map((depType) => (depType[depType.length - 1] === '*' ? depType.slice(0, -1) : depType)); } } - if (dependencies.includes(type) || !types[type]) { - return dependencies; - } - - return [ - type, - ...(types[type] as StarknetEnumType[]).reduce( - (previous, t) => [ - ...previous, - ...getDependencies(types, t.type, previous, t.contains, revision).filter( - (dependency) => !previous.includes(dependency) - ), + return dependencyTypes + .filter((t) => !dependencies.includes(t) && types[t]) + .reduce( + // This comment prevents prettier from rolling everything here into a single line. + (p, depType) => [ + ...p, + ...[ + depType, + ...(types[depType] as StarknetEnumType[]).reduce( + (previous, t) => [ + ...previous, + ...getDependencies(types, t.type, previous, t.contains, revision).filter( + (dependency) => !previous.includes(dependency) + ), + ], + [] + ), + ].filter((dependency) => !p.includes(dependency)), ], [] - ), - ]; + ); } function getMerkleTreeType(types: TypedData['types'], ctx: Context) {