Skip to content

Commit

Permalink
ZLUDA 3.8.4 (#46)
Browse files Browse the repository at this point in the history
* Restore cublas argument. (injector)

* Implement some Dark API functions (#41)

* Implement some Dark API functions

* Better error handling

* Implement mul24.lo.

* Implement mul24.hi.

* Fix mul24.lo implementation.

* Make mul24 tests more thorough.

* Add ZLUDA_COMGR_LOG_LEVEL.

* Bring back the minimal implementations of runtime API. (#45)

* [Fix] Handle stream correctly.

* WIP

* Fix fatbin.

* Revert.

* wip

* Remove redundant functions.

* Bump version.

---------

Co-authored-by: SEt <[email protected]>
  • Loading branch information
lshqqytiger and SEt-t authored Sep 11, 2024
1 parent 1c238a9 commit c0804ca
Show file tree
Hide file tree
Showing 22 changed files with 6,620 additions and 64 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ members = [
"zluda_redirect",
"zluda_rt",
"zluda_rtc",
"zluda_runtime",
"zluda_sparse",
]

Expand Down
15 changes: 13 additions & 2 deletions comgr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use hip_common::CompilationMode;
use itertools::Either;
use std::{
borrow::Borrow,
env,
ffi::{CStr, CString},
iter, mem, ptr,
rc::Rc,
Expand Down Expand Up @@ -39,7 +40,7 @@ macro_rules! call {

pub type Result<T> = std::result::Result<T, sys::amd_comgr_status_t>;

pub struct Comgr(LibComgr, AtomicU64);
pub struct Comgr(LibComgr, AtomicU64, u32);

static WAVE32_MODULE: &'static [u8] = include_bytes!("wave32.ll");
static WAVE32_ON_WAVE64_MODULE: &'static [u8] = include_bytes!("wave32_on_wave64.ll");
Expand All @@ -53,7 +54,14 @@ static OS_MODULE: &'static [u8] = include_bytes!("linux.ll");
impl Comgr {
pub fn find_and_load() -> Result<Self> {
match unsafe { Self::load_library() } {
Ok(libcomgr) => Ok(Self(libcomgr, AtomicU64::new(1))),
Ok(libcomgr) => Ok(Self(
libcomgr,
AtomicU64::new(1),
env::var("ZLUDA_COMGR_LOG_LEVEL")
.unwrap_or("0".into())
.parse()
.expect("Unexpected value for ZLUDA_COMGR_LOG_LEVEL."),
)),
Err(_) => Err(sys::amd_comgr_status_t::AMD_COMGR_STATUS_ERROR),
}
}
Expand Down Expand Up @@ -291,6 +299,9 @@ impl Comgr {
CStr::from_bytes_with_nul_unchecked(b"-mwavefrontsize64\0")
}
};
if self.2 == 1 {
eprintln!("Compiling in progress. Please wait...");
}
let relocatable = self.do_action(
sys::amd_comgr_action_kind_t::AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE,
bc_linking_output,
Expand Down
21 changes: 1 addition & 20 deletions hip_runtime-sys/src/hip_runtime_api_v6.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7495,17 +7495,9 @@ extern "C" {
extern "C" {
#[must_use]
pub fn __hipRegisterFatBinary(
data: *mut ::std::os::raw::c_void,
data: *const ::std::os::raw::c_void,
) -> *mut *mut ::std::os::raw::c_void;
}
/*
extern "C" {
#[must_use]
pub fn __hipRegisterFatBinaryEnd(
fatCubinHandle: *mut *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_void;
}
*/
extern "C" {
#[must_use]
pub fn __hipRegisterFunction(
Expand All @@ -7521,17 +7513,6 @@ extern "C" {
wSize: *mut ::std::os::raw::c_int,
) -> ::std::os::raw::c_void;
}
/*
extern "C" {
#[must_use]
pub fn __hipRegisterHostVar(
fatCubinHandle: *mut *mut ::std::os::raw::c_void,
deviceName: *const ::std::os::raw::c_char,
hostVar: *mut ::std::os::raw::c_char,
size: usize,
) -> ::std::os::raw::c_void;
}
*/
extern "C" {
#[must_use]
pub fn __hipRegisterManagedVar(
Expand Down
Binary file modified ptx/lib/zluda_ptx_impl.bc
Binary file not shown.
32 changes: 29 additions & 3 deletions ptx/lib/zluda_ptx_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1456,9 +1456,11 @@ extern "C"
}

// Keep scanning until we figure out the length of this specifier or if we reach the end of the string
while (*s != 0) {
while (*s != 0)
{
// "The width is not specified in the format string, but as an additional integer value argument preceding the argument that has to be formatted."
if (*s == '*') {
if (*s == '*')
{
s++;
uint64_t value = read_valist(valist_ptr, valist_offset, 4);
handle = __ockl_printf_append_args(handle, 1, value, 0, 0, 0, 0, 0, 0, 0);
Expand All @@ -1476,7 +1478,9 @@ extern "C"
if (specifier_with_length)
{
s += specifier_with_length;
} else {
}
else
{
// Assume the unknown character is a sub-specifier and move on
s++;
continue;
Expand All @@ -1501,6 +1505,17 @@ extern "C"
return __ockl_mul_hi_i64(x, y);
}

int32_t FUNC(mul24_hi_s32)(int32_t x, int32_t y)
{
return static_cast<int32_t>(((static_cast<int64_t>(x) * y) & 0x0000FFFFFFFFFFFF) >> 16);
}

int32_t __ockl_mul24_i32(int32_t x, int32_t y) __attribute__((device));
int32_t FUNC(mul24_lo_s32)(int32_t x, int32_t y)
{
return __ockl_mul24_i32(x, y);
}

int64_t FUNC(mad_hi_s64)(int64_t a, int64_t b, int64_t c)
{
int64_t temp = FUNC_CALL(mul_hi_s64)(a, b);
Expand All @@ -1513,6 +1528,17 @@ extern "C"
return __ockl_mul_hi_u64(x, y);
}

uint32_t FUNC(mul24_hi_u32)(uint32_t x, uint32_t y)
{
return static_cast<uint32_t>(((static_cast<uint64_t>(x) * y) & 0x0000FFFFFFFFFFFF) >> 16);
}

uint32_t __ockl_mul24_u32(uint32_t x, uint32_t y) __attribute__((device));
uint32_t FUNC(mul24_lo_u32)(uint32_t x, uint32_t y)
{
return __ockl_mul24_u32(x, y);
}

uint64_t FUNC(mad_hi_u64)(uint64_t a, uint64_t b, uint64_t c)
{
uint64_t temp = FUNC_CALL(mul_hi_u64)(a, b);
Expand Down
25 changes: 25 additions & 0 deletions ptx/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ pub enum Instruction<P: ArgParams> {
Ld(LdDetails, Arg2Ld<P>),
Mov(MovDetails, Arg2Mov<P>),
Mul(MulDetails, Arg3<P>),
Mul24(Mul24Details, Arg3<P>),
Add(ArithDetails, Arg3<P>),
AddC(CarryInDetails, Arg3<P>),
AddCC(ScalarType, Arg3<P>),
Expand Down Expand Up @@ -830,13 +831,25 @@ pub struct MulIntDesc {
pub control: MulIntControl,
}

#[derive(Copy, Clone)]
pub struct Mul24IntDesc {
pub typ: ScalarType,
pub control: Mul24IntControl,
}

#[derive(Copy, Clone, PartialEq, Eq)]
pub enum MulIntControl {
Low,
High,
Wide,
}

#[derive(Copy, Clone, PartialEq, Eq)]
pub enum Mul24IntControl {
Low,
High,
}

#[derive(PartialEq, Eq, Copy, Clone)]
pub enum RoundingMode {
NearestEven,
Expand Down Expand Up @@ -1027,6 +1040,18 @@ pub struct MulInt {
pub control: MulIntControl,
}

#[derive(Copy, Clone)]
pub enum Mul24Details {
Unsigned(Mul24Int),
Signed(Mul24Int),
}

#[derive(Copy, Clone)]
pub struct Mul24Int {
pub typ: ScalarType,
pub control: Mul24IntControl,
}

#[derive(Copy, Clone)]
pub enum ArithDetails {
Unsigned(ScalarType),
Expand Down
37 changes: 31 additions & 6 deletions ptx/src/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,7 @@ fn emit_instruction(
ast::Instruction::Ld(details, args) => emit_inst_ld(ctx, details, args)?,
ast::Instruction::Mov(details, args) => emit_inst_mov(ctx, details, args)?,
ast::Instruction::Mul(details, args) => emit_inst_mul(ctx, details, args)?,
ast::Instruction::Mul24(details, args) => emit_inst_mul24(ctx, details, args)?,
ast::Instruction::Add(details, args) => emit_inst_add(ctx, details, args)?,
ast::Instruction::Setp(details, args) => emit_inst_setp(ctx, details, args, None)?,
ast::Instruction::SetpBool(details, args) => emit_inst_setp_bool(ctx, details, args)?,
Expand Down Expand Up @@ -1730,12 +1731,8 @@ fn emit_inst_sqrt(
(ast::ScalarType::F64, ast::RcpSqrtKind::Approx) => {
(&b"llvm.sqrt.f64\0"[..], FastMathFlags::ApproxFunc)
}
(ast::ScalarType::F32, _) => {
(&b"llvm.sqrt.f32\0"[..], FastMathFlags::empty())
},
(ast::ScalarType::F64, _) => {
(&b"llvm.sqrt.f64\0"[..], FastMathFlags::empty())
},
(ast::ScalarType::F32, _) => (&b"llvm.sqrt.f32\0"[..], FastMathFlags::empty()),
(ast::ScalarType::F64, _) => (&b"llvm.sqrt.f64\0"[..], FastMathFlags::empty()),
_ => return Err(TranslateError::unreachable()),
};
let sqrt_result = emit_intrinsic_arg2(
Expand Down Expand Up @@ -2836,6 +2833,34 @@ fn emit_inst_mul(
}
}

fn emit_inst_mul24(
ctx: &mut EmitContext,
details: &ast::Mul24Details,
args: &ast::Arg3<crate::translate::ExpandedArgParams>,
) -> Result<(), TranslateError> {
match details {
ast::Mul24Details::Unsigned(ast::Mul24Int {
control: ast::Mul24IntControl::Low,
..
})
| ast::Mul24Details::Signed(ast::Mul24Int {
control: ast::Mul24IntControl::Low,
..
}) => emit_inst_mul_lo(ctx, args, LLVMBuildMul),
ast::Mul24Details::Unsigned(ast::Mul24Int {
control: ast::Mul24IntControl::High,
typ,
})
| ast::Mul24Details::Signed(ast::Mul24Int {
control: ast::Mul24IntControl::High,
typ,
}) => {
emit_inst_mul_hi(ctx, *typ, args)?;
Ok(())
}
}
}

fn emit_inst_mul_hi(
ctx: &mut EmitContext,
type_: ast::ScalarType,
Expand Down
24 changes: 24 additions & 0 deletions ptx/src/ptx.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ match {
"min",
"mov",
"mul",
"mul24",
"nanosleep",
"neg",
"not",
Expand Down Expand Up @@ -299,6 +300,7 @@ ExtendedID : &'input str = {
"min",
"mov",
"mul",
"mul24",
"nanosleep",
"neg",
"not",
Expand Down Expand Up @@ -782,6 +784,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstLd,
InstMov,
InstMul,
InstMul24,
InstAdd,
InstAddC,
InstAddCC,
Expand Down Expand Up @@ -993,6 +996,27 @@ MulIntControl: ast::MulIntControl = {
".wide" => ast::MulIntControl::Wide
};

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul24
InstMul24: ast::Instruction<ast::ParsedArgParams<'input>> = {
"mul24" <d:Mul24Details> <a:Arg3> => ast::Instruction::Mul24(d, a)
};

Mul24Details: ast::Mul24Details = {
<ctr:Mul24IntControl> <t:UIntType> => ast::Mul24Details::Unsigned(ast::Mul24Int{
typ: t,
control: ctr
}),
<ctr:Mul24IntControl> <t:SIntType> => ast::Mul24Details::Signed(ast::Mul24Int{
typ: t,
control: ctr
})
};

Mul24IntControl: ast::Mul24IntControl = {
".hi" => ast::Mul24IntControl::High,
".lo" => ast::Mul24IntControl::Low
};

#[inline]
RoundingModeFloat : ast::RoundingMode = {
".rn" => ast::RoundingMode::NearestEven,
Expand Down
2 changes: 2 additions & 0 deletions ptx/src/test/spirv_run/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ test_ptx!(ld_st_implicit, [0.5f32, 0.25f32], [0.5f32]);
test_ptx!(mov, [1u64], [1u64]);
test_ptx!(mul_lo, [1u64], [2u64]);
test_ptx!(mul_hi, [u64::max_value()], [1u64]);
test_ptx!(mul24_lo, [0xEA129Bu32], [0xBAF20C63u32]);
test_ptx!(mul24_hi, [0xEA129Bu32], [0x88F1BAF2u32]);
test_ptx!(add, [1u64], [2u64]);
test_ptx!(add_global, [1f32], [0x408487EEu32]);
test_ptx!(amdgpu_unnamed, [2u64], [3u64]);
Expand Down
35 changes: 35 additions & 0 deletions ptx/src/test/spirv_run/mul24_hi.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7"
target triple = "amdgcn-amd-amdhsa"

declare i32 @__zluda_ptx_impl__mul24_hi_u32(i32, i32) #0

define protected amdgpu_kernel void @mul24_hi(ptr addrspace(4) byref(i64) %"18", ptr addrspace(4) byref(i64) %"19") #1 {
%"8" = alloca i1, align 1, addrspace(5)
%"4" = alloca i64, align 8, addrspace(5)
%"5" = alloca i64, align 8, addrspace(5)
%"6" = alloca i32, align 4, addrspace(5)
%"7" = alloca i32, align 4, addrspace(5)
br label %1

1: ; preds = %0
store i1 false, ptr addrspace(5) %"8", align 1
%"9" = load i64, ptr addrspace(4) %"18", align 8
store i64 %"9", ptr addrspace(5) %"4", align 8
%"10" = load i64, ptr addrspace(4) %"19", align 8
store i64 %"10", ptr addrspace(5) %"5", align 8
%"12" = load i64, ptr addrspace(5) %"4", align 8
%"20" = inttoptr i64 %"12" to ptr
%"11" = load i32, ptr %"20", align 4
store i32 %"11", ptr addrspace(5) %"6", align 4
%"14" = load i32, ptr addrspace(5) %"6", align 4
%"13" = call i32 @__zluda_ptx_impl__mul24_hi_u32(i32 %"14", i32 9815513)
store i32 %"13", ptr addrspace(5) %"7", align 4
%"15" = load i64, ptr addrspace(5) %"5", align 8
%"16" = load i32, ptr addrspace(5) %"7", align 4
%"21" = inttoptr i64 %"15" to ptr
store i32 %"16", ptr %"21", align 4
ret void
}

attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "no-trapping-math"="true" "uniform-work-group-size"="true" }
attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee,ieee" "denormal-fp-math-f32"="ieee,ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
22 changes: 22 additions & 0 deletions ptx/src/test/spirv_run/mul24_hi.ptx
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64

.visible .entry mul24_hi(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u32 temp;
.reg .u32 temp2;

ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];

ld.u32 temp, [in_addr];
mul24.hi.u32 temp2, temp, 9815513;
st.u32 [out_addr], temp2;
ret;
}
Loading

0 comments on commit c0804ca

Please sign in to comment.