Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove virtual methods from ur_mem_handle_t_ #2620

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions source/adapters/level_zero/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1747,11 +1747,12 @@ ur_result_t urMemRelease(
if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
return ze2urResult(ZeResult);
}
delete Image;
} else {
auto Buffer = reinterpret_cast<_ur_buffer *>(Mem);
Buffer->free();
delete Buffer;
}
delete Mem;

return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -2081,10 +2082,11 @@ static ur_result_t ZeDeviceMemAllocHelper(void **ResultPtr,
return UR_RESULT_SUCCESS;
}

ur_result_t _ur_buffer::getZeHandle(char *&ZeHandle, access_mode_t AccessMode,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
ur_result_t _ur_buffer::getBufferZeHandle(char *&ZeHandle,
access_mode_t AccessMode,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {

// NOTE: There might be no valid allocation at all yet and we get
// here from piEnqueueKernelLaunch that would be doing the buffer
Expand Down Expand Up @@ -2393,7 +2395,7 @@ ur_result_t _ur_buffer::free() {
// Buffer constructor
_ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size, char *HostPtr,
bool ImportedHostPtr = false)
: ur_mem_handle_t_(Context), Size(Size) {
: ur_mem_handle_t_(mem_type_t::buffer, Context), Size(Size) {

// We treat integrated devices (physical memory shared with the CPU)
// differently from discrete devices (those with distinct memories).
Expand Down Expand Up @@ -2422,13 +2424,13 @@ _ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size, char *HostPtr,

_ur_buffer::_ur_buffer(ur_context_handle_t Context, ur_device_handle_t Device,
size_t Size)
: ur_mem_handle_t_(Context, Device), Size(Size) {}
: ur_mem_handle_t_(mem_type_t::buffer, Context, Device), Size(Size) {}

// Interop-buffer constructor
_ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size,
ur_device_handle_t Device, char *ZeMemHandle,
bool OwnZeMemHandle)
: ur_mem_handle_t_(Context, Device), Size(Size) {
: ur_mem_handle_t_(mem_type_t::buffer, Context, Device), Size(Size) {

// Device == nullptr means host allocation
Allocations[Device].ZeHandle = ZeMemHandle;
Expand All @@ -2449,11 +2451,38 @@ _ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size,
LastDeviceWithValidAllocation = Device;
}

ur_result_t _ur_buffer::getZeHandlePtr(char **&ZeHandlePtr,
access_mode_t AccessMode,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
ur_result_t ur_mem_handle_t_::getZeHandle(char *&ZeHandle, access_mode_t mode,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
switch (mem_type) {
case ur_mem_handle_t_::image:
return reinterpret_cast<_ur_image *>(this)->getImageZeHandle(
ZeHandle, mode, Device, phWaitEvents, numWaitEvents);
case ur_mem_handle_t_::buffer:
return reinterpret_cast<_ur_buffer *>(this)->getBufferZeHandle(
ZeHandle, mode, Device, phWaitEvents, numWaitEvents);
}
ur::unreachable();
}

ur_result_t ur_mem_handle_t_::getZeHandlePtr(
char **&ZeHandlePtr, access_mode_t mode, ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents, uint32_t numWaitEvents) {
switch (mem_type) {
case ur_mem_handle_t_::image:
return reinterpret_cast<_ur_image *>(this)->getImageZeHandlePtr(
ZeHandlePtr, mode, Device, phWaitEvents, numWaitEvents);
case ur_mem_handle_t_::buffer:
return reinterpret_cast<_ur_buffer *>(this)->getBufferZeHandlePtr(
ZeHandlePtr, mode, Device, phWaitEvents, numWaitEvents);
}
ur::unreachable();
}

ur_result_t _ur_buffer::getBufferZeHandlePtr(
char **&ZeHandlePtr, access_mode_t AccessMode, ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents, uint32_t numWaitEvents) {
char *ZeHandle;
UR_CALL(
getZeHandle(ZeHandle, AccessMode, Device, phWaitEvents, numWaitEvents));
Expand Down
78 changes: 41 additions & 37 deletions source/adapters/level_zero/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,34 +70,41 @@ struct ur_mem_handle_t_ : _ur_object {
// Keeps device of this memory handle
ur_device_handle_t UrDevice;

// Whether this is an image or buffer
enum mem_type_t { image, buffer };
mem_type_t mem_type;

// Enumerates all possible types of accesses.
enum access_mode_t { unknown, read_write, read_only, write_only };

// Interface of the _ur_mem object

// Get the Level Zero handle of the current memory object
virtual ur_result_t getZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) = 0;
ur_result_t getZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents);

// Get a pointer to the Level Zero handle of the current memory object
virtual ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) = 0;
ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents);

// Method to get type of the derived object (image or buffer)
virtual bool isImage() const = 0;

virtual ~ur_mem_handle_t_() = default;
bool isImage() const { return mem_type == mem_type_t::image; }

protected:
ur_mem_handle_t_(ur_context_handle_t Context)
: UrContext{Context}, UrDevice{nullptr} {}
ur_mem_handle_t_(mem_type_t type, ur_context_handle_t Context)
: UrContext{Context}, UrDevice{nullptr}, mem_type(type) {}

ur_mem_handle_t_(ur_context_handle_t Context, ur_device_handle_t Device)
: UrContext{Context}, UrDevice(Device) {}
ur_mem_handle_t_(mem_type_t type, ur_context_handle_t Context,
ur_device_handle_t Device)
: UrContext{Context}, UrDevice(Device), mem_type(type) {}

// Since the destructor isn't virtual, callers must destruct it via _ur_buffer
// or _ur_image
~ur_mem_handle_t_() {};
};

struct _ur_buffer final : ur_mem_handle_t_ {
Expand All @@ -110,7 +117,7 @@ struct _ur_buffer final : ur_mem_handle_t_ {

// Sub-buffer constructor
_ur_buffer(_ur_buffer *Parent, size_t Origin, size_t Size)
: ur_mem_handle_t_(Parent->UrContext), Size(Size),
: ur_mem_handle_t_(mem_type_t::buffer, Parent->UrContext), Size(Size),
SubBuffer{{Parent, Origin}} {
// Retain the Parent Buffer due to the Creation of the SubBuffer.
Parent->RefCount.increment();
Expand All @@ -127,16 +134,15 @@ struct _ur_buffer final : ur_mem_handle_t_ {
// up-to-date and any data copies needed for that are performed under
// the hood.
//
virtual ur_result_t getZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) override;
virtual ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) override;
ur_result_t getBufferZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents);
ur_result_t getBufferZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents);

bool isImage() const override { return false; }
bool isSubBuffer() const { return SubBuffer != std::nullopt; }

// Frees all allocations made for the buffer.
Expand Down Expand Up @@ -206,35 +212,33 @@ struct _ur_buffer final : ur_mem_handle_t_ {
struct _ur_image final : ur_mem_handle_t_ {
// Image constructor
_ur_image(ur_context_handle_t UrContext, ze_image_handle_t ZeImage)
: ur_mem_handle_t_(UrContext), ZeImage{ZeImage} {}
: ur_mem_handle_t_(mem_type_t::image, UrContext), ZeImage{ZeImage} {}

_ur_image(ur_context_handle_t UrContext, ze_image_handle_t ZeImage,
bool OwnZeMemHandle)
: ur_mem_handle_t_(UrContext), ZeImage{ZeImage} {
: ur_mem_handle_t_(mem_type_t::image, UrContext), ZeImage{ZeImage} {
OwnNativeHandle = OwnZeMemHandle;
}

virtual ur_result_t getZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) override {
ur_result_t getImageZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
std::ignore = phWaitEvents;
std::ignore = numWaitEvents;
ZeHandle = reinterpret_cast<char *>(ZeImage);
return UR_RESULT_SUCCESS;
}
virtual ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) override {
ur_result_t getImageZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
std::ignore = phWaitEvents;
std::ignore = numWaitEvents;
ZeHandlePtr = reinterpret_cast<char **>(&ZeImage);
return UR_RESULT_SUCCESS;
}

bool isImage() const override { return true; }

// Keep the descriptor of the image
ZeStruct<ze_image_desc_t> ZeImageDesc;

Expand Down
Loading