Skip to content

Commit

Permalink
Openvino/ep weight sharing (#548)
Browse files Browse the repository at this point in the history
* Rename EP instance context as session_context

* Add support for GetEpContextNodes

* enable config option for ovep weight sharing

* add config option for ovep weight sharing

* Refactor the conditional blocks in OVEP for compilation

* Convert initializers with external data to graph inputs

* create, store and export metadata for ovep weight sharing

* fix error handling in weight sharing

* fix crash issue while setting up inputs for wai model

* pass weight sharing option to OVEP qdq stripping pass

* Aligning OVEP variable names to match the session option value they hold

* Add plumbing for context sharing plus refactoring around option handling

* Store metadata in shared context

* fix: fix provider options

* create ov tensor from meta data and external data

* create ov tensor

* Add support for binding weight as input tensors

* Fix for mapping subgraph to ov compiled network arguments

* Fix for using so_share_ep_contexts without ep.context* flags

* Add remote tensor support for NPU weight sharing

* Use a single ov::Core copy across OVEP

* Decouple provider option cache_dir from session option ep.context_file_path

* Add support for serialization and deserialization of metadata to disk

* Load blobs from relative path stored in ep_cache_context

* Use remote L0 tensors for shared weights

* fix linux ci issues

* fix ci issues

* Fix Windows build failure

* Use ifstream to load weights instead of mmaped file

* Fix for epctx models made up entirely of OVEP epctx nodes

* Limit ov::Core lifetime to that of provider object

* Enforce shared tensors cleanup on shutdown

* Add support for default device type based on project configuration

* fix: Fixed concrete_backend_ pointer double free issue on Linux

* Preetha/weight sharing fix (#545)

* Move variables from subgraph to session context for model specific properties

* Fix for redundant subgraph creation

* Remove unused variable

---------

Co-authored-by: Javier E. Martinez <[email protected]>
Co-authored-by: saurabhkale117 <[email protected]>
Co-authored-by: Preetha Veeramalai <[email protected]>
Co-authored-by: ankitm3k <[email protected]>
Co-authored-by: Eric Crawford <[email protected]>
  • Loading branch information
6 people committed Jan 31, 2025
1 parent 37964db commit 84fd325
Show file tree
Hide file tree
Showing 23 changed files with 1,366 additions and 952 deletions.
2 changes: 1 addition & 1 deletion cmake/onnxruntime_providers_openvino.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
endif()

list(APPEND OPENVINO_LIB_LIST openvino::frontend::onnx openvino::runtime ${PYTHON_LIBRARIES})
if ((DEFINED ENV{OPENCL_LIBS}) AND (DEFINED ENV{OPENCL_INCS}))
if ((DEFINED ENV{OPENCL_LIBS}) AND (DEFINED ENV{OPENCL_INCS}) AND onnxruntime_USE_OPENVINO_GPU)
add_definitions(-DIO_BUFFER_ENABLED=1)
list(APPEND OPENVINO_LIB_LIST $ENV{OPENCL_LIBS})
endif()
Expand Down
247 changes: 132 additions & 115 deletions onnxruntime/core/providers/openvino/backend_manager.cc

Large diffs are not rendered by default.

15 changes: 7 additions & 8 deletions onnxruntime/core/providers/openvino/backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,16 @@ namespace openvino_ep {
// Singleton class that manages all the backends
class BackendManager {
public:
BackendManager(const GlobalContext& global_context,
BackendManager(SessionContext& session_context,
SharedContext& shared_context,
const onnxruntime::Node& fused_node,
const onnxruntime::GraphViewer& subgraph,
const logging::Logger& logger,
EPCtxHandler& ctx_handle);
void Compute(OrtKernelContext* context);
void ShutdownBackendManager();
void SetGlobalCotext(const GlobalContext& global_context);
GlobalContext& GetGlobalContext();
Status ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph,
const logging::Logger& logger);
SessionContext& GetSessionContext();
Status ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph);
ov::CompiledModel& GetOVCompiledModel();

private:
Expand All @@ -52,9 +51,9 @@ class BackendManager {
std::shared_ptr<IBackend> concrete_backend_;
std::map<std::string, std::shared_ptr<IBackend>> backend_map_;
SubGraphContext subgraph_context_;
GlobalContext global_context_;
EPCtxHandler ep_ctx_handle_{};
std::string openvino_sdk_version_{};
EPCtxHandler& ep_ctx_handle_;
SessionContext& session_context_;
SharedContext& shared_context_;
};

} // namespace openvino_ep
Expand Down
187 changes: 170 additions & 17 deletions onnxruntime/core/providers/openvino/backend_utils.cc
Original file line number Diff line number Diff line change
@@ -1,21 +1,107 @@
// Copyright (C) Intel Corporation
// Licensed under the MIT License

#include <algorithm>
#include <sstream>
#include <fstream>
#include <utility>

#include <filesystem>
#include <stdexcept>

#include "openvino/pass/convert_fp32_to_fp16.hpp"
#include "openvino/pass/constant_folding.hpp"
#include "openvino/runtime/intel_npu/level_zero/level_zero.hpp"
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/openvino/backend_utils.h"
#include "core/providers/openvino/ov_interface.h"


using Exception = ov::Exception;

namespace onnxruntime {
namespace openvino_ep {

SharedContext::SharedWeights::WeightsFile::WeightsFile(std::filesystem::path filename) : file_(filename, std::ios::in | std::ios::binary) {
try {
file_.exceptions(std::ifstream::failbit | std::ifstream::badbit);
weights_size_ = file_.seekg(0, std::ios::end).tellg();
} catch (std::ifstream::failure& e) {
ORT_THROW("Error: Failed to open weight file at ", filename.string(), " ", e.what());
}
}

void SharedContext::SharedWeights::WeightsFile::load_weights(size_t file_offset, void* data, size_t size) {
ORT_ENFORCE(file_offset < weights_size_ && size <= weights_size_ && (file_offset <= weights_size_ - size), "Error: File offset is out of bounds.");
file_.seekg(file_offset);
file_.read(reinterpret_cast<char*>(data), size);
}

std::ostream& operator<<(std::ostream& stream, const SharedContext::SharedWeights::Metadata::Map& metadata) {
try {
stream << metadata.size();

// Write each key-value pair
// Put elements in separate lines to facilitate reading
for (const auto& [key, value] : metadata) {
stream << std::endl
<< key.name;
stream << std::endl
<< value.location;
stream << std::endl
<< value.data_offset;
stream << std::endl
<< value.size;
stream << std::endl
<< value.dimensions.size();
for (const auto& dim : value.dimensions) {
stream << std::endl
<< dim;
}
stream << std::endl
<< value.element_type;
}
} catch (const Exception& e) {
ORT_THROW("Error: Failed to write map data.", e.what());
} catch (...) {
ORT_THROW("Error: Failed to write map data.");
}

ORT_ENFORCE(stream.good(), "Error: Failed to write map data.");
return stream;
}

std::istream& operator>>(std::istream& stream, SharedContext::SharedWeights::Metadata::Map& metadata) {
size_t map_size{0};
try {
stream >> map_size;

while (!stream.eof()) {
SharedContext::SharedWeights::Metadata::Key key;
SharedContext::SharedWeights::Metadata::Value value;
stream >> key.name;
stream >> value.location;
stream >> value.data_offset;
stream >> value.size;
size_t num_dimensions;
stream >> num_dimensions;
value.dimensions.resize(num_dimensions);
for (auto& dim : value.dimensions) {
stream >> dim;
}
stream >> value.element_type;
metadata.emplace(key, value);
}
} catch (const Exception& e) {
ORT_THROW("Error: Failed to read map data.", e.what());
} catch (...) {
ORT_THROW("Error: Failed to read map data.");
}

ORT_ENFORCE(metadata.size() == map_size, "Error: Inconsistent map data.");

return stream;
}

namespace backend_utils {

bool IsDebugEnabled() {
Expand All @@ -34,23 +120,18 @@ bool IsCILogEnabled() {
return false;
}

struct static_cast_int64 {
template <typename T1> // T1 models type statically convertible to T
int64_t operator()(const T1& x) const { return static_cast<int64_t>(x); }
};

std::shared_ptr<const OVNetwork>
CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext& global_context,
CreateOVModel(const std::string model,
const SessionContext& session_context,
std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map) {
if (IsCILogEnabled()) {
std::cout << "CreateNgraphFunc" << std::endl;
}
const std::string model = model_proto.SerializeAsString();
try {
auto ov_model = global_context.ie_core.ReadModel(model, global_context.onnx_model_path_name);
auto ov_model = OVCore::ReadModel(model, session_context.onnx_model_path_name.string());

// Check for Constant Folding
if ((global_context.device_type != "NPU") && !global_context.is_wholly_supported_graph) {
if ((session_context.device_type != "NPU") && !session_context.is_wholly_supported_graph) {
ov::pass::ConstantFolding pass_const_obj;
pass_const_obj.run_on_model(ov_model);
auto& results = const_cast<ov::ResultVector&>(ov_model.get()->get_results());
Expand Down Expand Up @@ -82,7 +163,7 @@ Ort::UnownedValue
GetOutputTensor(Ort::KernelContext& context, size_t batch_size,
OVInferRequestPtr infer_request,
std::string output_name,
std::unordered_map<std::string, int> output_names) {
const SubGraphContext::string_index_map_t& output_names) {
auto graph_output_blob = infer_request->GetTensor(output_name);

auto graph_output_dims = graph_output_blob->get_shape();
Expand All @@ -107,7 +188,7 @@ GetOutputTensor(Ort::KernelContext& context, size_t batch_size,
Ort::UnownedValue
GetOutputTensor(Ort::KernelContext& context,
std::string output_name,
std::unordered_map<std::string, int> output_names,
const SubGraphContext::string_index_map_t& output_names,
std::shared_ptr<ov::Node> node) {
// Find position of '/' in the output_name
int pos = output_name.find("/");
Expand All @@ -129,13 +210,13 @@ GetOutputTensor(Ort::KernelContext& context,
return context.GetOutput(index, output_shape.get(), num_dims);
}

int GetFirstAvailableDevice(GlobalContext& global_context) {
int GetFirstAvailableDevice(SessionContext& session_context) {
int i = 0;
// Get the first available VAD-M device and set the device to busy
while (i < 8) {
bool device = global_context.deviceAvailableList[i];
bool device = session_context.deviceAvailableList[i];
if (device) {
global_context.deviceAvailableList[i] = false;
session_context.deviceAvailableList[i] = false;
break;
}
i++;
Expand All @@ -144,9 +225,9 @@ int GetFirstAvailableDevice(GlobalContext& global_context) {
// make all remaining devices free
if (i == 8) {
i = 0;
global_context.deviceAvailableList[i] = false;
session_context.deviceAvailableList[i] = false;
for (int j = 1; j < 8; j++) {
global_context.deviceAvailableList[j] = true;
session_context.deviceAvailableList[j] = true;
}
}
return i;
Expand Down Expand Up @@ -267,6 +348,78 @@ void printPerformanceCounts(OVInferRequestPtr request, std::ostream& stream, std
printPerformanceCounts(performanceMap, stream, std::move(deviceName));
}

ov::element::Type GetOpenVINOElementType(ONNX_NAMESPACE::TensorProto_DataType dt) {
static std::unordered_map<ONNX_NAMESPACE::TensorProto_DataType, ov::element::Type> map{
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT, ov::element::f32},
{ONNX_NAMESPACE::TensorProto_DataType_UINT8, ov::element::u8},
{ONNX_NAMESPACE::TensorProto_DataType_INT8, ov::element::i8},
{ONNX_NAMESPACE::TensorProto_DataType_UINT16, ov::element::u16},
{ONNX_NAMESPACE::TensorProto_DataType_INT16, ov::element::i16},
{ONNX_NAMESPACE::TensorProto_DataType_INT32, ov::element::i32},
{ONNX_NAMESPACE::TensorProto_DataType_INT64, ov::element::i64},
{ONNX_NAMESPACE::TensorProto_DataType_STRING, ov::element::string},
{ONNX_NAMESPACE::TensorProto_DataType_BOOL, ov::element::boolean},
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, ov::element::f16},
{ONNX_NAMESPACE::TensorProto_DataType_DOUBLE, ov::element::f64},
{ONNX_NAMESPACE::TensorProto_DataType_UINT32, ov::element::u32},
{ONNX_NAMESPACE::TensorProto_DataType_UINT64, ov::element::u64},
//{ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64, ov::element::undefined},
//{ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128, ov::element::undefined},
{ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16, ov::element::bf16},
//{ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN, ov::element::undefined},
//{ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ, ov::element::undefined},
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2, ov::element::f8e5m2},
//{ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ, ov::element::undefined},
{ONNX_NAMESPACE::TensorProto_DataType_UINT4, ov::element::u4},
{ONNX_NAMESPACE::TensorProto_DataType_INT4, ov::element::i4},
};

if (auto result = map.find(dt); result != map.end()) {
return result->second;
} else {
throw std::runtime_error("Unsupported ONNX data type: " + std::to_string(dt));
}
}

// Function to handle tensor creation from external data
void CreateOVTensors(const std::string& device_name,
SharedContext::SharedWeights::Metadata::Map& metadata_map,
SharedContext::SharedWeights::WeightsFile &weights) {
for (auto& [key, value] : metadata_map) {
if (value.tensor) continue;

// Get element data type
auto onnx_element_type = (ONNX_NAMESPACE::TensorProto_DataType)value.element_type;

ov::element::Type ov_elementType = GetOpenVINOElementType(onnx_element_type); // Map to OpenVINO data type

// Create OpenVINO Tensor
if (device_name == "NPU") {
// Use remote tensors
auto npu_context = OVCore::Get().get_default_context("NPU").as<ov::intel_npu::level_zero::ZeroContext>();
auto&& remote_tensor = npu_context.create_l0_host_tensor(ov_elementType, value.dimensions, ov::intel_npu::TensorType::INPUT);

// Copy data to remote tensor
weights.load_weights(value.data_offset, remote_tensor.get(), value.size);
value.tensor = std::make_shared<ov::Tensor>(remote_tensor);
} else {
// Use vanilla tensors
value.tensor = std::make_shared<ov::Tensor>(ov_elementType, value.dimensions);
weights.load_weights(value.data_offset, value.tensor->data(), value.size);
}
ORT_ENFORCE(value.tensor->get_byte_size() == value.size, "Unexpected tensor size mismatch");
}
}

void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map) {
for (auto& [key, value] : metadata_map) {
if (value.tensor) {
value.tensor.reset();
}
}
metadata_map.clear();
}

} // namespace backend_utils
} // namespace openvino_ep
} // namespace onnxruntime
16 changes: 11 additions & 5 deletions onnxruntime/core/providers/openvino/backend_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <memory>
#include <vector>
#include <string>
#include <string_view>

#include "core/session/onnxruntime_cxx_api.h"
#include "core/providers/openvino/contexts.h"
Expand All @@ -34,7 +35,7 @@ bool IsDebugEnabled();
// Internal diagnostic function.
bool IsCILogEnabled();

int GetFirstAvailableDevice(GlobalContext& global_context);
int GetFirstAvailableDevice(SessionContext& session_context);

void FillOutputsWithConstantData(std::shared_ptr<ov::Node> node, Ort::UnownedValue& out_tensor);

Expand All @@ -44,14 +45,14 @@ void FillOutputHelper(Ort::UnownedValue& out_tensor, std::shared_ptr<ov::Node> n
Ort::UnownedValue
GetOutputTensor(Ort::KernelContext& context,
std::string output_name,
std::unordered_map<std::string, int> output_names,
const SubGraphContext::string_index_map_t& output_names,
std::shared_ptr<ov::Node> node);

Ort::UnownedValue
GetOutputTensor(Ort::KernelContext& context, size_t batch_size,
OVInferRequestPtr infer_request,
std::string output_name,
std::unordered_map<std::string, int> output_names);
const SubGraphContext::string_index_map_t& output_names);

void FillInputBlob(OVTensorPtr inputBlob, size_t batch_slice_idx,
std::string input_name, Ort::KernelContext& context,
Expand All @@ -61,10 +62,15 @@ void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor,
size_t batch_slice_idx);

std::shared_ptr<const OVNetwork>
CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto,
const GlobalContext& global_context,
CreateOVModel(const std::string model,
const SessionContext& session_context,
std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map);

void CreateOVTensors(const std::string& device_name,
SharedContext::SharedWeights::Metadata::Map& metadata_map,
SharedContext::SharedWeights::WeightsFile& weights);
void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map);

void printPerformanceCounts(const std::vector<OVProfilingInfo>& performanceMap,
std::ostream& stream, std::string deviceName);

Expand Down
10 changes: 6 additions & 4 deletions onnxruntime/core/providers/openvino/backends/backend_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@ namespace openvino_ep {

std::shared_ptr<IBackend>
BackendFactory::MakeBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_proto,
GlobalContext& global_context,
SessionContext& session_context,
const SubGraphContext& subgraph_context,
EPCtxHandler& ep_ctx_handle) {
std::string type = global_context.device_type;
SharedContext& shared_context,
ptr_stream_t& model_stream) {
std::string type = session_context.device_type;
if (type == "CPU" || type.find("GPU") != std::string::npos ||
type.find("NPU") != std::string::npos ||
type.find("HETERO") != std::string::npos ||
type.find("MULTI") != std::string::npos ||
type.find("AUTO") != std::string::npos) {
std::shared_ptr<IBackend> concrete_backend_;
try {
concrete_backend_ = std::make_shared<BasicBackend>(model_proto, global_context, subgraph_context, ep_ctx_handle);
concrete_backend_ = std::make_shared<BasicBackend>(model_proto, session_context, subgraph_context, shared_context, model_stream);
} catch (std::string const& msg) {
ORT_THROW(msg);
}
Expand All @@ -32,5 +33,6 @@ BackendFactory::MakeBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_p
ORT_THROW("[OpenVINO-EP] Backend factory error: Unknown backend type: " + type);
}
}

} // namespace openvino_ep
} // namespace onnxruntime
Loading

0 comments on commit 84fd325

Please sign in to comment.