diff --git a/c_api/include/taichi/taichi_cuda.h b/c_api/include/taichi/taichi_cuda.h index 136f6ce9e1f9e..a67b4f22e8d02 100644 --- a/c_api/include/taichi/taichi_cuda.h +++ b/c_api/include/taichi/taichi_cuda.h @@ -20,10 +20,17 @@ ti_export_cuda_memory(TiRuntime runtime, TiMemory memory, TiCudaMemoryInteropInfo *interop_info); +// Function `ti_import_cuda_memory` TI_DLL_EXPORT TiMemory TI_API_CALL ti_import_cuda_memory(TiRuntime runtime, void *ptr, size_t memory_size); +// Function `ti_set_cuda_stream` +TI_DLL_EXPORT void TI_API_CALL ti_set_cuda_stream(void *stream); + +// Function `ti_get_cuda_stream` +TI_DLL_EXPORT void TI_API_CALL ti_get_cuda_stream(void **stream); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/c_api/src/taichi_llvm_impl.cpp b/c_api/src/taichi_llvm_impl.cpp index 277ba50df70d7..cc36704f9bff1 100644 --- a/c_api/src/taichi_llvm_impl.cpp +++ b/c_api/src/taichi_llvm_impl.cpp @@ -14,6 +14,7 @@ #ifdef TI_WITH_CUDA #include "taichi/rhi/cuda/cuda_device.h" +#include "taichi/rhi/cuda/cuda_context.h" #include "taichi/runtime/cuda/kernel_launcher.h" #endif @@ -242,4 +243,24 @@ TI_DLL_EXPORT TiMemory TI_API_CALL ti_import_cuda_memory(TiRuntime runtime, #endif } +// function.set_cuda_stream +TI_DLL_EXPORT void TI_API_CALL ti_set_cuda_stream(void *stream) { +#ifdef TI_WITH_CUDA + taichi::lang::CUDAContext::get_instance().set_stream(stream); + +#else + TI_NOT_IMPLEMENTED; +#endif +} + +// function.get_cuda_stream +TI_DLL_EXPORT void TI_API_CALL ti_get_cuda_stream(void **stream) { +#ifdef TI_WITH_CUDA + *stream = taichi::lang::CUDAContext::get_instance().get_stream(); +#else + TI_NOT_IMPLEMENTED; + +#endif +} + #endif // TI_WITH_LLVM diff --git a/c_api/tests/c_api_interop_test.cpp b/c_api/tests/c_api_interop_test.cpp index 73ae75c58bd31..16bae248dbb90 100644 --- a/c_api/tests/c_api_interop_test.cpp +++ b/c_api/tests/c_api_interop_test.cpp @@ -160,3 +160,23 @@ TEST_F(CapiTest, TestCUDAImport) { EXPECT_EQ(data_out[3], 4.0); } #endif // TI_WITH_CUDA + +#ifdef TI_WITH_CUDA +TEST_F(CapiTest, TestCUDAStreamSet) { + void *temp_stream = nullptr; + + ti_get_cuda_stream(&temp_stream); + EXPECT_EQ(temp_stream, nullptr); + + void *stream1 = reinterpret_cast(0x12345678); + void *stream2 = reinterpret_cast(0x87654321); + + ti_set_cuda_stream(stream1); + ti_get_cuda_stream(&temp_stream); + EXPECT_EQ(temp_stream, stream1); + + ti_set_cuda_stream(stream2); + ti_get_cuda_stream(&temp_stream); + EXPECT_EQ(temp_stream, stream2); +} +#endif diff --git a/taichi/rhi/cuda/cuda_context.cpp b/taichi/rhi/cuda/cuda_context.cpp index 587f737e935e4..a71833e89c18a 100644 --- a/taichi/rhi/cuda/cuda_context.cpp +++ b/taichi/rhi/cuda/cuda_context.cpp @@ -12,7 +12,9 @@ namespace taichi::lang { CUDAContext::CUDAContext() - : profiler_(nullptr), driver_(CUDADriver::get_instance_without_context()) { + : profiler_(nullptr), + driver_(CUDADriver::get_instance_without_context()), + stream_(nullptr) { // CUDA initialization dev_count_ = 0; driver_.init(0); @@ -156,14 +158,14 @@ void CUDAContext::launch(void *func, dynamic_shared_mem_bytes); } driver_.launch_kernel(func, grid_dim, 1, 1, block_dim, 1, 1, - dynamic_shared_mem_bytes, nullptr, + dynamic_shared_mem_bytes, stream_, arg_pointers.data(), nullptr); } if (profiler_) profiler_->stop(task_handle); if (debug_) { - driver_.stream_synchronize(nullptr); + driver_.stream_synchronize(stream_); } } diff --git a/taichi/rhi/cuda/cuda_context.h b/taichi/rhi/cuda/cuda_context.h index e912cca7aa0c1..fa7d2fa93e936 100644 --- a/taichi/rhi/cuda/cuda_context.h +++ b/taichi/rhi/cuda/cuda_context.h @@ -29,6 +29,7 @@ class CUDAContext { int max_shared_memory_bytes_; bool debug_; bool supports_mem_pool_; + void *stream_; public: CUDAContext(); @@ -108,6 +109,14 @@ class CUDAContext { } static CUDAContext &get_instance(); + + void set_stream(void *stream) { + stream_ = stream; + } + + void *get_stream() const { + return stream_; + } }; } // namespace taichi::lang