diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java index 8f4292c1bc8..e355cea43b9 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java @@ -8,8 +8,6 @@ package org.pytorch.executorch.training; -import com.facebook.jni.HybridData; -import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; import java.util.Map; @@ -32,10 +30,9 @@ public class SGD { NativeLoader.loadLibrary("executorch"); } - private final HybridData mHybridData; + private long mNativeHandle; - @DoNotStrip - private static native HybridData initHybrid( + private static native long nativeCreate( Map namedParameters, double learningRate, double momentum, @@ -43,6 +40,8 @@ private static native HybridData initHybrid( double weightDecay, boolean nesterov); + private static native void nativeDestroy(long nativeHandle); + private SGD( Map namedParameters, double learningRate, @@ -50,8 +49,8 @@ private SGD( double dampening, double weightDecay, boolean nesterov) { - mHybridData = - initHybrid(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov); + mNativeHandle = + nativeCreate(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov); } /** @@ -92,12 +91,34 @@ public static SGD create(Map namedParameters, double learningRat * @param namedGradients Map of parameter names to gradient tensors */ public void step(Map namedGradients) { - if (!mHybridData.isValid()) { + if (mNativeHandle == 0) { throw new RuntimeException("Attempt to use a destroyed SGD optimizer"); } - stepNative(namedGradients); + nativeStep(mNativeHandle, namedGradients); } - @DoNotStrip - private native void stepNative(Map namedGradients); + private static native void nativeStep(long nativeHandle, Map namedGradients); + + /** + * Explicitly destroys the native SGD optimizer object. Calling this method is not required, as + * the native object will be destroyed when this object is garbage-collected. However, the timing + * of garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory + * more quickly. + */ + public void destroy() { + if (mNativeHandle != 0) { + nativeDestroy(mNativeHandle); + mNativeHandle = 0; + } + } + + @SuppressWarnings("deprecation") + @Override + protected void finalize() throws Throwable { + if (mNativeHandle != 0) { + nativeDestroy(mNativeHandle); + mNativeHandle = 0; + } + super.finalize(); + } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java index 3735fb6f426..c4189a207ed 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java @@ -9,8 +9,6 @@ package org.pytorch.executorch.training; import android.util.Log; -import com.facebook.jni.HybridData; -import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; import java.io.File; @@ -36,13 +34,14 @@ public class TrainingModule { NativeLoader.loadLibrary("executorch"); } - private final HybridData mHybridData; + private long mNativeHandle; - @DoNotStrip - private static native HybridData initHybrid(String moduleAbsolutePath, String dataAbsolutePath); + private static native long nativeCreate(String moduleAbsolutePath, String dataAbsolutePath); + + private static native void nativeDestroy(long nativeHandle); private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) { - mHybridData = initHybrid(moduleAbsolutePath, dataAbsolutePath); + mNativeHandle = nativeCreate(moduleAbsolutePath, dataAbsolutePath); } /** @@ -87,35 +86,58 @@ public static TrainingModule load(final String modelPath) { * @return return value(s) from the method. */ public EValue[] executeForwardBackward(String methodName, EValue... inputs) { - if (!mHybridData.isValid()) { + if (mNativeHandle == 0) { Log.e("ExecuTorch", "Attempt to use a destroyed module"); return new EValue[0]; } - return executeForwardBackwardNative(methodName, inputs); + return nativeExecuteForwardBackward(mNativeHandle, methodName, inputs); } - @DoNotStrip - private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs); + private static native EValue[] nativeExecuteForwardBackward( + long nativeHandle, String methodName, EValue... inputs); public Map namedParameters(String methodName) { - if (!mHybridData.isValid()) { + if (mNativeHandle == 0) { Log.e("ExecuTorch", "Attempt to use a destroyed module"); return new HashMap(); } - return namedParametersNative(methodName); + return nativeNamedParameters(mNativeHandle, methodName); } - @DoNotStrip - private native Map namedParametersNative(String methodName); + private static native Map nativeNamedParameters( + long nativeHandle, String methodName); public Map namedGradients(String methodName) { - if (!mHybridData.isValid()) { + if (mNativeHandle == 0) { Log.e("ExecuTorch", "Attempt to use a destroyed module"); return new HashMap(); } - return namedGradientsNative(methodName); + return nativeNamedGradients(mNativeHandle, methodName); } - @DoNotStrip - private native Map namedGradientsNative(String methodName); + private static native Map nativeNamedGradients( + long nativeHandle, String methodName); + + /** + * Explicitly destroys the native TrainingModule object. Calling this method is not required, as + * the native object will be destroyed when this object is garbage-collected. However, the timing + * of garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory + * more quickly. + */ + public void destroy() { + if (mNativeHandle != 0) { + nativeDestroy(mNativeHandle); + mNativeHandle = 0; + } + } + + @SuppressWarnings("deprecation") + @Override + protected void finalize() throws Throwable { + if (mNativeHandle != 0) { + nativeDestroy(mNativeHandle); + mNativeHandle = 0; + } + super.finalize(); + } } diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 1f8457e00c5..84f698b1b38 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -43,6 +43,7 @@ #include using namespace executorch::extension; +using namespace executorch::jni_helper; using namespace torch::executor; namespace executorch::extension { @@ -543,10 +544,10 @@ void register_natives_for_llm() {} extern void register_natives_for_runtime(); #ifdef EXECUTORCH_BUILD_EXTENSION_TRAINING -extern void register_natives_for_training(); +extern void register_natives_for_training(JNIEnv* env); #else // No op if we don't build training JNI -void register_natives_for_training() {} +void register_natives_for_training(JNIEnv* /* env */) {} #endif JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { @@ -554,6 +555,6 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { executorch::extension::ExecuTorchJni::registerNatives(); register_natives_for_llm(); register_natives_for_runtime(); - register_natives_for_training(); + register_natives_for_training(facebook::jni::Environment::current()); }); } diff --git a/extension/android/jni/jni_layer_training.cpp b/extension/android/jni/jni_layer_training.cpp index 5a5e9f24d2f..ff3a0f6e3be 100644 --- a/extension/android/jni/jni_layer_training.cpp +++ b/extension/android/jni/jni_layer_training.cpp @@ -6,92 +6,456 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include #include #include #include #include #include +#include #include #include #include +#include #include +#include #include #include -#include -#include - using namespace executorch::extension; using namespace executorch::extension::training; using namespace torch::executor; -namespace executorch::extension { +namespace { -// Forward declarations from jni_layer.cpp -class TensorHybrid : public facebook::jni::HybridClass { - public: - constexpr static const char* kJavaDescriptor = - "Lorg/pytorch/executorch/Tensor;"; +// EValue type codes (must match Java EValue class) +constexpr int kTypeCodeNone = 0; +constexpr int kTypeCodeTensor = 1; +constexpr int kTypeCodeString = 2; +constexpr int kTypeCodeDouble = 3; +constexpr int kTypeCodeInt = 4; +constexpr int kTypeCodeBool = 5; - static facebook::jni::local_ref - newJTensorFromTensor(const executorch::aten::Tensor& tensor); +// Helper to convert jstring to std::string +std::string jstring_to_string(JNIEnv* env, jstring jstr) { + if (jstr == nullptr) { + return ""; + } + const char* chars = env->GetStringUTFChars(jstr, nullptr); + if (chars == nullptr) { + return ""; + } + std::string result(chars); + env->ReleaseStringUTFChars(jstr, chars); + return result; +} + +// Helper to throw a Java exception +void throwJavaException(JNIEnv* env, const char* message) { + jclass exceptionClass = env->FindClass("java/lang/RuntimeException"); + if (exceptionClass != nullptr) { + env->ThrowNew(exceptionClass, message); + env->DeleteLocalRef(exceptionClass); + } +} + +// Cached class and method IDs for training module +struct TrainingJniCache { + jclass tensor_class = nullptr; + jclass evalue_class = nullptr; + jclass hashmap_class = nullptr; + jclass bytebuffer_class = nullptr; + jclass byteorder_class = nullptr; + jclass long_class = nullptr; + jclass double_class = nullptr; + jclass boolean_class = nullptr; + jmethodID tensor_nativeNewTensor = nullptr; + jmethodID tensor_dtypeJniCode = nullptr; + jmethodID tensor_getRawDataBuffer = nullptr; + jfieldID tensor_shape = nullptr; + jmethodID evalue_from_tensor = nullptr; + jmethodID evalue_from_long = nullptr; + jmethodID evalue_from_double = nullptr; + jmethodID evalue_from_bool = nullptr; + jmethodID evalue_from_string = nullptr; + jmethodID evalue_toTensor = nullptr; + jfieldID evalue_mTypeCode = nullptr; + jfieldID evalue_mData = nullptr; + jmethodID hashmap_init = nullptr; + jmethodID hashmap_put = nullptr; + jmethodID map_entrySet = nullptr; + jmethodID set_iterator = nullptr; + jmethodID iterator_hasNext = nullptr; + jmethodID iterator_next = nullptr; + jmethodID entry_getKey = nullptr; + jmethodID entry_getValue = nullptr; + jmethodID map_size = nullptr; + jmethodID bytebuffer_order = nullptr; + jmethodID byteorder_nativeOrder = nullptr; + jmethodID long_longValue = nullptr; + jmethodID double_doubleValue = nullptr; + jmethodID boolean_booleanValue = nullptr; + + bool initialized = false; + + void init(JNIEnv* env) { + if (initialized) { + return; + } - static TensorPtr newTensorFromJTensor( - facebook::jni::alias_ref jtensor); -}; + // Cache Tensor class and methods + jclass local_tensor_class = env->FindClass("org/pytorch/executorch/Tensor"); + if (local_tensor_class != nullptr) { + tensor_class = static_cast(env->NewGlobalRef(local_tensor_class)); + env->DeleteLocalRef(local_tensor_class); + + tensor_nativeNewTensor = env->GetStaticMethodID( + tensor_class, + "nativeNewTensor", + "(Ljava/nio/ByteBuffer;[JIJ)Lorg/pytorch/executorch/Tensor;"); + tensor_dtypeJniCode = + env->GetMethodID(tensor_class, "dtypeJniCode", "()I"); + tensor_getRawDataBuffer = env->GetMethodID( + tensor_class, "getRawDataBuffer", "()Ljava/nio/Buffer;"); + tensor_shape = env->GetFieldID(tensor_class, "shape", "[J"); + } -class JEValue : public facebook::jni::JavaClass { - public: - constexpr static const char* kJavaDescriptor = - "Lorg/pytorch/executorch/EValue;"; + // Cache EValue class and methods + jclass local_evalue_class = env->FindClass("org/pytorch/executorch/EValue"); + if (local_evalue_class != nullptr) { + evalue_class = static_cast(env->NewGlobalRef(local_evalue_class)); + env->DeleteLocalRef(local_evalue_class); + + evalue_from_tensor = env->GetStaticMethodID( + evalue_class, + "from", + "(Lorg/pytorch/executorch/Tensor;)Lorg/pytorch/executorch/EValue;"); + evalue_from_long = env->GetStaticMethodID( + evalue_class, "from", "(J)Lorg/pytorch/executorch/EValue;"); + evalue_from_double = env->GetStaticMethodID( + evalue_class, "from", "(D)Lorg/pytorch/executorch/EValue;"); + evalue_from_bool = env->GetStaticMethodID( + evalue_class, "from", "(Z)Lorg/pytorch/executorch/EValue;"); + evalue_from_string = env->GetStaticMethodID( + evalue_class, + "from", + "(Ljava/lang/String;)Lorg/pytorch/executorch/EValue;"); + evalue_toTensor = env->GetMethodID( + evalue_class, "toTensor", "()Lorg/pytorch/executorch/Tensor;"); + evalue_mTypeCode = env->GetFieldID(evalue_class, "mTypeCode", "I"); + evalue_mData = + env->GetFieldID(evalue_class, "mData", "Ljava/lang/Object;"); + } - constexpr static int kTypeCodeTensor = 1; - constexpr static int kTypeCodeString = 2; - constexpr static int kTypeCodeDouble = 3; - constexpr static int kTypeCodeInt = 4; - constexpr static int kTypeCodeBool = 5; + // Cache HashMap class and methods + jclass local_hashmap_class = env->FindClass("java/util/HashMap"); + if (local_hashmap_class != nullptr) { + hashmap_class = + static_cast(env->NewGlobalRef(local_hashmap_class)); + env->DeleteLocalRef(local_hashmap_class); + + hashmap_init = env->GetMethodID(hashmap_class, "", "()V"); + hashmap_put = env->GetMethodID( + hashmap_class, + "put", + "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"); + } - static facebook::jni::local_ref newJEValueFromEValue( - runtime::EValue evalue); + // Cache Map iteration methods + jclass map_class = env->FindClass("java/util/Map"); + if (map_class != nullptr) { + map_entrySet = + env->GetMethodID(map_class, "entrySet", "()Ljava/util/Set;"); + map_size = env->GetMethodID(map_class, "size", "()I"); + env->DeleteLocalRef(map_class); + } + + jclass set_class = env->FindClass("java/util/Set"); + if (set_class != nullptr) { + set_iterator = + env->GetMethodID(set_class, "iterator", "()Ljava/util/Iterator;"); + env->DeleteLocalRef(set_class); + } + + jclass iterator_class = env->FindClass("java/util/Iterator"); + if (iterator_class != nullptr) { + iterator_hasNext = env->GetMethodID(iterator_class, "hasNext", "()Z"); + iterator_next = + env->GetMethodID(iterator_class, "next", "()Ljava/lang/Object;"); + env->DeleteLocalRef(iterator_class); + } - static TensorPtr JEValueToTensorImpl( - facebook::jni::alias_ref JEValue); + jclass entry_class = env->FindClass("java/util/Map$Entry"); + if (entry_class != nullptr) { + entry_getKey = + env->GetMethodID(entry_class, "getKey", "()Ljava/lang/Object;"); + entry_getValue = + env->GetMethodID(entry_class, "getValue", "()Ljava/lang/Object;"); + env->DeleteLocalRef(entry_class); + } + + // Cache ByteBuffer and ByteOrder classes and methods + jclass local_bytebuffer_class = env->FindClass("java/nio/ByteBuffer"); + if (local_bytebuffer_class != nullptr) { + bytebuffer_class = + static_cast(env->NewGlobalRef(local_bytebuffer_class)); + env->DeleteLocalRef(local_bytebuffer_class); + + bytebuffer_order = env->GetMethodID( + bytebuffer_class, + "order", + "(Ljava/nio/ByteOrder;)Ljava/nio/ByteBuffer;"); + } + + jclass local_byteorder_class = env->FindClass("java/nio/ByteOrder"); + if (local_byteorder_class != nullptr) { + byteorder_class = + static_cast(env->NewGlobalRef(local_byteorder_class)); + env->DeleteLocalRef(local_byteorder_class); + + byteorder_nativeOrder = env->GetStaticMethodID( + byteorder_class, "nativeOrder", "()Ljava/nio/ByteOrder;"); + } + + // Cache wrapper classes for primitives (Long, Double, Boolean) + jclass local_long_class = env->FindClass("java/lang/Long"); + if (local_long_class != nullptr) { + long_class = static_cast(env->NewGlobalRef(local_long_class)); + env->DeleteLocalRef(local_long_class); + + long_longValue = env->GetMethodID(long_class, "longValue", "()J"); + } + + jclass local_double_class = env->FindClass("java/lang/Double"); + if (local_double_class != nullptr) { + double_class = static_cast(env->NewGlobalRef(local_double_class)); + env->DeleteLocalRef(local_double_class); + + double_doubleValue = env->GetMethodID(double_class, "doubleValue", "()D"); + } + + jclass local_boolean_class = env->FindClass("java/lang/Boolean"); + if (local_boolean_class != nullptr) { + boolean_class = + static_cast(env->NewGlobalRef(local_boolean_class)); + env->DeleteLocalRef(local_boolean_class); + + boolean_booleanValue = + env->GetMethodID(boolean_class, "booleanValue", "()Z"); + } + + initialized = true; + } }; -class ExecuTorchTrainingJni - : public facebook::jni::HybridClass { - private: - friend HybridBase; - std::unique_ptr module_; +TrainingJniCache g_training_cache; + +// Helper to create Java Tensor from native tensor +jobject newJTensorFromTensor( + JNIEnv* env, + const executorch::aten::Tensor& tensor) { + g_training_cache.init(env); + + const auto scalarType = tensor.scalar_type(); + if (scalar_type_to_java_dtype.count(scalarType) == 0) { + std::stringstream ss; + ss << "Tensor scalar type " << static_cast(scalarType) + << " is not supported on java side"; + throwJavaException(env, ss.str().c_str()); + return nullptr; + } + int jdtype = scalar_type_to_java_dtype.at(scalarType); + + // Create shape array + const auto& tensor_shape = tensor.sizes(); + jlongArray jTensorShape = env->NewLongArray(tensor_shape.size()); + if (jTensorShape == nullptr) { + return nullptr; + } + std::vector shape_vec; + for (const auto& s : tensor_shape) { + shape_vec.push_back(s); + } + env->SetLongArrayRegion(jTensorShape, 0, shape_vec.size(), shape_vec.data()); + + // Create ByteBuffer wrapping tensor data + jobject jTensorBuffer = env->NewDirectByteBuffer( + const_cast(tensor.const_data_ptr()), tensor.nbytes()); + if (jTensorBuffer == nullptr) { + env->DeleteLocalRef(jTensorShape); + return nullptr; + } + + // Set byte order to native order (using cached classes/methods) + jobject nativeOrder = env->CallStaticObjectMethod( + g_training_cache.byteorder_class, + g_training_cache.byteorder_nativeOrder); + env->CallObjectMethod( + jTensorBuffer, g_training_cache.bytebuffer_order, nativeOrder); + env->DeleteLocalRef(nativeOrder); + + // Call nativeNewTensor static method + jobject result = env->CallStaticObjectMethod( + g_training_cache.tensor_class, + g_training_cache.tensor_nativeNewTensor, + jTensorBuffer, + jTensorShape, + jdtype, + static_cast(0)); + + env->DeleteLocalRef(jTensorBuffer); + env->DeleteLocalRef(jTensorShape); + + return result; +} + +// Helper to create native TensorPtr from Java Tensor +TensorPtr newTensorFromJTensor(JNIEnv* env, jobject jtensor) { + g_training_cache.init(env); + + jint jdtype = + env->CallIntMethod(jtensor, g_training_cache.tensor_dtypeJniCode); + + jlongArray jshape = static_cast( + env->GetObjectField(jtensor, g_training_cache.tensor_shape)); + + jobject jbuffer = + env->CallObjectMethod(jtensor, g_training_cache.tensor_getRawDataBuffer); + + jsize rank = env->GetArrayLength(jshape); + + std::vector shapeArr(rank); + env->GetLongArrayRegion(jshape, 0, rank, shapeArr.data()); + + std::vector shape_vec; + shape_vec.reserve(rank); + + for (int i = 0; i < rank; i++) { + shape_vec.push_back(shapeArr[i]); + } + + if (java_dtype_to_scalar_type.count(jdtype) == 0) { + std::stringstream ss; + ss << "Unknown Tensor jdtype: " << jdtype; + throwJavaException(env, ss.str().c_str()); + env->DeleteLocalRef(jshape); + env->DeleteLocalRef(jbuffer); + return nullptr; + } + + ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype); + void* data = env->GetDirectBufferAddress(jbuffer); + TensorPtr result = from_blob(data, shape_vec, scalar_type); + + env->DeleteLocalRef(jshape); + env->DeleteLocalRef(jbuffer); + + return result; +} + +// Helper to create Java EValue from native EValue +jobject newJEValueFromEValue(JNIEnv* env, runtime::EValue evalue) { + g_training_cache.init(env); + + if (evalue.isTensor()) { + jobject jtensor = newJTensorFromTensor(env, evalue.toTensor()); + if (jtensor == nullptr) { + return nullptr; + } + jobject result = env->CallStaticObjectMethod( + g_training_cache.evalue_class, + g_training_cache.evalue_from_tensor, + jtensor); + env->DeleteLocalRef(jtensor); + return result; + } else if (evalue.isInt()) { + return env->CallStaticObjectMethod( + g_training_cache.evalue_class, + g_training_cache.evalue_from_long, + evalue.toInt()); + } else if (evalue.isDouble()) { + return env->CallStaticObjectMethod( + g_training_cache.evalue_class, + g_training_cache.evalue_from_double, + evalue.toDouble()); + } else if (evalue.isBool()) { + return env->CallStaticObjectMethod( + g_training_cache.evalue_class, + g_training_cache.evalue_from_bool, + static_cast(evalue.toBool())); + } else if (evalue.isString()) { + std::string str = + std::string(evalue.toString().begin(), evalue.toString().end()); + jstring jstr = env->NewStringUTF(str.c_str()); + jobject result = env->CallStaticObjectMethod( + g_training_cache.evalue_class, + g_training_cache.evalue_from_string, + jstr); + env->DeleteLocalRef(jstr); + return result; + } + + std::stringstream ss; + ss << "Unknown EValue type: " << static_cast(evalue.tag); + throwJavaException(env, ss.str().c_str()); + return nullptr; +} + +// Helper to get TensorPtr from Java EValue +TensorPtr JEValueToTensorImpl(JNIEnv* env, jobject jevalue) { + g_training_cache.init(env); + + jint typeCode = + env->GetIntField(jevalue, g_training_cache.evalue_mTypeCode); + if (typeCode == kTypeCodeTensor) { + jobject jtensor = + env->CallObjectMethod(jevalue, g_training_cache.evalue_toTensor); + TensorPtr result = newTensorFromJTensor(env, jtensor); + env->DeleteLocalRef(jtensor); + return result; + } + + std::stringstream ss; + ss << "Unknown EValue typeCode: " << typeCode; + throwJavaException(env, ss.str().c_str()); + return nullptr; +} + +} // anonymous namespace + +namespace executorch::extension { +// Native training module handle class +class TrainingModuleNative { public: - constexpr static auto kJavaDescriptor = - "Lorg/pytorch/executorch/training/TrainingModule;"; + std::unique_ptr module_; - ExecuTorchTrainingJni( - facebook::jni::alias_ref modelPath, - facebook::jni::alias_ref dataPath) { - auto modelPathString = modelPath->toStdString(); + TrainingModuleNative( + JNIEnv* env, + jstring modelPath, + jstring dataPath) { + std::string modelPathString = jstring_to_string(env, modelPath); auto modelLoaderRes = FileDataLoader::from(modelPathString.c_str()); if (modelLoaderRes.error() != Error::Ok) { - facebook::jni::throwNewJavaException( - "java/lang/Exception", - "Failed to open model file: %s", - modelPathString.c_str()); + std::stringstream ss; + ss << "Failed to open model file: " << modelPathString; + throwJavaException(env, ss.str().c_str()); + return; } auto modelLoader = std::make_unique(std::move(modelLoaderRes.get())); std::unique_ptr dataLoader = nullptr; - auto dataPathString = dataPath->toStdString(); + std::string dataPathString = jstring_to_string(env, dataPath); if (!dataPathString.empty()) { auto dataLoaderRes = FileDataLoader::from(dataPathString.c_str()); if (dataLoaderRes.error() != Error::Ok) { - facebook::jni::throwNewJavaException( - "java/lang/Exception", - "Failed to open ptd file: %s", - dataPathString.c_str()); + std::stringstream ss; + ss << "Failed to open ptd file: " << dataPathString; + throwJavaException(env, ss.str().c_str()); + return; } dataLoader = std::make_unique(std::move(dataLoaderRes.get())); @@ -104,174 +468,53 @@ class ExecuTorchTrainingJni nullptr, std::move(dataLoader)); } - - static facebook::jni::local_ref initHybrid( - facebook::jni::alias_ref, - facebook::jni::alias_ref modelPath, - facebook::jni::alias_ref dataPath) { - return makeCxxInstance(modelPath, dataPath); - } - - facebook::jni::local_ref> - executeForwardBackward( - facebook::jni::alias_ref methodName, - facebook::jni::alias_ref< - facebook::jni::JArrayClass::javaobject> - jinputs) { - std::vector evalues; - std::vector tensors; - - static const auto typeCodeField = - JEValue::javaClassStatic()->getField("mTypeCode"); - - for (int i = 0; i < jinputs->size(); i++) { - auto jevalue = jinputs->getElement(i); - const auto typeCode = jevalue->getFieldValue(typeCodeField); - if (typeCode == JEValue::kTypeCodeTensor) { - tensors.emplace_back(JEValue::JEValueToTensorImpl(jevalue)); - evalues.emplace_back(tensors.back()); - } else if (typeCode == JEValue::kTypeCodeInt) { - int64_t value = jevalue->getFieldValue(typeCodeField); - evalues.emplace_back(value); - } else if (typeCode == JEValue::kTypeCodeDouble) { - double value = jevalue->getFieldValue(typeCodeField); - evalues.emplace_back(value); - } else if (typeCode == JEValue::kTypeCodeBool) { - bool value = jevalue->getFieldValue(typeCodeField); - evalues.emplace_back(value); - } - } - - auto result = - module_->execute_forward_backward(methodName->toStdString(), evalues); - if (!result.ok()) { - facebook::jni::throwNewJavaException( - "java/lang/Exception", - "Execution of forward_backward for method %s failed with status 0x%" PRIx32, - methodName->toStdString().c_str(), - static_cast(result.error())); - } - - facebook::jni::local_ref> jresult = - facebook::jni::JArrayClass::newArray(result.get().size()); - - for (int i = 0; i < result.get().size(); i++) { - auto jevalue = JEValue::newJEValueFromEValue(result.get()[i]); - jresult->setElement(i, *jevalue); - } - return jresult; - } - - facebook::jni::local_ref< - facebook::jni::JMap> - namedParameters(facebook::jni::alias_ref methodName) { - auto method = methodName->toStdString(); - auto result = module_->named_parameters(method); - if (!result.ok()) { - facebook::jni::throwNewJavaException( - "java/lang/Exception", - "Getting named parameters for method %s failed with status 0x%" PRIx32, - method.c_str(), - static_cast(result.error())); - } - facebook::jni::local_ref< - facebook::jni::JHashMap> - parameters = facebook::jni:: - JHashMap::create(); - for (auto& [layer, tensor] : result.get()) { - parameters->put( - facebook::jni::make_jstring(layer.data()), - TensorHybrid::newJTensorFromTensor(tensor)); - } - return parameters; - } - - facebook::jni::local_ref< - facebook::jni::JMap> - namedGradients(facebook::jni::alias_ref methodName) { - auto method = methodName->toStdString(); - auto result = module_->named_gradients(method); - if (!result.ok()) { - facebook::jni::throwNewJavaException( - "java/lang/Exception", - "Getting named gradients for method %s failed with status 0x%" PRIx32, - method.c_str(), - static_cast(result.error())); - } - facebook::jni::local_ref< - facebook::jni::JHashMap> - gradients = facebook::jni::JHashMap:: - create(); - for (auto& [layer, tensor] : result.get()) { - gradients->put( - facebook::jni::make_jstring(layer.data()), - TensorHybrid::newJTensorFromTensor(tensor)); - } - return gradients; - } - - static void registerNatives() { - registerHybrid({ - makeNativeMethod("initHybrid", ExecuTorchTrainingJni::initHybrid), - makeNativeMethod( - "executeForwardBackwardNative", - ExecuTorchTrainingJni::executeForwardBackward), - makeNativeMethod( - "namedParametersNative", ExecuTorchTrainingJni::namedParameters), - makeNativeMethod( - "namedGradientsNative", ExecuTorchTrainingJni::namedGradients), - }); - } }; -class SGDHybrid : public facebook::jni::HybridClass { +// Native SGD optimizer handle class +class SGDNative { public: - constexpr static const char* kJavaDescriptor = - "Lorg/pytorch/executorch/training/SGD;"; - - static facebook::jni::local_ref initHybrid( - facebook::jni::alias_ref, - facebook::jni::alias_ref< - facebook::jni::JMap> - namedParameters, - jdouble learningRate, - jdouble momentum, - jdouble dampening, - jdouble weightDecay, - jboolean nesterov) { - return makeCxxInstance( - namedParameters, - learningRate, - momentum, - dampening, - weightDecay, - nesterov); - } - - SGDHybrid( - facebook::jni::alias_ref< - facebook::jni::JMap> - namedParameters, + std::unique_ptr sgdOptimizer_; + std::vector + parameterNames_; // Store parameter names to keep string_view valid + std::vector + paramTensorPtrs_; // Store parameter tensors to keep TensorPtrs valid. + + SGDNative( + JNIEnv* env, + jobject namedParameters, jdouble learningRate, jdouble momentum, jdouble dampening, jdouble weightDecay, jboolean nesterov) { + g_training_cache.init(env); + std::map cppNamedParameters; - // Avoid vector reallocation to keep string_views valid. - parameterNames_.reserve(namedParameters->size()); - paramTensorPtrs_.reserve(namedParameters->size()); + // Get the size of the map + jint mapSize = + env->CallIntMethod(namedParameters, g_training_cache.map_size); + + // Reserve space + parameterNames_.reserve(mapSize); + paramTensorPtrs_.reserve(mapSize); - auto iterator = namedParameters->begin(); - auto end = namedParameters->end(); + // Get entry set and iterate + jobject entrySet = + env->CallObjectMethod(namedParameters, g_training_cache.map_entrySet); + jobject iterator = + env->CallObjectMethod(entrySet, g_training_cache.set_iterator); - while (iterator != end) { - auto key = iterator->first; - auto value = iterator->second; + while (env->CallBooleanMethod(iterator, g_training_cache.iterator_hasNext)) { + jobject entry = + env->CallObjectMethod(iterator, g_training_cache.iterator_next); + jstring key = static_cast( + env->CallObjectMethod(entry, g_training_cache.entry_getKey)); + jobject value = + env->CallObjectMethod(entry, g_training_cache.entry_getValue); - std::string paramName = key->toStdString(); - TensorPtr tensor = TensorHybrid::newTensorFromJTensor(value); + std::string paramName = jstring_to_string(env, key); + TensorPtr tensor = newTensorFromJTensor(env, value); // Store the parameter name and tensor parameterNames_.push_back(paramName); @@ -279,73 +522,388 @@ class SGDHybrid : public facebook::jni::HybridClass { cppNamedParameters.emplace( std::string_view(parameterNames_.back()), *tensor); - ++iterator; + env->DeleteLocalRef(key); + env->DeleteLocalRef(value); + env->DeleteLocalRef(entry); } + env->DeleteLocalRef(iterator); + env->DeleteLocalRef(entrySet); + optimizer::SGDOptions options( learningRate, momentum, dampening, weightDecay, nesterov); sgdOptimizer_ = std::make_unique(cppNamedParameters, options); } +}; - void - step(facebook::jni::alias_ref< - facebook::jni::JMap> namedGradients) { - std::map cppNamedGradients; - std::vector gradientNames; - std::vector tensorKeepalives; +} // namespace executorch::extension - gradientNames.reserve(namedGradients->size()); - tensorKeepalives.reserve(namedGradients->size()); +extern "C" { + +JNIEXPORT jlong JNICALL +Java_org_pytorch_executorch_training_TrainingModule_nativeCreate( + JNIEnv* env, + jclass /* clazz */, + jstring modelPath, + jstring dataPath) { + auto* native = + new executorch::extension::TrainingModuleNative(env, modelPath, dataPath); + return reinterpret_cast(native); +} + +JNIEXPORT void JNICALL +Java_org_pytorch_executorch_training_TrainingModule_nativeDestroy( + JNIEnv* /* env */, + jclass /* clazz */, + jlong nativeHandle) { + if (nativeHandle != 0) { + auto* native = + reinterpret_cast( + nativeHandle); + delete native; + } +} + +JNIEXPORT jobjectArray JNICALL +Java_org_pytorch_executorch_training_TrainingModule_nativeExecuteForwardBackward( + JNIEnv* env, + jclass /* clazz */, + jlong nativeHandle, + jstring methodName, + jobjectArray jinputs) { + auto* native = + reinterpret_cast( + nativeHandle); + if (native == nullptr) { + throwJavaException(env, "Native handle is null"); + return nullptr; + } - auto iterator = namedGradients->begin(); - auto end = namedGradients->end(); + g_training_cache.init(env); + + std::string method = jstring_to_string(env, methodName); + jsize inputSize = jinputs != nullptr ? env->GetArrayLength(jinputs) : 0; + + std::vector evalues; + std::vector tensors; + + for (jsize i = 0; i < inputSize; i++) { + jobject jevalue = env->GetObjectArrayElement(jinputs, i); + jint typeCode = + env->GetIntField(jevalue, g_training_cache.evalue_mTypeCode); + + if (typeCode == kTypeCodeTensor) { + tensors.emplace_back(JEValueToTensorImpl(env, jevalue)); + evalues.emplace_back(tensors.back()); + } else if (typeCode == kTypeCodeInt) { + jobject mData = + env->GetObjectField(jevalue, g_training_cache.evalue_mData); + jlong value = + env->CallLongMethod(mData, g_training_cache.long_longValue); + evalues.emplace_back(static_cast(value)); + env->DeleteLocalRef(mData); + } else if (typeCode == kTypeCodeDouble) { + jobject mData = + env->GetObjectField(jevalue, g_training_cache.evalue_mData); + jdouble value = + env->CallDoubleMethod(mData, g_training_cache.double_doubleValue); + evalues.emplace_back(static_cast(value)); + env->DeleteLocalRef(mData); + } else if (typeCode == kTypeCodeBool) { + jobject mData = + env->GetObjectField(jevalue, g_training_cache.evalue_mData); + jboolean value = + env->CallBooleanMethod(mData, g_training_cache.boolean_booleanValue); + evalues.emplace_back(static_cast(value)); + env->DeleteLocalRef(mData); + } + env->DeleteLocalRef(jevalue); + } - while (iterator != end) { - auto key = iterator->first; - auto value = iterator->second; + auto result = native->module_->execute_forward_backward(method, evalues); + if (!result.ok()) { + std::stringstream ss; + ss << "Execution of forward_backward for method " << method + << " failed with status 0x" << std::hex + << static_cast(result.error()); + throwJavaException(env, ss.str().c_str()); + return nullptr; + } - std::string gradName = key->toStdString(); - TensorPtr tensor = TensorHybrid::newTensorFromJTensor(value); + jobjectArray jresult = env->NewObjectArray( + result.get().size(), g_training_cache.evalue_class, nullptr); - // Store the gradient name and tensor - gradientNames.push_back(gradName); - tensorKeepalives.push_back(tensor); - cppNamedGradients.emplace( - std::string_view(gradientNames.back()), *tensor); + for (size_t i = 0; i < result.get().size(); i++) { + jobject jevalue = newJEValueFromEValue(env, result.get()[i]); + env->SetObjectArrayElement(jresult, i, jevalue); + if (jevalue != nullptr) { + env->DeleteLocalRef(jevalue); + } + } + return jresult; +} + +JNIEXPORT jobject JNICALL +Java_org_pytorch_executorch_training_TrainingModule_nativeNamedParameters( + JNIEnv* env, + jclass /* clazz */, + jlong nativeHandle, + jstring methodName) { + auto* native = + reinterpret_cast( + nativeHandle); + if (native == nullptr) { + throwJavaException(env, "Native handle is null"); + return nullptr; + } - ++iterator; + g_training_cache.init(env); + + std::string method = jstring_to_string(env, methodName); + auto result = native->module_->named_parameters(method); + if (!result.ok()) { + std::stringstream ss; + ss << "Getting named parameters for method " << method + << " failed with status 0x" << std::hex + << static_cast(result.error()); + throwJavaException(env, ss.str().c_str()); + return nullptr; + } + + // Create a new HashMap + jobject hashMap = env->NewObject( + g_training_cache.hashmap_class, g_training_cache.hashmap_init); + + for (auto& [layer, tensor] : result.get()) { + jstring jkey = env->NewStringUTF(std::string(layer).c_str()); + jobject jtensor = newJTensorFromTensor(env, tensor); + env->CallObjectMethod( + hashMap, g_training_cache.hashmap_put, jkey, jtensor); + env->DeleteLocalRef(jkey); + if (jtensor != nullptr) { + env->DeleteLocalRef(jtensor); } + } - auto result = sgdOptimizer_->step(cppNamedGradients); - if (result != ::executorch::runtime::Error::Ok) { - facebook::jni::throwNewJavaException( - "java/lang/Exception", - "SGD optimization step failed with status 0x%" PRIx32, - static_cast(result)); + return hashMap; +} + +JNIEXPORT jobject JNICALL +Java_org_pytorch_executorch_training_TrainingModule_nativeNamedGradients( + JNIEnv* env, + jclass /* clazz */, + jlong nativeHandle, + jstring methodName) { + auto* native = + reinterpret_cast( + nativeHandle); + if (native == nullptr) { + throwJavaException(env, "Native handle is null"); + return nullptr; + } + + g_training_cache.init(env); + + std::string method = jstring_to_string(env, methodName); + auto result = native->module_->named_gradients(method); + if (!result.ok()) { + std::stringstream ss; + ss << "Getting named gradients for method " << method + << " failed with status 0x" << std::hex + << static_cast(result.error()); + throwJavaException(env, ss.str().c_str()); + return nullptr; + } + + // Create a new HashMap + jobject hashMap = env->NewObject( + g_training_cache.hashmap_class, g_training_cache.hashmap_init); + + for (auto& [layer, tensor] : result.get()) { + jstring jkey = env->NewStringUTF(std::string(layer).c_str()); + jobject jtensor = newJTensorFromTensor(env, tensor); + env->CallObjectMethod( + hashMap, g_training_cache.hashmap_put, jkey, jtensor); + env->DeleteLocalRef(jkey); + if (jtensor != nullptr) { + env->DeleteLocalRef(jtensor); } } - static void registerNatives() { - registerHybrid({ - makeNativeMethod("initHybrid", SGDHybrid::initHybrid), - makeNativeMethod("stepNative", SGDHybrid::step), - }); + return hashMap; +} + +JNIEXPORT jlong JNICALL +Java_org_pytorch_executorch_training_SGD_nativeCreate( + JNIEnv* env, + jclass /* clazz */, + jobject namedParameters, + jdouble learningRate, + jdouble momentum, + jdouble dampening, + jdouble weightDecay, + jboolean nesterov) { + auto* native = new executorch::extension::SGDNative( + env, + namedParameters, + learningRate, + momentum, + dampening, + weightDecay, + nesterov); + return reinterpret_cast(native); +} + +JNIEXPORT void JNICALL +Java_org_pytorch_executorch_training_SGD_nativeDestroy( + JNIEnv* /* env */, + jclass /* clazz */, + jlong nativeHandle) { + if (nativeHandle != 0) { + auto* native = + reinterpret_cast(nativeHandle); + delete native; + } +} + +JNIEXPORT void JNICALL +Java_org_pytorch_executorch_training_SGD_nativeStep( + JNIEnv* env, + jclass /* clazz */, + jlong nativeHandle, + jobject namedGradients) { + auto* native = + reinterpret_cast(nativeHandle); + if (native == nullptr) { + throwJavaException(env, "Native handle is null"); + return; } - private: - friend HybridBase; - std::unique_ptr sgdOptimizer_; - std::vector - parameterNames_; // Store parameter names to keep string_view valid - std::vector - paramTensorPtrs_; // Store parameter tensors to keep TensorPtrs valid. -}; + g_training_cache.init(env); + + std::map cppNamedGradients; + std::vector gradientNames; + std::vector tensorKeepalives; + + // Get the size of the map + jint mapSize = + env->CallIntMethod(namedGradients, g_training_cache.map_size); + + gradientNames.reserve(mapSize); + tensorKeepalives.reserve(mapSize); + + // Get entry set and iterate + jobject entrySet = + env->CallObjectMethod(namedGradients, g_training_cache.map_entrySet); + jobject iterator = + env->CallObjectMethod(entrySet, g_training_cache.set_iterator); + + while (env->CallBooleanMethod(iterator, g_training_cache.iterator_hasNext)) { + jobject entry = + env->CallObjectMethod(iterator, g_training_cache.iterator_next); + jstring key = static_cast( + env->CallObjectMethod(entry, g_training_cache.entry_getKey)); + jobject value = + env->CallObjectMethod(entry, g_training_cache.entry_getValue); + + std::string gradName = jstring_to_string(env, key); + TensorPtr tensor = newTensorFromJTensor(env, value); + + // Store the gradient name and tensor + gradientNames.push_back(gradName); + tensorKeepalives.push_back(tensor); + cppNamedGradients.emplace( + std::string_view(gradientNames.back()), *tensor); + + env->DeleteLocalRef(key); + env->DeleteLocalRef(value); + env->DeleteLocalRef(entry); + } -} // namespace executorch::extension + env->DeleteLocalRef(iterator); + env->DeleteLocalRef(entrySet); + + auto result = native->sgdOptimizer_->step(cppNamedGradients); + if (result != ::executorch::runtime::Error::Ok) { + std::stringstream ss; + ss << "SGD optimization step failed with status 0x" << std::hex + << static_cast(result); + throwJavaException(env, ss.str().c_str()); + } +} + +} // extern "C" // Function to register training module natives -void register_natives_for_training() { - executorch::extension::ExecuTorchTrainingJni::registerNatives(); - executorch::extension::SGDHybrid::registerNatives(); -}; +void register_natives_for_training(JNIEnv* env) { + // Register TrainingModule natives + jclass training_module_class = + env->FindClass("org/pytorch/executorch/training/TrainingModule"); + if (training_module_class == nullptr) { + ET_LOG(Error, "Failed to find TrainingModule class"); + env->ExceptionClear(); + return; + } + + // clang-format off + static const JNINativeMethod training_methods[] = { + {"nativeCreate", "(Ljava/lang/String;Ljava/lang/String;)J", + reinterpret_cast( + Java_org_pytorch_executorch_training_TrainingModule_nativeCreate)}, + {"nativeDestroy", "(J)V", + reinterpret_cast( + Java_org_pytorch_executorch_training_TrainingModule_nativeDestroy)}, + {"nativeExecuteForwardBackward", + "(JLjava/lang/String;[Lorg/pytorch/executorch/EValue;)[Lorg/pytorch/executorch/EValue;", + reinterpret_cast( + Java_org_pytorch_executorch_training_TrainingModule_nativeExecuteForwardBackward)}, + {"nativeNamedParameters", "(JLjava/lang/String;)Ljava/util/Map;", + reinterpret_cast( + Java_org_pytorch_executorch_training_TrainingModule_nativeNamedParameters)}, + {"nativeNamedGradients", "(JLjava/lang/String;)Ljava/util/Map;", + reinterpret_cast( + Java_org_pytorch_executorch_training_TrainingModule_nativeNamedGradients)}, + }; + // clang-format on + + int result = env->RegisterNatives( + training_module_class, + training_methods, + sizeof(training_methods) / sizeof(training_methods[0])); + if (result != JNI_OK) { + ET_LOG(Error, "Failed to register native methods for TrainingModule"); + } + + env->DeleteLocalRef(training_module_class); + + // Register SGD natives + jclass sgd_class = env->FindClass("org/pytorch/executorch/training/SGD"); + if (sgd_class == nullptr) { + ET_LOG(Error, "Failed to find SGD class"); + env->ExceptionClear(); + return; + } + + // clang-format off + static const JNINativeMethod sgd_methods[] = { + {"nativeCreate", "(Ljava/util/Map;DDDDZ)J", + reinterpret_cast( + Java_org_pytorch_executorch_training_SGD_nativeCreate)}, + {"nativeDestroy", "(J)V", + reinterpret_cast( + Java_org_pytorch_executorch_training_SGD_nativeDestroy)}, + {"nativeStep", "(JLjava/util/Map;)V", + reinterpret_cast( + Java_org_pytorch_executorch_training_SGD_nativeStep)}, + }; + // clang-format on + + result = env->RegisterNatives( + sgd_class, sgd_methods, sizeof(sgd_methods) / sizeof(sgd_methods[0])); + if (result != JNI_OK) { + ET_LOG(Error, "Failed to register native methods for SGD"); + } + + env->DeleteLocalRef(sgd_class); +}