From b4165826c43d11b1270269752513cb0d9022f7e4 Mon Sep 17 00:00:00 2001 From: Terry Wang Date: Fri, 7 Feb 2025 20:49:23 +0800 Subject: [PATCH 1/3] Fix coredump when rss push partition data size exceed IN.MAVALALUE --- cpp/core/jni/JniCommon.h | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/cpp/core/jni/JniCommon.h b/cpp/core/jni/JniCommon.h index 8f40398a4132..16bb80207ad2 100644 --- a/cpp/core/jni/JniCommon.h +++ b/cpp/core/jni/JniCommon.h @@ -488,22 +488,36 @@ class JavaRssClient : public RssClient { } int32_t pushPartitionData(int32_t partitionId, const char* bytes, int64_t size) override { + const int64_t maxBufferSize = 128 * 1024 * 1024; // 128 MB JNIEnv* env; if (vm_->GetEnv(reinterpret_cast(&env), jniVersion) != JNI_OK) { throw gluten::GlutenException("JNIEnv was not attached to current thread"); } - jint length = env->GetArrayLength(array_); - if (size > length) { - jbyte* byteArray = env->GetByteArrayElements(array_, NULL); - env->ReleaseByteArrayElements(array_, byteArray, JNI_ABORT); - env->DeleteGlobalRef(array_); - array_ = env->NewByteArray(size); - array_ = static_cast(env->NewGlobalRef(array_)); + int32_t totalBytesPushed = 0; + int64_t offset = 0; + while (offset < size) { + int64_t chunkSize = std::min(maxBufferSize, size - offset); + + if (chunkSize > env->GetArrayLength(array_)) { + jbyte* byteArray = env->GetByteArrayElements(array_, NULL); + env->ReleaseByteArrayElements(array_, byteArray, JNI_ABORT); + env->DeleteGlobalRef(array_); + array_ = env->NewByteArray(chunkSize); + if (array_ == nullptr) { + throw gluten::GlutenException("Failed to allocate new byte array"); + } + array_ = static_cast(env->NewGlobalRef(array_)); + } + + env->SetByteArrayRegion(array_, 0, chunkSize, reinterpret_cast(bytes + offset)); + jint javaBytesSize = + env->CallIntMethod(javaRssShuffleWriter_, javaPushPartitionData_, partitionId, array_, chunkSize); + checkException(env); + + totalBytesPushed += static_cast(javaBytesSize); + offset += chunkSize; } - env->SetByteArrayRegion(array_, 0, size, (jbyte*)bytes); - jint javaBytesSize = env->CallIntMethod(javaRssShuffleWriter_, javaPushPartitionData_, partitionId, array_, size); - checkException(env); - return static_cast(javaBytesSize); + return static_cast(totalBytesPushed); } void stop() override {} From 6940191d6fc8c4fe2ba9727b7357f8b2bc365ed8 Mon Sep 17 00:00:00 2001 From: Terry Wang Date: Sat, 8 Feb 2025 10:57:16 +0800 Subject: [PATCH 2/3] fix compile problem --- cpp/core/jni/JniCommon.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/core/jni/JniCommon.h b/cpp/core/jni/JniCommon.h index 16bb80207ad2..631821ae35a6 100644 --- a/cpp/core/jni/JniCommon.h +++ b/cpp/core/jni/JniCommon.h @@ -509,7 +509,7 @@ class JavaRssClient : public RssClient { array_ = static_cast(env->NewGlobalRef(array_)); } - env->SetByteArrayRegion(array_, 0, chunkSize, reinterpret_cast(bytes + offset)); + env->SetByteArrayRegion(array_, 0, chunkSize, (jbyte*)(bytes + offset)); jint javaBytesSize = env->CallIntMethod(javaRssShuffleWriter_, javaPushPartitionData_, partitionId, array_, chunkSize); checkException(env); From 0b51f31bd7ba2e25974361e1ac67b2805c9d3b4a Mon Sep 17 00:00:00 2001 From: Terry Wang Date: Sat, 8 Feb 2025 15:32:58 +0800 Subject: [PATCH 3/3] remove split bytes logic and just add check null logic --- cpp/core/jni/JniCommon.h | 38 ++++++++++++++------------------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/cpp/core/jni/JniCommon.h b/cpp/core/jni/JniCommon.h index 631821ae35a6..246658834d48 100644 --- a/cpp/core/jni/JniCommon.h +++ b/cpp/core/jni/JniCommon.h @@ -488,36 +488,26 @@ class JavaRssClient : public RssClient { } int32_t pushPartitionData(int32_t partitionId, const char* bytes, int64_t size) override { - const int64_t maxBufferSize = 128 * 1024 * 1024; // 128 MB JNIEnv* env; if (vm_->GetEnv(reinterpret_cast(&env), jniVersion) != JNI_OK) { throw gluten::GlutenException("JNIEnv was not attached to current thread"); } - int32_t totalBytesPushed = 0; - int64_t offset = 0; - while (offset < size) { - int64_t chunkSize = std::min(maxBufferSize, size - offset); - - if (chunkSize > env->GetArrayLength(array_)) { - jbyte* byteArray = env->GetByteArrayElements(array_, NULL); - env->ReleaseByteArrayElements(array_, byteArray, JNI_ABORT); - env->DeleteGlobalRef(array_); - array_ = env->NewByteArray(chunkSize); - if (array_ == nullptr) { - throw gluten::GlutenException("Failed to allocate new byte array"); - } - array_ = static_cast(env->NewGlobalRef(array_)); + jint length = env->GetArrayLength(array_); + if (size > length) { + jbyte* byteArray = env->GetByteArrayElements(array_, NULL); + env->ReleaseByteArrayElements(array_, byteArray, JNI_ABORT); + env->DeleteGlobalRef(array_); + array_ = env->NewByteArray(size); + if (array_ == nullptr) { + LOG(WARNING) << "Failed to allocate new byte array size: " << size; + throw gluten::GlutenException("Failed to allocate new byte array"); } - - env->SetByteArrayRegion(array_, 0, chunkSize, (jbyte*)(bytes + offset)); - jint javaBytesSize = - env->CallIntMethod(javaRssShuffleWriter_, javaPushPartitionData_, partitionId, array_, chunkSize); - checkException(env); - - totalBytesPushed += static_cast(javaBytesSize); - offset += chunkSize; + array_ = static_cast(env->NewGlobalRef(array_)); } - return static_cast(totalBytesPushed); + env->SetByteArrayRegion(array_, 0, size, (jbyte*)bytes); + jint javaBytesSize = env->CallIntMethod(javaRssShuffleWriter_, javaPushPartitionData_, partitionId, array_, size); + checkException(env); + return static_cast(javaBytesSize); } void stop() override {}