From dd66b2f376dbca5f0794fd0fe307b1227e7c986b Mon Sep 17 00:00:00 2001 From: matteo-cristino Date: Fri, 10 Jan 2025 19:44:29 +0100 Subject: [PATCH] fix(random): use unifrom distribution to generate random number with modulo, shuffle arrays and extract subtable unifrom distribution is obtained by using rejection sampling and reservoir sampling Moreover deprecate statements: * create random object of '' bits * create random object of '' bytes * create array of '' random objects * create array of '' random objects of '' bits * create array of '' random objects of '' bytes * pick random object in '' * create random dictionary with '' random objects from '' in favour of, respectively: * create random of '' bits * create random of '' bytes * create array of '' random * create array of '' random of '' bits * create array of '' random of '' bytes * create random pick from '' * create random table with '' random pick from '' --- src/lua/zencode_random.lua | 329 ++++++++++++++++++++----------- test/zencode/array.bats | 10 +- test/zencode/bbs_zkp.bats | 12 +- test/zencode/bbs_zkp_shake.bats | 8 +- test/zencode/cookbook_intro.bats | 10 +- test/zencode/cookbook_when.bats | 54 ++--- test/zencode/dictionary.bats | 16 +- test/zencode/dp3t.bats | 16 +- test/zencode/given.bats | 12 +- test/zencode/output.bats | 2 +- test/zencode/random.bats | 48 ++--- test/zencode/secshare.bats | 21 +- 12 files changed, 313 insertions(+), 225 deletions(-) diff --git a/src/lua/zencode_random.lua b/src/lua/zencode_random.lua index 4cf55c2da..a4cfc9252 100644 --- a/src/lua/zencode_random.lua +++ b/src/lua/zencode_random.lua @@ -20,6 +20,44 @@ --on Saturday, 27th November 2021 --]] +-- utils +local function _get_bytes(n) + local num = tonumber(mayhave(n) or n) + if not num then error("Argument is not a number: "..n, 2) end + return math.ceil(num) +end +local function _get_bytes_from_bits(n) + local num = tonumber(mayhave(n) or n) + if not num then error("Argument is not a number: "..n, 2) end + return math.ceil(num/8) +end + +-- uniformity obtained with rejection sampling +local function _random_modulo_uniform_distribution(modulo, max_random, random_f) + if not modulo then + error("modulo input argument is missing", 2) + end + local max_random = max_random or 65536 + if type(max_random) ~= type(modulo) then + error("max_random and modulo have different types: "..type(max_random).." and "..type(modulo), 2) + end + if max_random < modulo then + error("max_random is less than modulo", 2) + end + local random_f = random_f or random_int16 + local max_uniform_random + if type(modulo) == 'zenroom.big' then + max_uniform_random = (max_random/modulo)*modulo + else + max_uniform_random = math.floor(max_random/modulo)*modulo + end + local random = random_f() + while random >= max_uniform_random do + random = random_f() + end + return (random % modulo) +1 +end + -- random operations, mostly on arrays and schemas supported When("seed random with ''", @@ -32,87 +70,129 @@ When("seed random with ''", end ) -When("create random ''", function(dest) - zencode_assert(not ACK[dest], "Cannot overwrite existing value: "..dest) - ACK[dest] = OCTET.random(32) -- TODO: right now hardcoded 256 bit random secrets - new_codec(dest, { zentype = 'e' }) -end) +-- random octets -local function shuffle_array_f(tab) - -- do not enforce CODEC detection since some schemas are also 1st level arrays - local count = isarray(tab) - zencode_assert( count > 0, "Randomized object is not an array") - local res = { } - for i = count,2,-1 do - local r = (random_int16() % i)+1 - table.insert(res,tab[r]) -- limit 16bit lenght for arrays - table.remove(tab, r) - end - table.insert(res,tab[1]) - return res +local function _create_random(dest, bytes) + empty(dest) + ACK[dest] = OCTET.random(bytes) + new_codec(dest, { zentype = 'e' }) end --- random and hashing operations -When("create random object of '' bits", function(n) - empty'random object' - local bits = tonumber(mayhave(n) or n) - zencode_assert(bits, 'Invalid number of bits: ' .. n) - ACK.random_object = OCTET.random(math.ceil(bits / 8)) - new_codec('random_object', { zentype = 'e' }) -end +When("create random ''", function(dest) + _create_random(dest, 32) +end) +When("create random of '' bits", function(n) + _create_random('random', _get_bytes_from_bits(n)) +end) +When("create random of '' bytes", function(n) + _create_random('random', _get_bytes(n)) +end) + +When( + deprecated( + "create random object of '' bits", + "create random of '' bits", + function(n) _create_random('random_object', n, 8) end + ) ) -When("create random object of '' bytes",function(n) - empty'random object' - local bytes = math.ceil(tonumber(mayhave(n) or n)) - zencode_assert(bytes, 'Invalid number of bytes: ' .. n) - ACK.random_object = OCTET.random(bytes) - new_codec('random_object', { zentype = 'e' }) -end +When( + deprecated( + "create random object of '' bytes", + "create random of '' bytes", + function(n) _create_random('random_object', n) end + ) ) -When("randomize '' array", function(arr) - local A = have(arr) - -- ZEN.assert(ZEN.CODEC[arr].zentype == 'a', "Object is not an array: "..arr) - ACK[arr] = shuffle_array_f(A) -end) +-- array shuffle + +-- Fisher-Yates algorithm +local function shuffle_array_f(arr) + local tab, c_tab = have(arr) + if (c_tab.zentype ~= 'a' and (not c_tab.schema or not isarray(tab))) then + error("Object to be randomized is not an array: "..arr, 2) + end + local tab_len = #tab + local res = { } + for i = tab_len,2,-1 do + local r = _random_modulo_uniform_distribution(i) + r = (r % i) + 1 + tab[i], tab[r] = tab[r], tab[i] + end +end -local function _create_random_array(array_length, fun_input, fun) +When("randomize '' array", shuffle_array_f) + +-- random array + +local function _create_random_array(array_length, fun_input, fun, codec) empty 'array' - ACK.array = { } local length = tonumber(mayhave(array_length) or array_length) zencode_assert(length, "Argument is not a number: "..array_length) + ACK.array = {} for i = length,1,-1 do table.insert(ACK.array, fun(fun_input)) end + local n_codec = {zentype = 'a'} + if codec then + for k, v in pairs(codec) do + n_codec[k] = v + end + end + new_codec('array', n_codec) end -When("create array of '' random objects", function(s) +When("create array of '' random", function(s) _create_random_array(s, 64, OCTET.random) - new_codec('array', {zentype = 'a'}) end) - -When("create array of '' random objects of '' bits", function(s, b) - local bits = tonumber(mayhave(b) or b) - zencode_assert(bits, "Argument is not a number: "..b) - local bytes = math.ceil(bits/8) - _create_random_array(s, bytes, OCTET.random) - new_codec('array', {zentype = 'a'}) +When("create array of '' random of '' bits", function(s, b) + _create_random_array(s, _get_bytes_from_bits(b), OCTET.random) end) - -When("create array of '' random objects of '' bytes", function(s, b) - local n_bytes = tonumber(mayhave(b) or b) - zencode_assert(n_bytes, "Argument is not a number: "..b) - local bytes = math.ceil(n_bytes) - _create_random_array(s, bytes, OCTET.random) - new_codec('array', {zentype = 'a'}) +When("create array of '' random of '' bytes", function(s, b) + _create_random_array(s, _get_bytes(b), OCTET.random) end) +When( + deprecated( + "create array of '' random objects", + "create array of '' random", + function(s) + _create_random_array(s, 64, OCTET.random) + end + ) +) +When( + deprecated( + "create array of '' random objects of '' bits", + "create array of '' random of '' bits", + function(s, b) + _create_random_array(s, _get_bytes_from_bits(b), OCTET.random) + end + ) +) +When( + deprecated( + "create array of '' random objects of '' bytes", + "create array of '' random of '' bytes", + function(s, b) + _create_random_array(s, _get_bytes(b), OCTET.random) + end + ) +) + When("create array of '' random numbers", function(s) - _create_random_array(s, null, BIG.random) - new_codec('array', {zentype = 'a', encoding = 'integer' }) + _create_random_array(s, null, BIG.random, {encoding = 'integer'}) end) - +local random_generator = { + ['zenroom.big'] = { + fun = function(input_modulo) return _random_modulo_uniform_distribution(input_modulo, ECP.order(), BIG.random) end, + enc = {encoding = 'integer'} + }, + ['zenroom.float'] = { + fun = function(input_modulo) return F.new(_random_modulo_uniform_distribution(tonumber(input_modulo))) end, + enc = {encoding = 'float'} + } +} When("create array of '' random numbers modulo ''", function(s,m) local modulo = mayhave(m) if not modulo then @@ -120,69 +200,90 @@ When("create array of '' random numbers modulo ''", function(s,m) zencode_assert(mod, "Argument is not a number: "..m) modulo = BIG.new(mod) end - local fun - local enc local modulo_type = type(modulo) - if modulo_type == "zenroom.big" then - fun = function(input) return BIG.random() % input end - enc = 'integer' - elseif modulo_type == "zenroom.float" then - fun = function(input) return F.new(math.floor(random_int16() % tonumber(input))) end - enc = 'float' - else + local random_gen = random_generator[modulo_type] + if not random_gen then error("Modulo is not a number nor an integer: "..modulo_type) end - _create_random_array(s, modulo, fun) - new_codec('array', {zentype = 'a', encoding = enc }) + _create_random_array(s, modulo, random_gen.fun, random_gen.enc) end) +-- pick random element + +-- reservoir sampling algorithm +local function _extract_random_elements(dest, num, from) + empty(dest) + local n = tonumber(num) or tonumber(tostring(have(num))) + zencode_assert(n and n>0, "Not a number or not a positive number: "..num) + local src, src_codec = have(from) + zencode_assert(luatype(src) == 'table', "Object is not a table: "..from) + local is_array = isarray(src) -local function _extract_random_elements(num, from, random_fun) - local n = tonumber(num) or tonumber(tostring(have(num))) - zencode_assert(n and n>=0, "Not a number or not a positive number: "..num) - local src = have(from) - zencode_assert(luatype(src) == 'table', "Object is not a table: "..from) - - local tmp = { } - local keys = { } - for k,v in pairs(src) do - table.insert(keys, k) - table.insert(tmp, v) - end - - local len = #tmp - local max_len = 65536 - zencode_assert(len < max_len, "The number of elements of "..from.." exceed the maximum length: "..max_len) - zencode_assert(n <= len, num.." is grater than the number of elements in "..from) - local max_random = math.floor(max_len/len)*len - - local dst = { } - while(n ~= 0) do - local r = random_fun() - while r >= max_random do - r = random_fun() - end - r = (r % len) +1 - if keys[r] ~= nil then - if tonumber(keys[r]) then - table.insert(dst ,tmp[r]) - else - dst[keys[r]] = tmp[r] - end - keys[r] = nil - tmp[r] = nil - n = n - 1 - end - end - return dst + local keys = {} + local values = {} + for k,v in pairs(src) do + table.insert(keys, k) + table.insert(values, v) + end + + local len = #keys + local max_len = 65536 + zencode_assert(len < max_len, "The number of elements of "..from.." exceed the maximum length: "..max_len) + zencode_assert(n <= len, num.." is grater than the number of elements in "..from) + + local dst = {} + for i = 1, n do + if is_array then + dst[i] = values[i] + else + dst[keys[i]] = values[i] + end + end + + for i = n+1, len do + local r = random_int16() + if r % i < n then + local replace_index = (r % n) + 1 + if is_array then + dst[replace_index] = values[i] + else + dst[keys[replace_index]] = nil + dst[keys[i]] = values[i] + end + end + end + + local n_codec = {encoding = src_codec.encoding} + if (n == 1) then + n_codec.name, ACK[dest] = next(dst) + else + ACK[dest] = dst + end + new_codec(dest, n_codec) end -When("pick random object in ''", function(from) - key, ACK.random_object = next(_extract_random_elements(1, from, random_int16)) - new_codec('random_object', {name=key, encoding = CODEC[from].encoding}) +When("create random pick from ''", function(from) + _extract_random_elements('random_pick', 1, from) end) - -When("create random dictionary with '' random objects from ''", function(num, from) - ACK.random_dictionary = _extract_random_elements(num, from, random_int16) - new_codec('random_dictionary', {encoding = CODEC[from].encoding}) +When("create random table with '' random pick from ''", function(num, from) + _extract_random_elements('random_table', num, from) end) + +When( + deprecated( + "pick random object in ''", + "create random pick from ''", + function(from) + _extract_random_elements('random_object', 1, from) + end + ) +) +When( + deprecated( + "create random dictionary with '' random objects from ''", + "create random table with '' random pick from ''", + function(num, from) + _extract_random_elements('random_dictionary', num, from) + end + ) +) diff --git a/test/zencode/array.bats b/test/zencode/array.bats index 003811777..1b16e11bb 100755 --- a/test/zencode/array.bats +++ b/test/zencode/array.bats @@ -6,7 +6,7 @@ SUBDOC=array cat <