Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mhmd-azeez committed Feb 6, 2025
1 parent bee95a4 commit 579caa0
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 67 deletions.
1 change: 0 additions & 1 deletion crates/cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ fn main() -> Result<()> {

// Create a tmp dir to hold all the library objects
// This can go away once we do all the wasm-merge stuff in process
// use a specfic path instead: /tmp/wizer
let tmp_dir = TempDir::new()?;
let core_path = tmp_dir.path().join("core.wasm");
let shim_path = tmp_dir.path().join("shim.wasm");
Expand Down
108 changes: 45 additions & 63 deletions crates/cli/src/shims.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@ use anyhow::Result;
use std::path::Path;
use wagen::{BlockType, Instr, ValType};

#[derive(PartialEq)]
enum TypeCode {
Void = 0,
I32 = 1,
I64 = 2,
F32 = 3,
F64 = 4,
}

pub fn generate_wasm_shims(
path: impl AsRef<Path>,
exports: &Interface,
Expand All @@ -11,40 +20,16 @@ pub fn generate_wasm_shims(
let mut module = wagen::Module::new();

// Core imports for argument handling
let __arg_start = module.import("core", "__arg_start", None, vec![], vec![]);
let __arg_i32 = module.import("core", "__arg_i32", None, vec![ValType::I32], vec![]);
let __arg_i64 = module.import("core", "__arg_i64", None, vec![ValType::I64], vec![]);
let __arg_f32 = module.import("core", "__arg_f32", None, vec![ValType::F32], vec![]);
let __arg_f64 = module.import("core", "__arg_f64", None, vec![ValType::F64], vec![]);
let __invoke_i32 = module.import(
"core",
"__invoke_i32",
None,
vec![ValType::I32],
vec![ValType::I32],
);
let __invoke_i64 = module.import(
"core",
"__invoke_i64",
None,
vec![ValType::I32],
vec![ValType::I64],
);
let __invoke_f32 = module.import(
"core",
"__invoke_f32",
None,
vec![ValType::I32],
vec![ValType::F32],
);
let __invoke_f64 = module.import(
"core",
"__invoke_f64",
None,
vec![ValType::I32],
vec![ValType::F64],
);
let __invoke = module.import("core", "__invoke", None, vec![ValType::I32], vec![]);
let __arg_start = module.import("core", "__arg_start", None, [], []);
let __arg_i32 = module.import("core", "__arg_i32", None, [ValType::I32], []);
let __arg_i64 = module.import("core", "__arg_i64", None, [ValType::I64], []);
let __arg_f32 = module.import("core", "__arg_f32", None, [ValType::F32], []);
let __arg_f64 = module.import("core", "__arg_f64", None, [ValType::F64], []);
let __invoke_i32 = module.import("core", "__invoke_i32", None, [ValType::I32], [ValType::I32]);
let __invoke_i64 = module.import("core", "__invoke_i64", None, [ValType::I32], [ValType::I64]);
let __invoke_f32 = module.import("core", "__invoke_f32", None, [ValType::I32], [ValType::F32]);
let __invoke_f64 = module.import("core", "__invoke_f64", None, [ValType::I32], [ValType::F64]);
let __invoke = module.import("core", "__invoke", None, [ValType::I32], []);

// Create import functions vector
let mut import_elements = Vec::new();
Expand Down Expand Up @@ -74,36 +59,33 @@ pub fn generate_wasm_shims(
});

// Add a new function that returns the type code for a function index
let mut type_getter_builder = wagen::Builder::default();
let mut return_type_getter_builder = wagen::Builder::default();

for (func_idx, (_name, _index, _params, results)) in import_funcs.iter().enumerate() {
let type_code = match results.first() {
Some(ValType::I32) => 1,
Some(ValType::I64) => 2,
Some(ValType::F32) => 3,
Some(ValType::F64) => 4,
None => 0,
_ => 2, // Default to I64
};

if type_code == 0 {
// don't emit anything for void functions
let type_code = results.first().map_or(TypeCode::Void, |val_type| match val_type {
ValType::I32 => TypeCode::I32,
ValType::I64 => TypeCode::I64,
ValType::F32 => TypeCode::F32,
ValType::F64 => TypeCode::F64,
_ => TypeCode::Void,
});

if type_code == TypeCode::Void {
continue;
}

// Compare the input function index with the current index.
type_getter_builder.push(Instr::LocalGet(0));
type_getter_builder.push(Instr::I32Const(func_idx as i32));
type_getter_builder.push(Instr::I32Eq);
// Declare the if block with empty result type.
type_getter_builder.push(Instr::If(BlockType::Empty));
type_getter_builder.push(Instr::I32Const(type_code));
type_getter_builder.push(Instr::Return);
type_getter_builder.push(Instr::End);
return_type_getter_builder.push(Instr::LocalGet(0)); // load requested function index
return_type_getter_builder.push(Instr::I32Const(func_idx as i32)); // load func_idx
return_type_getter_builder.push(Instr::I32Eq); // compare
return_type_getter_builder.push(Instr::If(BlockType::Empty)); // if true
return_type_getter_builder.push(Instr::I32Const(type_code as i32)); // load type code
return_type_getter_builder.push(Instr::Return); // early return if match
return_type_getter_builder.push(Instr::End);
}

type_getter_builder.push(Instr::I32Const(0)); // Default to 0
type_getter_builder.push(Instr::Return);
return_type_getter_builder.push(Instr::I32Const(0)); // Default to 0
return_type_getter_builder.push(Instr::Return);

let return_type_getter = module.func(
"__get_function_return_type",
Expand All @@ -112,19 +94,19 @@ pub fn generate_wasm_shims(
vec![],
);
return_type_getter.export("__get_function_return_type");
return_type_getter.body = type_getter_builder;
return_type_getter.body = return_type_getter_builder;

let mut arg_type_getter_builder = wagen::Builder::default();

for (func_idx, (_name, _index, params, _results)) in import_funcs.iter().enumerate() {
// For each function
for arg_idx in 0..params.len() {
let type_code = match params[arg_idx] {
ValType::I32 => 1,
ValType::I64 => 2,
ValType::F32 => 3,
ValType::F64 => 4,
_ => 0, // Default/unknown type
ValType::I32 => TypeCode::I32,
ValType::I64 => TypeCode::I64,
ValType::F32 => TypeCode::F32,
ValType::F64 => TypeCode::F64,
_ => panic!("Unsupported argument type for function {} at index {}", func_idx, arg_idx),
};

// Compare both function index and argument index
Expand All @@ -140,7 +122,7 @@ pub fn generate_wasm_shims(

// If both match, return the type code
arg_type_getter_builder.push(Instr::If(BlockType::Empty));
arg_type_getter_builder.push(Instr::I32Const(type_code));
arg_type_getter_builder.push(Instr::I32Const(type_code as i32));
arg_type_getter_builder.push(Instr::Return);
arg_type_getter_builder.push(Instr::End);
}
Expand Down
16 changes: 13 additions & 3 deletions crates/core/src/globals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ fn build_module_object(this: Ctx) -> anyhow::Result<Object> {
Ok(module)
}


#[derive(PartialEq)]
enum TypeCode {
Void = 0,
I32 = 1,
I64 = 2,
F32 = 3,
F64 = 4,
}

fn build_host_object<'js>(this: Ctx<'js>) -> anyhow::Result<Object<'js>> {
let host_input_bytes = Function::new(
this.clone(),
Expand Down Expand Up @@ -231,12 +241,12 @@ fn add_host_functions<'a>(this: Ctx<'a>) -> anyhow::Result<()> {

let return_type = unsafe { __get_function_return_type(func_id) };
Ok(match return_type {
0 => Undefined.into_value(cx.clone()),
TYPE_VOID => Undefined.into_value(cx.clone()),
TYPE_I32 => Value::new_float(cx, (result & 0xFFFFFFFF) as i32 as f64),
TYPE_I64 => Value::new_float(cx, result as f64),
TYPE_F32 => Value::new_float(cx, f32::from_bits(result as u32) as f64),
TYPE_F64 => Value::new_float(cx, f64::from_bits(result)),
_ => Value::new_float(cx, result as f64)
_ => panic!("Unsupported return type: {:?}", return_type)
})
})?;

Expand Down Expand Up @@ -745,7 +755,7 @@ fn encode_js_string_to_utf8_buffer<'js>(
})
}

// Add after imports
const TYPE_VOID: u32 = 0;
const TYPE_I32: u32 = 1;
const TYPE_I64: u32 = 2;
const TYPE_F32: u32 = 3;
Expand Down

0 comments on commit 579caa0

Please sign in to comment.