Skip to content

Commit

Permalink
remove split bytes logic and just add check null logic
Browse files Browse the repository at this point in the history
  • Loading branch information
zjuwangg committed Feb 10, 2025
1 parent 6940191 commit 0b51f31
Showing 1 changed file with 14 additions and 24 deletions.
38 changes: 14 additions & 24 deletions cpp/core/jni/JniCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,36 +488,26 @@ class JavaRssClient : public RssClient {
}

int32_t pushPartitionData(int32_t partitionId, const char* bytes, int64_t size) override {
const int64_t maxBufferSize = 128 * 1024 * 1024; // 128 MB
JNIEnv* env;
if (vm_->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) {
throw gluten::GlutenException("JNIEnv was not attached to current thread");
}
int32_t totalBytesPushed = 0;
int64_t offset = 0;
while (offset < size) {
int64_t chunkSize = std::min(maxBufferSize, size - offset);

if (chunkSize > env->GetArrayLength(array_)) {
jbyte* byteArray = env->GetByteArrayElements(array_, NULL);
env->ReleaseByteArrayElements(array_, byteArray, JNI_ABORT);
env->DeleteGlobalRef(array_);
array_ = env->NewByteArray(chunkSize);
if (array_ == nullptr) {
throw gluten::GlutenException("Failed to allocate new byte array");
}
array_ = static_cast<jbyteArray>(env->NewGlobalRef(array_));
jint length = env->GetArrayLength(array_);
if (size > length) {
jbyte* byteArray = env->GetByteArrayElements(array_, NULL);
env->ReleaseByteArrayElements(array_, byteArray, JNI_ABORT);
env->DeleteGlobalRef(array_);
array_ = env->NewByteArray(size);
if (array_ == nullptr) {
LOG(WARNING) << "Failed to allocate new byte array size: " << size;
throw gluten::GlutenException("Failed to allocate new byte array");
}

env->SetByteArrayRegion(array_, 0, chunkSize, (jbyte*)(bytes + offset));
jint javaBytesSize =
env->CallIntMethod(javaRssShuffleWriter_, javaPushPartitionData_, partitionId, array_, chunkSize);
checkException(env);

totalBytesPushed += static_cast<int32_t>(javaBytesSize);
offset += chunkSize;
array_ = static_cast<jbyteArray>(env->NewGlobalRef(array_));
}
return static_cast<int32_t>(totalBytesPushed);
env->SetByteArrayRegion(array_, 0, size, (jbyte*)bytes);
jint javaBytesSize = env->CallIntMethod(javaRssShuffleWriter_, javaPushPartitionData_, partitionId, array_, size);
checkException(env);
return static_cast<int32_t>(javaBytesSize);
}

void stop() override {}
Expand Down

0 comments on commit 0b51f31

Please sign in to comment.