diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java
index ab3b77ff1fb..e0122e3979e 100644
--- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java
+++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java
@@ -8,7 +8,6 @@
package org.pytorch.executorch;
-import com.facebook.jni.annotations.DoNotStrip;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Locale;
@@ -33,7 +32,6 @@
*
Warning: These APIs are experimental and subject to change without notice
*/
@Experimental
-@DoNotStrip
public class EValue {
private static final int TYPE_CODE_NONE = 0;
@@ -47,52 +45,50 @@ public class EValue {
"None", "Tensor", "String", "Double", "Int", "Bool",
};
- @DoNotStrip private final int mTypeCode;
- @DoNotStrip private Object mData;
+ final int mTypeCode;
+ Object mData;
- @DoNotStrip
private EValue(int typeCode) {
this.mTypeCode = typeCode;
}
- @DoNotStrip
public boolean isNone() {
return TYPE_CODE_NONE == this.mTypeCode;
}
- @DoNotStrip
+
public boolean isTensor() {
return TYPE_CODE_TENSOR == this.mTypeCode;
}
- @DoNotStrip
+
public boolean isBool() {
return TYPE_CODE_BOOL == this.mTypeCode;
}
- @DoNotStrip
+
public boolean isInt() {
return TYPE_CODE_INT == this.mTypeCode;
}
- @DoNotStrip
+
public boolean isDouble() {
return TYPE_CODE_DOUBLE == this.mTypeCode;
}
- @DoNotStrip
+
public boolean isString() {
return TYPE_CODE_STRING == this.mTypeCode;
}
/** Creates a new {@code EValue} of type {@code Optional} that contains no value. */
- @DoNotStrip
+
public static EValue optionalNone() {
return new EValue(TYPE_CODE_NONE);
}
/** Creates a new {@code EValue} of type {@code Tensor}. */
- @DoNotStrip
+
public static EValue from(Tensor tensor) {
final EValue iv = new EValue(TYPE_CODE_TENSOR);
iv.mData = tensor;
@@ -100,7 +96,7 @@ public static EValue from(Tensor tensor) {
}
/** Creates a new {@code EValue} of type {@code bool}. */
- @DoNotStrip
+
public static EValue from(boolean value) {
final EValue iv = new EValue(TYPE_CODE_BOOL);
iv.mData = value;
@@ -108,7 +104,7 @@ public static EValue from(boolean value) {
}
/** Creates a new {@code EValue} of type {@code int}. */
- @DoNotStrip
+
public static EValue from(long value) {
final EValue iv = new EValue(TYPE_CODE_INT);
iv.mData = value;
@@ -116,7 +112,7 @@ public static EValue from(long value) {
}
/** Creates a new {@code EValue} of type {@code double}. */
- @DoNotStrip
+
public static EValue from(double value) {
final EValue iv = new EValue(TYPE_CODE_DOUBLE);
iv.mData = value;
@@ -124,38 +120,38 @@ public static EValue from(double value) {
}
/** Creates a new {@code EValue} of type {@code str}. */
- @DoNotStrip
+
public static EValue from(String value) {
final EValue iv = new EValue(TYPE_CODE_STRING);
iv.mData = value;
return iv;
}
- @DoNotStrip
+
public Tensor toTensor() {
preconditionType(TYPE_CODE_TENSOR, mTypeCode);
return (Tensor) mData;
}
- @DoNotStrip
+
public boolean toBool() {
preconditionType(TYPE_CODE_BOOL, mTypeCode);
return (boolean) mData;
}
- @DoNotStrip
+
public long toInt() {
preconditionType(TYPE_CODE_INT, mTypeCode);
return (long) mData;
}
- @DoNotStrip
+
public double toDouble() {
preconditionType(TYPE_CODE_DOUBLE, mTypeCode);
return (double) mData;
}
- @DoNotStrip
+
public String toStr() {
preconditionType(TYPE_CODE_STRING, mTypeCode);
return (String) mData;
diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java
index 8e2f259ef3a..dfa9f77b6dd 100644
--- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java
+++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java
@@ -8,7 +8,6 @@
package org.pytorch.executorch;
-import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
@@ -33,10 +32,16 @@ public static ExecuTorchRuntime getRuntime() {
}
/** Get all registered ops. */
- @DoNotStrip
- public static native String[] getRegisteredOps();
+ public static String[] getRegisteredOps() {
+ return nativeGetRegisteredOps();
+ }
+
+ private static native String[] nativeGetRegisteredOps();
/** Get all registered backends. */
- @DoNotStrip
- public static native String[] getRegisteredBackends();
+ public static String[] getRegisteredBackends() {
+ return nativeGetRegisteredBackends();
+ }
+
+ private static native String[] nativeGetRegisteredBackends();
}
diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java
index 6da76bf4b74..481165f4e21 100644
--- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java
+++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java
@@ -9,8 +9,6 @@
package org.pytorch.executorch;
import android.util.Log;
-import com.facebook.jni.HybridData;
-import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.io.File;
@@ -48,18 +46,18 @@ public class Module {
/** Load mode for the module. Use memory locking and ignore errors. */
public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3;
- private final HybridData mHybridData;
+ private long mNativeHandle;
private final Map mMethodMetadata;
- @DoNotStrip
- private static native HybridData initHybrid(
- String moduleAbsolutePath, int loadMode, int initHybrid);
+ private static native long nativeCreate(String moduleAbsolutePath, int loadMode, int numThreads);
+
+ private static native void nativeDestroy(long nativeHandle);
private Module(String moduleAbsolutePath, int loadMode, int numThreads) {
ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime();
- mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads);
+ mNativeHandle = nativeCreate(moduleAbsolutePath, loadMode, numThreads);
mMethodMetadata = populateMethodMeta();
}
@@ -75,7 +73,7 @@ Map populateMethodMeta() {
return metadata;
}
- /** Lock protecting the non-thread safe methods in mHybridData. */
+ /** Lock protecting the non-thread safe methods in native handle. */
private Lock mLock = new ReentrantLock();
/**
@@ -138,18 +136,18 @@ public EValue[] forward(EValue... inputs) {
public EValue[] execute(String methodName, EValue... inputs) {
try {
mLock.lock();
- if (!mHybridData.isValid()) {
+ if (mNativeHandle == 0) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new EValue[0];
}
- return executeNative(methodName, inputs);
+ return nativeExecute(mNativeHandle, methodName, inputs);
} finally {
mLock.unlock();
}
}
- @DoNotStrip
- private native EValue[] executeNative(String methodName, EValue... inputs);
+ private static native EValue[] nativeExecute(
+ long nativeHandle, String methodName, EValue... inputs);
/**
* Load a method on this module. This might help with the first time inference performance,
@@ -163,18 +161,17 @@ public EValue[] execute(String methodName, EValue... inputs) {
public int loadMethod(String methodName) {
try {
mLock.lock();
- if (!mHybridData.isValid()) {
+ if (mNativeHandle == 0) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return 0x2; // InvalidState
}
- return loadMethodNative(methodName);
+ return nativeLoadMethod(mNativeHandle, methodName);
} finally {
mLock.unlock();
}
}
- @DoNotStrip
- private native int loadMethodNative(String methodName);
+ private static native int nativeLoadMethod(long nativeHandle, String methodName);
/**
* Returns the names of the backends in a certain method.
@@ -182,16 +179,22 @@ public int loadMethod(String methodName) {
* @param methodName method name to query
* @return an array of backend name
*/
- @DoNotStrip
- private native String[] getUsedBackends(String methodName);
+ public String[] getUsedBackends(String methodName) {
+ return nativeGetUsedBackends(mNativeHandle, methodName);
+ }
+
+ private static native String[] nativeGetUsedBackends(long nativeHandle, String methodName);
/**
* Returns the names of methods.
*
* @return name of methods in this Module
*/
- @DoNotStrip
- public native String[] getMethods();
+ public String[] getMethods() {
+ return nativeGetMethods(mNativeHandle);
+ }
+
+ private static native String[] nativeGetMethods(long nativeHandle);
/**
* Get the corresponding @MethodMetadata for a method
@@ -211,20 +214,18 @@ public MethodMetadata getMethodMetadata(String name) {
return methodMetadata;
}
- @DoNotStrip
- private static native String[] readLogBufferStaticNative();
+ private static native String[] nativeReadLogBufferStatic();
public static String[] readLogBufferStatic() {
- return readLogBufferStaticNative();
+ return nativeReadLogBufferStatic();
}
/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
public String[] readLogBuffer() {
- return readLogBufferNative();
+ return nativeReadLogBuffer(mNativeHandle);
}
- @DoNotStrip
- private native String[] readLogBufferNative();
+ private static native String[] nativeReadLogBuffer(long nativeHandle);
/**
* Dump the ExecuTorch ETRecord file to /data/local/tmp/result.etdump.
@@ -234,19 +235,25 @@ public String[] readLogBuffer() {
* @return true if the etdump was successfully written, false otherwise.
*/
@Experimental
- @DoNotStrip
- public native boolean etdump();
+ public boolean etdump() {
+ return nativeEtdump(mNativeHandle);
+ }
+
+ private static native boolean nativeEtdump(long nativeHandle);
/**
* Explicitly destroys the native Module object. Calling this method is not required, as the
* native object will be destroyed when this object is garbage-collected. However, the timing of
* garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory
- * more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
+ * more quickly.
*/
public void destroy() {
if (mLock.tryLock()) {
try {
- mHybridData.resetNative();
+ if (mNativeHandle != 0) {
+ nativeDestroy(mNativeHandle);
+ mNativeHandle = 0;
+ }
} finally {
mLock.unlock();
}
@@ -257,4 +264,13 @@ public void destroy() {
+ " released.");
}
}
+
+ @Override
+ protected void finalize() throws Throwable {
+ if (mNativeHandle != 0) {
+ nativeDestroy(mNativeHandle);
+ mNativeHandle = 0;
+ }
+ super.finalize();
+ }
}
diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java
index e8c0a918b13..a103e3691c2 100644
--- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java
+++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java
@@ -9,8 +9,6 @@
package org.pytorch.executorch;
import android.util.Log;
-import com.facebook.jni.HybridData;
-import com.facebook.jni.annotations.DoNotStrip;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
@@ -53,7 +51,7 @@ public abstract class Tensor {
private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT =
"Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)";
- @DoNotStrip final long[] shape;
+ final long[] shape;
private static final int BYTE_SIZE_BYTES = 1;
private static final int INT_SIZE_BYTES = 4;
@@ -468,7 +466,8 @@ public static Tensor zeros(long[] shape, DType dtype) {
}
}
- @DoNotStrip private HybridData mHybridData;
+ // Native handle for tensor data (unused in pure JNI but kept for API compatibility)
+ private long mNativeHandle;
private Tensor(long[] shape) {
checkShape(shape);
@@ -501,7 +500,6 @@ public long[] shape() {
public abstract DType dtype();
// Called from native
- @DoNotStrip
int dtypeJniCode() {
return dtype().jniCode;
}
@@ -572,7 +570,6 @@ public double[] getDataAsDoubleArray() {
"Tensor of type " + getClass().getSimpleName() + " cannot return data as double array.");
}
- @DoNotStrip
Buffer getRawDataBuffer() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot " + "return raw data buffer.");
@@ -889,9 +886,8 @@ private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[
// endregion checks
// Called from native
- @DoNotStrip
private static Tensor nativeNewTensor(
- ByteBuffer data, long[] shape, int dtype, HybridData hybridData) {
+ ByteBuffer data, long[] shape, int dtype, long nativeHandle) {
Tensor tensor = null;
if (DType.FLOAT.jniCode == dtype) {
@@ -911,7 +907,7 @@ private static Tensor nativeNewTensor(
} else {
tensor = new Tensor_unsupported(data, shape, DType.fromJniCode(dtype));
}
- tensor.mHybridData = hybridData;
+ tensor.mNativeHandle = nativeHandle;
return tensor;
}
diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java
index 5e080e0c369..54494979766 100644
--- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java
+++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java
@@ -8,8 +8,6 @@
package org.pytorch.executorch.extension.llm;
-import com.facebook.jni.HybridData;
-import com.facebook.jni.annotations.DoNotStrip;
import java.io.File;
import java.util.List;
import org.pytorch.executorch.ExecuTorchRuntime;
@@ -28,18 +26,19 @@ public class LlmModule {
public static final int MODEL_TYPE_TEXT_VISION = 2;
public static final int MODEL_TYPE_MULTIMODAL = 2;
- private final HybridData mHybridData;
+ private long mNativeHandle;
private static final int DEFAULT_SEQ_LEN = 128;
private static final boolean DEFAULT_ECHO = true;
- @DoNotStrip
- private static native HybridData initHybrid(
+ private static native long nativeCreate(
int modelType,
String modulePath,
String tokenizerPath,
float temperature,
List dataFiles);
+ private static native void nativeDestroy(long nativeHandle);
+
/**
* Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
* dataFiles.
@@ -61,7 +60,7 @@ public LlmModule(
throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath);
}
- mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, dataFiles);
+ mNativeHandle = nativeCreate(modelType, modulePath, tokenizerPath, temperature, dataFiles);
}
/**
@@ -107,7 +106,16 @@ public LlmModule(LlmModuleConfig config) {
}
public void resetNative() {
- mHybridData.resetNative();
+ if (mNativeHandle != 0) {
+ nativeDestroy(mNativeHandle);
+ mNativeHandle = 0;
+ }
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ resetNative();
+ super.finalize();
}
/**
@@ -150,7 +158,12 @@ public int generate(String prompt, LlmCallback llmCallback, boolean echo) {
* @param llmCallback callback object to receive results
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
*/
- public native int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo);
+ public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) {
+ return nativeGenerate(mNativeHandle, prompt, seqLen, llmCallback, echo);
+ }
+
+ private static native int nativeGenerate(
+ long nativeHandle, String prompt, int seqLen, LlmCallback llmCallback, boolean echo);
/**
* Start generating tokens from the module.
@@ -206,14 +219,15 @@ public int generate(
*/
@Experimental
public long prefillImages(int[] image, int width, int height, int channels) {
- int nativeResult = appendImagesInput(image, width, height, channels);
+ int nativeResult = nativeAppendImagesInput(mNativeHandle, image, width, height, channels);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
return 0;
}
- private native int appendImagesInput(int[] image, int width, int height, int channels);
+ private static native int nativeAppendImagesInput(
+ long nativeHandle, int[] image, int width, int height, int channels);
/**
* Prefill a multimodal Module with the given images input.
@@ -228,15 +242,16 @@ public long prefillImages(int[] image, int width, int height, int channels) {
*/
@Experimental
public long prefillImages(float[] image, int width, int height, int channels) {
- int nativeResult = appendNormalizedImagesInput(image, width, height, channels);
+ int nativeResult =
+ nativeAppendNormalizedImagesInput(mNativeHandle, image, width, height, channels);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
return 0;
}
- private native int appendNormalizedImagesInput(
- float[] image, int width, int height, int channels);
+ private static native int nativeAppendNormalizedImagesInput(
+ long nativeHandle, float[] image, int width, int height, int channels);
/**
* Prefill a multimodal Module with the given audio input.
@@ -251,14 +266,15 @@ private native int appendNormalizedImagesInput(
*/
@Experimental
public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) {
- int nativeResult = appendAudioInput(audio, batch_size, n_bins, n_frames);
+ int nativeResult = nativeAppendAudioInput(mNativeHandle, audio, batch_size, n_bins, n_frames);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
return 0;
}
- private native int appendAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames);
+ private static native int nativeAppendAudioInput(
+ long nativeHandle, byte[] audio, int batch_size, int n_bins, int n_frames);
/**
* Prefill a multimodal Module with the given audio input.
@@ -273,14 +289,16 @@ public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames)
*/
@Experimental
public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) {
- int nativeResult = appendAudioInputFloat(audio, batch_size, n_bins, n_frames);
+ int nativeResult =
+ nativeAppendAudioInputFloat(mNativeHandle, audio, batch_size, n_bins, n_frames);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
return 0;
}
- private native int appendAudioInputFloat(float[] audio, int batch_size, int n_bins, int n_frames);
+ private static native int nativeAppendAudioInputFloat(
+ long nativeHandle, float[] audio, int batch_size, int n_bins, int n_frames);
/**
* Prefill a multimodal Module with the given raw audio input.
@@ -295,15 +313,16 @@ public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames
*/
@Experimental
public long prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) {
- int nativeResult = appendRawAudioInput(audio, batch_size, n_channels, n_samples);
+ int nativeResult =
+ nativeAppendRawAudioInput(mNativeHandle, audio, batch_size, n_channels, n_samples);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
return 0;
}
- private native int appendRawAudioInput(
- byte[] audio, int batch_size, int n_channels, int n_samples);
+ private static native int nativeAppendRawAudioInput(
+ long nativeHandle, byte[] audio, int batch_size, int n_channels, int n_samples);
/**
* Prefill a multimodal Module with the given text input.
@@ -315,7 +334,7 @@ private native int appendRawAudioInput(
*/
@Experimental
public long prefillPrompt(String prompt) {
- int nativeResult = appendTextInput(prompt);
+ int nativeResult = nativeAppendTextInput(mNativeHandle, prompt);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
@@ -323,20 +342,30 @@ public long prefillPrompt(String prompt) {
}
// returns status
- private native int appendTextInput(String prompt);
+ private static native int nativeAppendTextInput(long nativeHandle, String prompt);
/**
* Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM.
*
* The startPos will be reset to 0.
*/
- public native void resetContext();
+ public void resetContext() {
+ nativeResetContext(mNativeHandle);
+ }
+
+ private static native void nativeResetContext(long nativeHandle);
/** Stop current generate() before it finishes. */
- @DoNotStrip
- public native void stop();
+ public void stop() {
+ nativeStop(mNativeHandle);
+ }
+
+ private static native void nativeStop(long nativeHandle);
/** Force loading the module. Otherwise the model is loaded during first generate(). */
- @DoNotStrip
- public native int load();
+ public int load() {
+ return nativeLoad(mNativeHandle);
+ }
+
+ private static native int nativeLoad(long nativeHandle);
}
diff --git a/extension/android/jni/jni_helper.cpp b/extension/android/jni/jni_helper.cpp
index 6491524c7ac..37f9b271e52 100644
--- a/extension/android/jni/jni_helper.cpp
+++ b/extension/android/jni/jni_helper.cpp
@@ -10,6 +10,60 @@
namespace executorch::jni_helper {
+void throwExecutorchException(
+ JNIEnv* env,
+ uint32_t errorCode,
+ const std::string& details) {
+ if (!env) {
+ return;
+ }
+
+ // Find the exception class
+ jclass exceptionClass =
+ env->FindClass("org/pytorch/executorch/ExecutorchRuntimeException");
+ if (exceptionClass == nullptr) {
+ // Class not found, clear the exception and return
+ env->ExceptionClear();
+ return;
+ }
+
+ // Find the static factory method: makeExecutorchException(int, String)
+ jmethodID makeExceptionMethod = env->GetStaticMethodID(
+ exceptionClass,
+ "makeExecutorchException",
+ "(ILjava/lang/String;)Ljava/lang/RuntimeException;");
+ if (makeExceptionMethod == nullptr) {
+ env->ExceptionClear();
+ env->DeleteLocalRef(exceptionClass);
+ return;
+ }
+
+ // Create the details string
+ jstring jDetails = env->NewStringUTF(details.c_str());
+ if (jDetails == nullptr) {
+ env->ExceptionClear();
+ env->DeleteLocalRef(exceptionClass);
+ return;
+ }
+
+ // Call the factory method to create the exception object
+ jobject exception = env->CallStaticObjectMethod(
+ exceptionClass,
+ makeExceptionMethod,
+ static_cast(errorCode),
+ jDetails);
+
+ env->DeleteLocalRef(jDetails);
+
+ if (exception != nullptr) {
+ env->Throw(static_cast(exception));
+ env->DeleteLocalRef(exception);
+ }
+
+ env->DeleteLocalRef(exceptionClass);
+}
+
+#if EXECUTORCH_HAS_FBJNI
void throwExecutorchException(uint32_t errorCode, const std::string& details) {
// Get the current JNI environment
auto env = facebook::jni::Environment::current();
@@ -34,5 +88,6 @@ void throwExecutorchException(uint32_t errorCode, const std::string& details) {
auto exception = makeExceptionMethod(exceptionClass, errorCode, jDetails);
facebook::jni::throwNewJavaException(exception.get());
}
+#endif
} // namespace executorch::jni_helper
diff --git a/extension/android/jni/jni_helper.h b/extension/android/jni/jni_helper.h
index 898c1619d9c..683a3cfe447 100644
--- a/extension/android/jni/jni_helper.h
+++ b/extension/android/jni/jni_helper.h
@@ -8,9 +8,16 @@
#pragma once
-#include
+#include
#include
+#if __has_include()
+#include
+#define EXECUTORCH_HAS_FBJNI 1
+#else
+#define EXECUTORCH_HAS_FBJNI 0
+#endif
+
namespace executorch::jni_helper {
/**
@@ -18,6 +25,25 @@ namespace executorch::jni_helper {
* code and details. Uses the Java factory method
* ExecutorchRuntimeException.makeExecutorchException(int, String).
*
+ * This version takes JNIEnv* directly and works with pure JNI.
+ *
+ * @param env The JNI environment.
+ * @param errorCode The error code from the C++ Executorch runtime.
+ * @param details Additional details to include in the exception message.
+ */
+void throwExecutorchException(
+ JNIEnv* env,
+ uint32_t errorCode,
+ const std::string& details);
+
+#if EXECUTORCH_HAS_FBJNI
+/**
+ * Throws a Java ExecutorchRuntimeException corresponding to the given error
+ * code and details. Uses the Java factory method
+ * ExecutorchRuntimeException.makeExecutorchException(int, String).
+ *
+ * This version uses fbjni to get the current JNI environment.
+ *
* @param errorCode The error code from the C++ Executorch runtime.
* @param details Additional details to include in the exception message.
*/
@@ -29,5 +55,6 @@ struct JExecutorchRuntimeException
static constexpr auto kJavaDescriptor =
"Lorg/pytorch/executorch/ExecutorchRuntimeException;";
};
+#endif
} // namespace executorch::jni_helper
diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp
index 1f8457e00c5..636e4de0b36 100644
--- a/extension/android/jni/jni_layer.cpp
+++ b/extension/android/jni/jni_layer.cpp
@@ -6,6 +6,8 @@
* LICENSE file in the root directory of this source tree.
*/
+#include
+
#include
#include
@@ -39,223 +41,120 @@
#include
#endif
-#include
-#include
-
using namespace executorch::extension;
using namespace torch::executor;
-namespace executorch::extension {
-class TensorHybrid : public facebook::jni::HybridClass {
- public:
- constexpr static const char* kJavaDescriptor =
- "Lorg/pytorch/executorch/Tensor;";
-
- explicit TensorHybrid(executorch::aten::Tensor tensor) {}
-
- static facebook::jni::local_ref
- newJTensorFromTensor(const executorch::aten::Tensor& tensor) {
- // Java wrapper currently only supports contiguous tensors.
+// Helper to convert jstring to std::string (defined outside namespace for broad access)
+static std::string jstring_to_string(JNIEnv* env, jstring jstr) {
+ if (jstr == nullptr) {
+ return "";
+ }
+ const char* chars = env->GetStringUTFChars(jstr, nullptr);
+ if (chars == nullptr) {
+ return "";
+ }
+ std::string result(chars);
+ env->ReleaseStringUTFChars(jstr, chars);
+ return result;
+}
- const auto scalarType = tensor.scalar_type();
- int jdtype = scalar_type_to_java_dtype.at(scalarType);
- if (scalar_type_to_java_dtype.count(scalarType) == 0) {
- std::stringstream ss;
- ss << "executorch::aten::Tensor scalar [java] type: " << jdtype
- << " is not supported on java side";
- jni_helper::throwExecutorchException(
- static_cast(Error::InvalidArgument), ss.str().c_str());
+namespace {
+
+// Global JavaVM pointer for obtaining JNIEnv in callbacks
+JavaVM* g_jvm = nullptr;
+
+// EValue type codes (must match Java EValue class)
+constexpr int kTypeCodeNone = 0;
+constexpr int kTypeCodeTensor = 1;
+constexpr int kTypeCodeString = 2;
+constexpr int kTypeCodeDouble = 3;
+constexpr int kTypeCodeInt = 4;
+constexpr int kTypeCodeBool = 5;
+
+// Cached class and method IDs for performance
+struct JniCache {
+ jclass tensor_class = nullptr;
+ jclass evalue_class = nullptr;
+ jmethodID tensor_nativeNewTensor = nullptr;
+ jmethodID tensor_dtypeJniCode = nullptr;
+ jmethodID tensor_getRawDataBuffer = nullptr;
+ jfieldID tensor_shape = nullptr;
+ jmethodID evalue_from_tensor = nullptr;
+ jmethodID evalue_from_long = nullptr;
+ jmethodID evalue_from_double = nullptr;
+ jmethodID evalue_from_bool = nullptr;
+ jmethodID evalue_from_string = nullptr;
+ jmethodID evalue_toTensor = nullptr;
+ jfieldID evalue_mTypeCode = nullptr;
+ jfieldID evalue_mData = nullptr;
+
+ bool initialized = false;
+
+ void init(JNIEnv* env) {
+ if (initialized) {
+ return;
}
- const auto& tensor_shape = tensor.sizes();
- std::vector tensor_shape_vec;
- for (const auto& s : tensor_shape) {
- tensor_shape_vec.push_back(s);
- }
- facebook::jni::local_ref jTensorShape =
- facebook::jni::make_long_array(tensor_shape_vec.size());
- jTensorShape->setRegion(
- 0, tensor_shape_vec.size(), tensor_shape_vec.data());
-
- static auto cls = TensorHybrid::javaClassStatic();
- // Note: this is safe as long as the data stored in tensor is valid; the
- // data won't go out of scope as long as the Method for the inference is
- // valid and there is no other inference call. Java layer picks up this
- // value immediately so the data is valid.
- facebook::jni::local_ref jTensorBuffer =
- facebook::jni::JByteBuffer::wrapBytes(
- (uint8_t*)tensor.data_ptr(), tensor.nbytes());
- jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder());
-
- static const auto jMethodNewTensor =
- cls->getStaticMethod(
- facebook::jni::alias_ref,
- facebook::jni::alias_ref,
- jint,
- facebook::jni::alias_ref)>("nativeNewTensor");
- return jMethodNewTensor(
- cls, jTensorBuffer, jTensorShape, jdtype, makeCxxInstance(tensor));
- }
-
- static TensorPtr newTensorFromJTensor(
- facebook::jni::alias_ref jtensor) {
- static auto cls = TensorHybrid::javaClassStatic();
- static const auto dtypeMethod = cls->getMethod("dtypeJniCode");
- jint jdtype = dtypeMethod(jtensor);
-
- static const auto shapeField = cls->getField("shape");
- auto jshape = jtensor->getFieldValue(shapeField);
-
- static auto dataBufferMethod = cls->getMethod<
- facebook::jni::local_ref()>(
- "getRawDataBuffer");
- facebook::jni::local_ref jbuffer =
- dataBufferMethod(jtensor);
-
- const auto rank = jshape->size();
-
- const auto shapeArr = jshape->getRegion(0, rank);
- std::vector shape_vec;
- shape_vec.reserve(rank);
-
- int64_t numel = 1;
- for (int i = 0; i < rank; i++) {
- shape_vec.push_back(shapeArr[i]);
+ // Cache Tensor class and methods
+ jclass local_tensor_class = env->FindClass("org/pytorch/executorch/Tensor");
+ if (local_tensor_class != nullptr) {
+ tensor_class = static_cast(env->NewGlobalRef(local_tensor_class));
+ env->DeleteLocalRef(local_tensor_class);
+
+ tensor_nativeNewTensor = env->GetStaticMethodID(
+ tensor_class,
+ "nativeNewTensor",
+ "(Ljava/nio/ByteBuffer;[JIJ)Lorg/pytorch/executorch/Tensor;");
+ tensor_dtypeJniCode = env->GetMethodID(tensor_class, "dtypeJniCode", "()I");
+ tensor_getRawDataBuffer =
+ env->GetMethodID(tensor_class, "getRawDataBuffer", "()Ljava/nio/Buffer;");
+ tensor_shape = env->GetFieldID(tensor_class, "shape", "[J");
}
- for (int i = rank - 1; i >= 0; --i) {
- numel *= shapeArr[i];
- }
- JNIEnv* jni = facebook::jni::Environment::current();
- if (java_dtype_to_scalar_type.count(jdtype) == 0) {
- std::stringstream ss;
- ss << "Unknown Tensor jdtype: [" << jdtype << "]";
- jni_helper::throwExecutorchException(
- static_cast(Error::InvalidArgument), ss.str().c_str());
- }
- ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype);
- const jlong dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
- if (dataCapacity < 0) {
- std::stringstream ss;
- ss << "Tensor buffer is not direct or has invalid capacity";
- jni_helper::throwExecutorchException(
- static_cast(Error::InvalidArgument), ss.str().c_str());
- }
- const size_t elementSize = executorch::runtime::elementSize(scalar_type);
- const jlong expectedElements = static_cast(numel);
- const jlong expectedBytes =
- expectedElements * static_cast(elementSize);
- const bool matchesElements = dataCapacity == expectedElements;
- const bool matchesBytes = dataCapacity == expectedBytes;
- if (!matchesElements && !matchesBytes) {
- std::stringstream ss;
- ss << "Tensor dimensions(elements number: " << numel
- << ") inconsistent with buffer capacity " << dataCapacity
- << " (element size bytes: " << elementSize << ")";
- jni_helper::throwExecutorchException(
- static_cast(Error::InvalidArgument), ss.str().c_str());
- }
- return from_blob(
- jni->GetDirectBufferAddress(jbuffer.get()), shape_vec, scalar_type);
- }
-
- private:
- friend HybridBase;
-};
-class JEValue : public facebook::jni::JavaClass {
- public:
- constexpr static const char* kJavaDescriptor =
- "Lorg/pytorch/executorch/EValue;";
-
- constexpr static int kTypeCodeTensor = 1;
- constexpr static int kTypeCodeString = 2;
- constexpr static int kTypeCodeDouble = 3;
- constexpr static int kTypeCodeInt = 4;
- constexpr static int kTypeCodeBool = 5;
-
- static facebook::jni::local_ref newJEValueFromEValue(EValue evalue) {
- if (evalue.isTensor()) {
- static auto jMethodTensor =
- JEValue::javaClassStatic()
- ->getStaticMethod(
- facebook::jni::local_ref)>("from");
- return jMethodTensor(
- JEValue::javaClassStatic(),
- TensorHybrid::newJTensorFromTensor(evalue.toTensor()));
- } else if (evalue.isInt()) {
- static auto jMethodTensor =
- JEValue::javaClassStatic()
- ->getStaticMethod(jlong)>(
- "from");
- return jMethodTensor(JEValue::javaClassStatic(), evalue.toInt());
- } else if (evalue.isDouble()) {
- static auto jMethodTensor =
- JEValue::javaClassStatic()
- ->getStaticMethod(jdouble)>(
- "from");
- return jMethodTensor(JEValue::javaClassStatic(), evalue.toDouble());
- } else if (evalue.isBool()) {
- static auto jMethodTensor =
- JEValue::javaClassStatic()
- ->getStaticMethod(jboolean)>(
- "from");
- return jMethodTensor(JEValue::javaClassStatic(), evalue.toBool());
- } else if (evalue.isString()) {
- static auto jMethodTensor =
- JEValue::javaClassStatic()
- ->getStaticMethod(
- facebook::jni::local_ref)>("from");
- std::string str =
- std::string(evalue.toString().begin(), evalue.toString().end());
- return jMethodTensor(
- JEValue::javaClassStatic(), facebook::jni::make_jstring(str));
- }
- std::stringstream ss;
- ss << "Unknown EValue type: [" << static_cast(evalue.tag) << "]";
- jni_helper::throwExecutorchException(
- static_cast(Error::InvalidArgument), ss.str().c_str());
- return {};
- }
-
- static TensorPtr JEValueToTensorImpl(
- facebook::jni::alias_ref JEValue) {
- static const auto typeCodeField =
- JEValue::javaClassStatic()->getField("mTypeCode");
- const auto typeCode = JEValue->getFieldValue(typeCodeField);
- if (JEValue::kTypeCodeTensor == typeCode) {
- static const auto jMethodGetTensor =
- JEValue::javaClassStatic()
- ->getMethod()>(
- "toTensor");
- auto jtensor = jMethodGetTensor(JEValue);
- return TensorHybrid::newTensorFromJTensor(jtensor);
+ // Cache EValue class and methods
+ jclass local_evalue_class = env->FindClass("org/pytorch/executorch/EValue");
+ if (local_evalue_class != nullptr) {
+ evalue_class = static_cast(env->NewGlobalRef(local_evalue_class));
+ env->DeleteLocalRef(local_evalue_class);
+
+ evalue_from_tensor = env->GetStaticMethodID(
+ evalue_class,
+ "from",
+ "(Lorg/pytorch/executorch/Tensor;)Lorg/pytorch/executorch/EValue;");
+ evalue_from_long =
+ env->GetStaticMethodID(evalue_class, "from", "(J)Lorg/pytorch/executorch/EValue;");
+ evalue_from_double =
+ env->GetStaticMethodID(evalue_class, "from", "(D)Lorg/pytorch/executorch/EValue;");
+ evalue_from_bool =
+ env->GetStaticMethodID(evalue_class, "from", "(Z)Lorg/pytorch/executorch/EValue;");
+ evalue_from_string = env->GetStaticMethodID(
+ evalue_class,
+ "from",
+ "(Ljava/lang/String;)Lorg/pytorch/executorch/EValue;");
+ evalue_toTensor = env->GetMethodID(
+ evalue_class, "toTensor", "()Lorg/pytorch/executorch/Tensor;");
+ evalue_mTypeCode = env->GetFieldID(evalue_class, "mTypeCode", "I");
+ evalue_mData = env->GetFieldID(evalue_class, "mData", "Ljava/lang/Object;");
}
- std::stringstream ss;
- ss << "Unknown EValue typeCode: " << typeCode;
- jni_helper::throwExecutorchException(
- static_cast(Error::InvalidArgument), ss.str().c_str());
- return {};
+
+ initialized = true;
}
};
-class ExecuTorchJni : public facebook::jni::HybridClass {
- private:
- friend HybridBase;
- std::unique_ptr module_;
+JniCache g_jni_cache;
- public:
- constexpr static auto kJavaDescriptor = "Lorg/pytorch/executorch/Module;";
+} // anonymous namespace
- static facebook::jni::local_ref initHybrid(
- facebook::jni::alias_ref,
- facebook::jni::alias_ref modelPath,
- jint loadMode,
- jint numThreads) {
- return makeCxxInstance(modelPath, loadMode, numThreads);
- }
+namespace executorch::extension {
+
+// Native module handle class - named ExecuTorchJni to match friend declaration in Module
+class ExecuTorchJni {
+ public:
+ std::unique_ptr module_;
ExecuTorchJni(
- facebook::jni::alias_ref modelPath,
+ JNIEnv* env,
+ jstring modelPath,
jint loadMode,
jint numThreads) {
Module::LoadMode load_mode = Module::LoadMode::Mmap;
@@ -273,17 +172,10 @@ class ExecuTorchJni : public facebook::jni::HybridClass {
#else
auto etdump_gen = nullptr;
#endif
- module_ = std::make_unique(
- modelPath->toStdString(), load_mode, std::move(etdump_gen));
+ std::string path = jstring_to_string(env, modelPath);
+ module_ = std::make_unique(path, load_mode, std::move(etdump_gen));
#ifdef ET_USE_THREADPOOL
- // Default to using cores/2 threadpool threads. The long-term plan is to
- // improve performant core detection in CPUInfo, but for now we can use
- // cores/2 as a sane default.
- //
- // Based on testing, this is almost universally faster than using all
- // cores, as efficiency cores can be quite slow. In extreme cases, using
- // all cores can be 10x slower than using cores/2.
auto threadpool = executorch::extension::threadpool::get_threadpool();
if (threadpool) {
int thread_count =
@@ -295,265 +187,607 @@ class ExecuTorchJni : public facebook::jni::HybridClass {
#endif
}
- facebook::jni::local_ref> execute(
- facebook::jni::alias_ref methodName,
- facebook::jni::alias_ref<
- facebook::jni::JArrayClass::javaobject>
- jinputs) {
- return execute_method(methodName->toStdString(), jinputs);
+ // Access protected methods_ member (friend class privilege)
+ Method* get_method(const std::string& method_name) {
+ auto it = module_->methods_.find(method_name);
+ if (it != module_->methods_.end()) {
+ return it->second.method.get();
+ }
+ return nullptr;
}
+};
+
+} // namespace executorch::extension
+
+namespace {
+
+// Helper to create Java Tensor from native tensor
+jobject newJTensorFromTensor(JNIEnv* env, const executorch::aten::Tensor& tensor) {
+ g_jni_cache.init(env);
- jint load_method(facebook::jni::alias_ref methodName) {
- return static_cast(module_->load_method(methodName->toStdString()));
+ const auto scalarType = tensor.scalar_type();
+ if (scalar_type_to_java_dtype.count(scalarType) == 0) {
+ std::stringstream ss;
+ ss << "executorch::aten::Tensor scalar type is not supported on java side";
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(Error::InvalidArgument), ss.str().c_str());
+ return nullptr;
}
+ int jdtype = scalar_type_to_java_dtype.at(scalarType);
- facebook::jni::local_ref> execute_method(
- std::string method,
- facebook::jni::alias_ref<
- facebook::jni::JArrayClass::javaobject>
- jinputs) {
- // If no inputs is given, it will run with sample inputs (ones)
- if (jinputs->size() == 0) {
- auto result = module_->load_method(method);
- if (result != Error::Ok) {
- // Format hex string
- std::stringstream ss;
- ss << "Cannot get method names [Native Error: 0x" << std::hex
- << std::uppercase << static_cast(result) << "]";
+ // Create shape array
+ const auto& tensor_shape = tensor.sizes();
+ jlongArray jTensorShape = env->NewLongArray(tensor_shape.size());
+ if (jTensorShape == nullptr) {
+ return nullptr;
+ }
+ std::vector shape_vec;
+ for (const auto& s : tensor_shape) {
+ shape_vec.push_back(s);
+ }
+ env->SetLongArrayRegion(jTensorShape, 0, shape_vec.size(), shape_vec.data());
+
+ // Create ByteBuffer wrapping tensor data
+ jobject jTensorBuffer = env->NewDirectByteBuffer(
+ const_cast(tensor.const_data_ptr()), tensor.nbytes());
+ if (jTensorBuffer == nullptr) {
+ env->DeleteLocalRef(jTensorShape);
+ return nullptr;
+ }
- jni_helper::throwExecutorchException(
- static_cast(result), ss.str());
- return {};
- }
- auto&& underlying_method = module_->methods_[method].method;
- auto&& buf = prepare_input_tensors(*underlying_method);
- result = underlying_method->execute();
- if (result != Error::Ok) {
- jni_helper::throwExecutorchException(
- static_cast(result),
- "Execution failed for method: " + method);
- return {};
- }
- facebook::jni::local_ref> jresult =
- facebook::jni::JArrayClass::newArray(
- underlying_method->outputs_size());
-
- for (int i = 0; i < underlying_method->outputs_size(); i++) {
- auto jevalue =
- JEValue::newJEValueFromEValue(underlying_method->get_output(i));
- jresult->setElement(i, *jevalue);
- }
- return jresult;
- }
+ // Set byte order to native order
+ jclass byteBufferClass = env->FindClass("java/nio/ByteBuffer");
+ jmethodID orderMethod =
+ env->GetMethodID(byteBufferClass, "order", "(Ljava/nio/ByteOrder;)Ljava/nio/ByteBuffer;");
+ jclass byteOrderClass = env->FindClass("java/nio/ByteOrder");
+ jmethodID nativeOrderMethod =
+ env->GetStaticMethodID(byteOrderClass, "nativeOrder", "()Ljava/nio/ByteOrder;");
+ jobject nativeOrder = env->CallStaticObjectMethod(byteOrderClass, nativeOrderMethod);
+ env->CallObjectMethod(jTensorBuffer, orderMethod, nativeOrder);
+
+ env->DeleteLocalRef(byteBufferClass);
+ env->DeleteLocalRef(byteOrderClass);
+ env->DeleteLocalRef(nativeOrder);
+
+ // Call nativeNewTensor static method (pass 0 for nativeHandle since we don't need it)
+ jobject result = env->CallStaticObjectMethod(
+ g_jni_cache.tensor_class,
+ g_jni_cache.tensor_nativeNewTensor,
+ jTensorBuffer,
+ jTensorShape,
+ jdtype,
+ static_cast(0));
+
+ env->DeleteLocalRef(jTensorBuffer);
+ env->DeleteLocalRef(jTensorShape);
+
+ return result;
+}
- std::vector evalues;
- std::vector tensors;
-
- static const auto typeCodeField =
- JEValue::javaClassStatic()->getField("mTypeCode");
-
- for (int i = 0; i < jinputs->size(); i++) {
- auto jevalue = jinputs->getElement(i);
- const auto typeCode = jevalue->getFieldValue(typeCodeField);
- if (typeCode == JEValue::kTypeCodeTensor) {
- tensors.emplace_back(JEValue::JEValueToTensorImpl(jevalue));
- evalues.emplace_back(tensors.back());
- } else if (typeCode == JEValue::kTypeCodeInt) {
- int64_t value = jevalue->getFieldValue(typeCodeField);
- evalues.emplace_back(value);
- } else if (typeCode == JEValue::kTypeCodeDouble) {
- double value = jevalue->getFieldValue(typeCodeField);
- evalues.emplace_back(value);
- } else if (typeCode == JEValue::kTypeCodeBool) {
- bool value = jevalue->getFieldValue(typeCodeField);
- evalues.emplace_back(value);
- }
- }
+// Helper to create native TensorPtr from Java Tensor
+TensorPtr newTensorFromJTensor(JNIEnv* env, jobject jtensor) {
+ g_jni_cache.init(env);
-#ifdef EXECUTORCH_ANDROID_PROFILING
- auto start = std::chrono::high_resolution_clock::now();
- auto result = module_->execute(method, evalues);
- auto end = std::chrono::high_resolution_clock::now();
- auto duration =
- std::chrono::duration_cast(end - start)
- .count();
- ET_LOG(Debug, "Execution time: %lld ms.", duration);
+ jint jdtype = env->CallIntMethod(jtensor, g_jni_cache.tensor_dtypeJniCode);
-#else
- auto result = module_->execute(method, evalues);
+ jlongArray jshape =
+ static_cast(env->GetObjectField(jtensor, g_jni_cache.tensor_shape));
-#endif
+ jobject jbuffer = env->CallObjectMethod(jtensor, g_jni_cache.tensor_getRawDataBuffer);
- if (!result.ok()) {
- jni_helper::throwExecutorchException(
- static_cast(result.error()),
- "Execution failed for method: " + method);
- return {};
- }
+ jsize rank = env->GetArrayLength(jshape);
+
+ std::vector shapeArr(rank);
+ env->GetLongArrayRegion(jshape, 0, rank, shapeArr.data());
+
+ std::vector shape_vec;
+ shape_vec.reserve(rank);
+
+ int64_t numel = 1;
+ for (int i = 0; i < rank; i++) {
+ shape_vec.push_back(shapeArr[i]);
+ }
+ for (int i = rank - 1; i >= 0; --i) {
+ numel *= shapeArr[i];
+ }
- facebook::jni::local_ref> jresult =
- facebook::jni::JArrayClass::newArray(result.get().size());
+ if (java_dtype_to_scalar_type.count(jdtype) == 0) {
+ std::stringstream ss;
+ ss << "Unknown Tensor jdtype: [" << jdtype << "]";
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(Error::InvalidArgument), ss.str().c_str());
+ env->DeleteLocalRef(jshape);
+ env->DeleteLocalRef(jbuffer);
+ return nullptr;
+ }
+
+ ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype);
+ const jlong dataCapacity = env->GetDirectBufferCapacity(jbuffer);
+ if (dataCapacity < 0) {
+ std::stringstream ss;
+ ss << "Tensor buffer is not direct or has invalid capacity";
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(Error::InvalidArgument), ss.str().c_str());
+ env->DeleteLocalRef(jshape);
+ env->DeleteLocalRef(jbuffer);
+ return nullptr;
+ }
+
+ const size_t elementSize = executorch::runtime::elementSize(scalar_type);
+ const jlong expectedElements = static_cast(numel);
+ const jlong expectedBytes = expectedElements * static_cast(elementSize);
+ const bool matchesElements = dataCapacity == expectedElements;
+ const bool matchesBytes = dataCapacity == expectedBytes;
+
+ if (!matchesElements && !matchesBytes) {
+ std::stringstream ss;
+ ss << "Tensor dimensions(elements number: " << numel
+ << ") inconsistent with buffer capacity " << dataCapacity
+ << " (element size bytes: " << elementSize << ")";
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(Error::InvalidArgument), ss.str().c_str());
+ env->DeleteLocalRef(jshape);
+ env->DeleteLocalRef(jbuffer);
+ return nullptr;
+ }
+
+ void* data = env->GetDirectBufferAddress(jbuffer);
+ TensorPtr result = from_blob(data, shape_vec, scalar_type);
- for (int i = 0; i < result.get().size(); i++) {
- auto jevalue = JEValue::newJEValueFromEValue(result.get()[i]);
- jresult->setElement(i, *jevalue);
+ env->DeleteLocalRef(jshape);
+ env->DeleteLocalRef(jbuffer);
+
+ return result;
+}
+
+// Helper to create Java EValue from native EValue
+jobject newJEValueFromEValue(JNIEnv* env, EValue evalue) {
+ g_jni_cache.init(env);
+
+ if (evalue.isTensor()) {
+ jobject jtensor = newJTensorFromTensor(env, evalue.toTensor());
+ if (jtensor == nullptr) {
+ return nullptr;
}
- return jresult;
+ jobject result = env->CallStaticObjectMethod(
+ g_jni_cache.evalue_class, g_jni_cache.evalue_from_tensor, jtensor);
+ env->DeleteLocalRef(jtensor);
+ return result;
+ } else if (evalue.isInt()) {
+ return env->CallStaticObjectMethod(
+ g_jni_cache.evalue_class, g_jni_cache.evalue_from_long, evalue.toInt());
+ } else if (evalue.isDouble()) {
+ return env->CallStaticObjectMethod(
+ g_jni_cache.evalue_class, g_jni_cache.evalue_from_double, evalue.toDouble());
+ } else if (evalue.isBool()) {
+ return env->CallStaticObjectMethod(
+ g_jni_cache.evalue_class,
+ g_jni_cache.evalue_from_bool,
+ static_cast(evalue.toBool()));
+ } else if (evalue.isString()) {
+ std::string str =
+ std::string(evalue.toString().begin(), evalue.toString().end());
+ jstring jstr = env->NewStringUTF(str.c_str());
+ jobject result = env->CallStaticObjectMethod(
+ g_jni_cache.evalue_class, g_jni_cache.evalue_from_string, jstr);
+ env->DeleteLocalRef(jstr);
+ return result;
}
- facebook::jni::local_ref>
- readLogBuffer() {
- return readLogBufferUtil();
+ std::stringstream ss;
+ ss << "Unknown EValue type: [" << static_cast(evalue.tag) << "]";
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(Error::InvalidArgument), ss.str().c_str());
+ return nullptr;
+}
+
+// Helper to get TensorPtr from Java EValue
+TensorPtr JEValueToTensorImpl(JNIEnv* env, jobject jevalue) {
+ g_jni_cache.init(env);
+
+ jint typeCode = env->GetIntField(jevalue, g_jni_cache.evalue_mTypeCode);
+ if (typeCode == kTypeCodeTensor) {
+ jobject jtensor =
+ env->CallObjectMethod(jevalue, g_jni_cache.evalue_toTensor);
+ TensorPtr result = newTensorFromJTensor(env, jtensor);
+ env->DeleteLocalRef(jtensor);
+ return result;
+ }
+
+ std::stringstream ss;
+ ss << "Unknown EValue typeCode: " << typeCode;
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(Error::InvalidArgument), ss.str().c_str());
+ return nullptr;
+}
+
+} // namespace
+
+extern "C" {
+
+JNIEXPORT jlong JNICALL
+Java_org_pytorch_executorch_Module_nativeCreate(
+ JNIEnv* env,
+ jclass /* clazz */,
+ jstring modelPath,
+ jint loadMode,
+ jint numThreads) {
+ auto* native = new executorch::extension::ExecuTorchJni(env, modelPath, loadMode, numThreads);
+ return reinterpret_cast(native);
+}
+
+JNIEXPORT void JNICALL
+Java_org_pytorch_executorch_Module_nativeDestroy(
+ JNIEnv* /* env */,
+ jclass /* clazz */,
+ jlong nativeHandle) {
+ if (nativeHandle != 0) {
+ auto* native = reinterpret_cast(nativeHandle);
+ delete native;
}
+}
- static facebook::jni::local_ref>
- readLogBufferStatic(facebook::jni::alias_ref) {
- return readLogBufferUtil();
+JNIEXPORT jobjectArray JNICALL
+Java_org_pytorch_executorch_Module_nativeExecute(
+ JNIEnv* env,
+ jclass /* clazz */,
+ jlong nativeHandle,
+ jstring methodName,
+ jobjectArray jinputs) {
+ auto* native = reinterpret_cast(nativeHandle);
+ if (native == nullptr) {
+ return nullptr;
}
- static facebook::jni::local_ref>
- readLogBufferUtil() {
-#ifdef __ANDROID__
+ g_jni_cache.init(env);
+
+ std::string method = jstring_to_string(env, methodName);
+ jsize inputSize = jinputs != nullptr ? env->GetArrayLength(jinputs) : 0;
- facebook::jni::local_ref> ret;
-
- access_log_buffer([&](std::vector& buffer) {
- const auto size = buffer.size();
- ret = facebook::jni::JArrayClass::newArray(size);
- for (auto i = 0u; i < size; i++) {
- const auto& entry = buffer[i];
- // Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL
- // MESSAGE".
- std::stringstream ss;
- ss << "[" << entry.timestamp << " " << entry.function << " "
- << entry.filename << ":" << entry.line << "] "
- << static_cast(entry.level) << " " << entry.message;
-
- facebook::jni::local_ref jstr_message =
- facebook::jni::make_jstring(ss.str().c_str());
- (*ret)[i] = jstr_message;
+ // If no inputs is given, it will run with sample inputs (ones)
+ if (inputSize == 0) {
+ auto result = native->module_->load_method(method);
+ if (result != Error::Ok) {
+ std::stringstream ss;
+ ss << "Cannot get method names [Native Error: 0x" << std::hex
+ << std::uppercase << static_cast(result) << "]";
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(result), ss.str());
+ return nullptr;
+ }
+ auto* underlying_method = native->get_method(method);
+ if (underlying_method == nullptr) {
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(Error::InvalidArgument), "Method not found: " + method);
+ return nullptr;
+ }
+ auto&& buf = prepare_input_tensors(*underlying_method);
+ result = underlying_method->execute();
+ if (result != Error::Ok) {
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(result), "Execution failed for method: " + method);
+ return nullptr;
+ }
+
+ jobjectArray jresult =
+ env->NewObjectArray(underlying_method->outputs_size(), g_jni_cache.evalue_class, nullptr);
+
+ for (int i = 0; i < underlying_method->outputs_size(); i++) {
+ jobject jevalue = newJEValueFromEValue(env, underlying_method->get_output(i));
+ env->SetObjectArrayElement(jresult, i, jevalue);
+ if (jevalue != nullptr) {
+ env->DeleteLocalRef(jevalue);
}
- });
+ }
+ return jresult;
+ }
- return ret;
-#else
- return facebook::jni::JArrayClass::newArray(0);
-#endif
+ std::vector evalues;
+ std::vector tensors;
+
+ for (int i = 0; i < inputSize; i++) {
+ jobject jevalue = env->GetObjectArrayElement(jinputs, i);
+ jint typeCode = env->GetIntField(jevalue, g_jni_cache.evalue_mTypeCode);
+
+ if (typeCode == kTypeCodeTensor) {
+ tensors.emplace_back(JEValueToTensorImpl(env, jevalue));
+ evalues.emplace_back(tensors.back());
+ } else if (typeCode == kTypeCodeInt) {
+ jobject mData = env->GetObjectField(jevalue, g_jni_cache.evalue_mData);
+ jclass longClass = env->FindClass("java/lang/Long");
+ jmethodID longValue = env->GetMethodID(longClass, "longValue", "()J");
+ jlong value = env->CallLongMethod(mData, longValue);
+ evalues.emplace_back(static_cast(value));
+ env->DeleteLocalRef(mData);
+ env->DeleteLocalRef(longClass);
+ } else if (typeCode == kTypeCodeDouble) {
+ jobject mData = env->GetObjectField(jevalue, g_jni_cache.evalue_mData);
+ jclass doubleClass = env->FindClass("java/lang/Double");
+ jmethodID doubleValue = env->GetMethodID(doubleClass, "doubleValue", "()D");
+ jdouble value = env->CallDoubleMethod(mData, doubleValue);
+ evalues.emplace_back(static_cast(value));
+ env->DeleteLocalRef(mData);
+ env->DeleteLocalRef(doubleClass);
+ } else if (typeCode == kTypeCodeBool) {
+ jobject mData = env->GetObjectField(jevalue, g_jni_cache.evalue_mData);
+ jclass boolClass = env->FindClass("java/lang/Boolean");
+ jmethodID boolValue = env->GetMethodID(boolClass, "booleanValue", "()Z");
+ jboolean value = env->CallBooleanMethod(mData, boolValue);
+ evalues.emplace_back(static_cast(value));
+ env->DeleteLocalRef(mData);
+ env->DeleteLocalRef(boolClass);
+ }
+ env->DeleteLocalRef(jevalue);
}
- jboolean etdump() {
#ifdef EXECUTORCH_ANDROID_PROFILING
- executorch::etdump::ETDumpGen* etdumpgen =
- (executorch::etdump::ETDumpGen*)module_->event_tracer();
- auto etdump_data = etdumpgen->get_etdump_data();
-
- if (etdump_data.buf != nullptr && etdump_data.size > 0) {
- int etdump_file =
- open("/data/local/tmp/result.etdump", O_WRONLY | O_CREAT, 0644);
- if (etdump_file == -1) {
- ET_LOG(Error, "Cannot create result.etdump error: %d", errno);
- return false;
- }
- ssize_t bytes_written =
- write(etdump_file, (uint8_t*)etdump_data.buf, etdump_data.size);
- if (bytes_written == -1) {
- ET_LOG(Error, "Cannot write result.etdump error: %d", errno);
- return false;
- } else {
- ET_LOG(Info, "ETDump written %d bytes to file.", bytes_written);
- }
- close(etdump_file);
- free(etdump_data.buf);
- return true;
- } else {
- ET_LOG(Error, "No ETDump data available!");
- }
+ auto start = std::chrono::high_resolution_clock::now();
+ auto result = native->module_->execute(method, evalues);
+ auto end = std::chrono::high_resolution_clock::now();
+ auto duration =
+ std::chrono::duration_cast(end - start).count();
+ ET_LOG(Debug, "Execution time: %lld ms.", duration);
+#else
+ auto result = native->module_->execute(method, evalues);
#endif
- return false;
+
+ if (!result.ok()) {
+ executorch::jni_helper::throwExecutorchException(
+ env,
+ static_cast(result.error()),
+ "Execution failed for method: " + method);
+ return nullptr;
}
- facebook::jni::local_ref> getMethods() {
- const auto& names_result = module_->method_names();
- if (!names_result.ok()) {
- // Format hex string
- std::stringstream ss;
- ss << "Cannot get load module [Native Error: 0x" << std::hex
- << std::uppercase << static_cast(names_result.error())
- << "]";
+ jobjectArray jresult =
+ env->NewObjectArray(result.get().size(), g_jni_cache.evalue_class, nullptr);
- jni_helper::throwExecutorchException(
- static_cast(Error::InvalidArgument), ss.str());
- return {};
- }
- const auto& methods = names_result.get();
- facebook::jni::local_ref> ret =
- facebook::jni::JArrayClass::newArray(methods.size());
- int i = 0;
- for (auto s : methods) {
- facebook::jni::local_ref method_name =
- facebook::jni::make_jstring(s.c_str());
- (*ret)[i] = method_name;
- i++;
+ for (size_t i = 0; i < result.get().size(); i++) {
+ jobject jevalue = newJEValueFromEValue(env, result.get()[i]);
+ env->SetObjectArrayElement(jresult, i, jevalue);
+ if (jevalue != nullptr) {
+ env->DeleteLocalRef(jevalue);
}
- return ret;
}
+ return jresult;
+}
+
+JNIEXPORT jint JNICALL
+Java_org_pytorch_executorch_Module_nativeLoadMethod(
+ JNIEnv* env,
+ jclass /* clazz */,
+ jlong nativeHandle,
+ jstring methodName) {
+ auto* native = reinterpret_cast(nativeHandle);
+ if (native == nullptr) {
+ return -1;
+ }
+ std::string method = jstring_to_string(env, methodName);
+ return static_cast(native->module_->load_method(method));
+}
+
+JNIEXPORT jobjectArray JNICALL
+Java_org_pytorch_executorch_Module_nativeGetMethods(
+ JNIEnv* env,
+ jclass /* clazz */,
+ jlong nativeHandle) {
+ auto* native = reinterpret_cast(nativeHandle);
+ if (native == nullptr) {
+ return nullptr;
+ }
+
+ const auto& names_result = native->module_->method_names();
+ if (!names_result.ok()) {
+ std::stringstream ss;
+ ss << "Cannot get load module [Native Error: 0x" << std::hex
+ << std::uppercase << static_cast(names_result.error()) << "]";
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(Error::InvalidArgument), ss.str());
+ return nullptr;
+ }
+
+ const auto& methods = names_result.get();
+ jclass stringClass = env->FindClass("java/lang/String");
+ jobjectArray ret = env->NewObjectArray(methods.size(), stringClass, nullptr);
+
+ int i = 0;
+ for (auto s : methods) {
+ jstring method_name = env->NewStringUTF(s.c_str());
+ env->SetObjectArrayElement(ret, i, method_name);
+ env->DeleteLocalRef(method_name);
+ i++;
+ }
+ env->DeleteLocalRef(stringClass);
+ return ret;
+}
- facebook::jni::local_ref> getUsedBackends(
- facebook::jni::alias_ref methodName) {
- auto methodMeta = module_->method_meta(methodName->toStdString()).get();
- std::unordered_set backends;
- for (auto i = 0; i < methodMeta.num_backends(); i++) {
- backends.insert(methodMeta.get_backend_name(i).get());
+JNIEXPORT jobjectArray JNICALL
+Java_org_pytorch_executorch_Module_nativeGetUsedBackends(
+ JNIEnv* env,
+ jclass /* clazz */,
+ jlong nativeHandle,
+ jstring methodName) {
+ auto* native = reinterpret_cast(nativeHandle);
+ if (native == nullptr) {
+ return nullptr;
+ }
+
+ std::string method = jstring_to_string(env, methodName);
+ auto methodMeta = native->module_->method_meta(method).get();
+ std::unordered_set backends;
+ for (auto i = 0; i < methodMeta.num_backends(); i++) {
+ backends.insert(methodMeta.get_backend_name(i).get());
+ }
+
+ jclass stringClass = env->FindClass("java/lang/String");
+ jobjectArray ret = env->NewObjectArray(backends.size(), stringClass, nullptr);
+
+ int i = 0;
+ for (auto s : backends) {
+ jstring backend_name = env->NewStringUTF(s.c_str());
+ env->SetObjectArrayElement(ret, i, backend_name);
+ env->DeleteLocalRef(backend_name);
+ i++;
+ }
+ env->DeleteLocalRef(stringClass);
+ return ret;
+}
+
+JNIEXPORT jobjectArray JNICALL
+Java_org_pytorch_executorch_Module_nativeReadLogBuffer(
+ JNIEnv* env,
+ jclass /* clazz */,
+ jlong /* nativeHandle */) {
+#ifdef __ANDROID__
+ jclass stringClass = env->FindClass("java/lang/String");
+ jobjectArray ret = nullptr;
+
+ access_log_buffer([&](std::vector& buffer) {
+ const auto size = buffer.size();
+ ret = env->NewObjectArray(size, stringClass, nullptr);
+ for (auto i = 0u; i < size; i++) {
+ const auto& entry = buffer[i];
+ std::stringstream ss;
+ ss << "[" << entry.timestamp << " " << entry.function << " "
+ << entry.filename << ":" << entry.line << "] "
+ << static_cast(entry.level) << " " << entry.message;
+ jstring jstr_message = env->NewStringUTF(ss.str().c_str());
+ env->SetObjectArrayElement(ret, i, jstr_message);
+ env->DeleteLocalRef(jstr_message);
}
+ });
+
+ env->DeleteLocalRef(stringClass);
+ return ret;
+#else
+ jclass stringClass = env->FindClass("java/lang/String");
+ jobjectArray ret = env->NewObjectArray(0, stringClass, nullptr);
+ env->DeleteLocalRef(stringClass);
+ return ret;
+#endif
+}
+
+JNIEXPORT jobjectArray JNICALL
+Java_org_pytorch_executorch_Module_nativeReadLogBufferStatic(
+ JNIEnv* env,
+ jclass clazz) {
+ return Java_org_pytorch_executorch_Module_nativeReadLogBuffer(env, clazz, 0);
+}
+
+JNIEXPORT jboolean JNICALL
+Java_org_pytorch_executorch_Module_nativeEtdump(
+ JNIEnv* /* env */,
+ jclass /* clazz */,
+ jlong nativeHandle) {
+#ifdef EXECUTORCH_ANDROID_PROFILING
+ auto* native = reinterpret_cast(nativeHandle);
+ if (native == nullptr) {
+ return JNI_FALSE;
+ }
- facebook::jni::local_ref> ret =
- facebook::jni::JArrayClass::newArray(backends.size());
- int i = 0;
- for (auto s : backends) {
- facebook::jni::local_ref backend_name =
- facebook::jni::make_jstring(s.c_str());
- (*ret)[i] = backend_name;
- i++;
+ executorch::etdump::ETDumpGen* etdumpgen =
+ (executorch::etdump::ETDumpGen*)native->module_->event_tracer();
+ auto etdump_data = etdumpgen->get_etdump_data();
+
+ if (etdump_data.buf != nullptr && etdump_data.size > 0) {
+ int etdump_file =
+ open("/data/local/tmp/result.etdump", O_WRONLY | O_CREAT, 0644);
+ if (etdump_file == -1) {
+ ET_LOG(Error, "Cannot create result.etdump error: %d", errno);
+ return JNI_FALSE;
+ }
+ ssize_t bytes_written =
+ write(etdump_file, (uint8_t*)etdump_data.buf, etdump_data.size);
+ if (bytes_written == -1) {
+ ET_LOG(Error, "Cannot write result.etdump error: %d", errno);
+ return JNI_FALSE;
+ } else {
+ ET_LOG(Info, "ETDump written %d bytes to file.", bytes_written);
}
- return ret;
- }
-
- static void registerNatives() {
- registerHybrid({
- makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
- makeNativeMethod("executeNative", ExecuTorchJni::execute),
- makeNativeMethod("loadMethodNative", ExecuTorchJni::load_method),
- makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer),
- makeNativeMethod(
- "readLogBufferStaticNative", ExecuTorchJni::readLogBufferStatic),
- makeNativeMethod("etdump", ExecuTorchJni::etdump),
- makeNativeMethod("getMethods", ExecuTorchJni::getMethods),
- makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends),
- });
+ close(etdump_file);
+ free(etdump_data.buf);
+ return JNI_TRUE;
+ } else {
+ ET_LOG(Error, "No ETDump data available!");
}
-};
-} // namespace executorch::extension
+#endif
+ return JNI_FALSE;
+}
+
+} // extern "C"
#ifdef EXECUTORCH_BUILD_LLAMA_JNI
-extern void register_natives_for_llm();
+extern void register_natives_for_llm(JNIEnv* env);
#else
// No op if we don't build LLM
-void register_natives_for_llm() {}
+void register_natives_for_llm(JNIEnv* /* env */) {}
#endif
-extern void register_natives_for_runtime();
#ifdef EXECUTORCH_BUILD_EXTENSION_TRAINING
-extern void register_natives_for_training();
+extern void register_natives_for_training(JNIEnv* env);
#else
// No op if we don't build training JNI
-void register_natives_for_training() {}
+void register_natives_for_training(JNIEnv* /* env */) {}
#endif
+void register_natives_for_runtime(JNIEnv* env);
+
+void register_natives_for_module(JNIEnv* env) {
+ jclass module_class = env->FindClass("org/pytorch/executorch/Module");
+ if (module_class == nullptr) {
+ ET_LOG(Error, "Failed to find Module class");
+ env->ExceptionClear();
+ return;
+ }
+
+ // clang-format off
+ static const JNINativeMethod methods[] = {
+ {"nativeCreate", "(Ljava/lang/String;II)J",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeCreate)},
+ {"nativeDestroy", "(J)V",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeDestroy)},
+ {"nativeExecute",
+ "(JLjava/lang/String;[Lorg/pytorch/executorch/EValue;)[Lorg/pytorch/executorch/EValue;",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeExecute)},
+ {"nativeLoadMethod", "(JLjava/lang/String;)I",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeLoadMethod)},
+ {"nativeGetMethods", "(J)[Ljava/lang/String;",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeGetMethods)},
+ {"nativeGetUsedBackends", "(JLjava/lang/String;)[Ljava/lang/String;",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeGetUsedBackends)},
+ {"nativeReadLogBuffer", "(J)[Ljava/lang/String;",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeReadLogBuffer)},
+ {"nativeReadLogBufferStatic", "()[Ljava/lang/String;",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeReadLogBufferStatic)},
+ {"nativeEtdump", "(J)Z",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeEtdump)},
+ };
+ // clang-format on
+
+ int num_methods = sizeof(methods) / sizeof(methods[0]);
+ int result = env->RegisterNatives(module_class, methods, num_methods);
+ if (result != JNI_OK) {
+ ET_LOG(Error, "Failed to register native methods for Module");
+ }
+
+ env->DeleteLocalRef(module_class);
+}
+
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
- return facebook::jni::initialize(vm, [] {
- executorch::extension::ExecuTorchJni::registerNatives();
- register_natives_for_llm();
- register_natives_for_runtime();
- register_natives_for_training();
- });
+ g_jvm = vm;
+ JNIEnv* env = nullptr;
+ if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) {
+ return JNI_ERR;
+ }
+
+ // Initialize the JNI cache
+ g_jni_cache.init(env);
+
+ // Register native methods
+ register_natives_for_module(env);
+ register_natives_for_llm(env);
+ register_natives_for_runtime(env);
+ register_natives_for_training(env);
+
+ return JNI_VERSION_1_6;
}
diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp
index 888e09e7989..e26b40f4f4b 100644
--- a/extension/android/jni/jni_layer_llama.cpp
+++ b/extension/android/jni/jni_layer_llama.cpp
@@ -6,9 +6,12 @@
* LICENSE file in the root directory of this source tree.
*/
+#include
+
#include
#include
#include
+#include
#include
#include
#include
@@ -30,9 +33,6 @@
#include
#endif
-#include
-#include
-
#if defined(EXECUTORCH_BUILD_QNN)
#include
#endif
@@ -45,6 +45,10 @@ namespace llm = ::executorch::extension::llm;
using ::executorch::runtime::Error;
namespace {
+
+// Global JavaVM pointer for obtaining JNIEnv in callbacks
+JavaVM* g_jvm = nullptr;
+
bool utf8_check_validity(const char* str, size_t length) {
for (size_t i = 0; i < length; ++i) {
uint8_t byte = static_cast(str[i]);
@@ -79,47 +83,70 @@ bool utf8_check_validity(const char* str, size_t length) {
}
std::string token_buffer;
-} // namespace
-namespace executorch_jni {
+// Helper to convert jstring to std::string
+std::string jstring_to_string(JNIEnv* env, jstring jstr) {
+ if (jstr == nullptr) {
+ return "";
+ }
+ const char* chars = env->GetStringUTFChars(jstr, nullptr);
+ if (chars == nullptr) {
+ return "";
+ }
+ std::string result(chars);
+ env->ReleaseStringUTFChars(jstr, chars);
+ return result;
+}
-class ExecuTorchLlmCallbackJni
- : public facebook::jni::JavaClass {
- public:
- constexpr static const char* kJavaDescriptor =
- "Lorg/pytorch/executorch/extension/llm/LlmCallback;";
+// Helper to convert Java List to std::vector
+std::vector jlist_to_string_vector(JNIEnv* env, jobject jlist) {
+ std::vector result;
+ if (jlist == nullptr) {
+ return result;
+ }
- void onResult(std::string result) const {
- static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic();
- static const auto method =
- cls->getMethod)>("onResult");
+ jclass list_class = env->FindClass("java/util/List");
+ if (list_class == nullptr) {
+ env->ExceptionClear();
+ return result;
+ }
- token_buffer += result;
- if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) {
- ET_LOG(
- Info, "Current token buffer is not valid UTF-8. Waiting for more.");
- return;
- }
- result = token_buffer;
- token_buffer = "";
- facebook::jni::local_ref s = facebook::jni::make_jstring(result);
- method(self(), s);
+ jmethodID size_method = env->GetMethodID(list_class, "size", "()I");
+ jmethodID get_method =
+ env->GetMethodID(list_class, "get", "(I)Ljava/lang/Object;");
+
+ if (size_method == nullptr || get_method == nullptr) {
+ env->ExceptionClear();
+ env->DeleteLocalRef(list_class);
+ return result;
}
- void onStats(const llm::Stats& result) const {
- static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic();
- static const auto on_stats_method =
- cls->getMethod)>("onStats");
- on_stats_method(
- self(),
- facebook::jni::make_jstring(
- executorch::extension::llm::stats_to_json_string(result)));
+ jint size = env->CallIntMethod(jlist, size_method);
+ for (jint i = 0; i < size; ++i) {
+ jobject str_obj = env->CallObjectMethod(jlist, get_method, i);
+ if (str_obj != nullptr) {
+ result.push_back(jstring_to_string(env, static_cast(str_obj)));
+ env->DeleteLocalRef(str_obj);
+ }
}
-};
-class ExecuTorchLlmJni : public facebook::jni::HybridClass {
- private:
- friend HybridBase;
+ env->DeleteLocalRef(list_class);
+ return result;
+}
+
+} // namespace
+
+namespace executorch_jni {
+
+// Model type category constants
+constexpr int MODEL_TYPE_CATEGORY_LLM = 1;
+constexpr int MODEL_TYPE_CATEGORY_MULTIMODAL = 2;
+constexpr int MODEL_TYPE_MEDIATEK_LLAMA = 3;
+constexpr int MODEL_TYPE_QNN_LLAMA = 4;
+
+// Native handle class that holds the runner state
+class ExecuTorchLlmNative {
+ public:
float temperature_ = 0.0f;
int model_type_category_;
std::unique_ptr runner_;
@@ -127,37 +154,13 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass {
multi_modal_runner_;
std::vector prefill_inputs_;
- public:
- constexpr static auto kJavaDescriptor =
- "Lorg/pytorch/executorch/extension/llm/LlmModule;";
-
- constexpr static int MODEL_TYPE_CATEGORY_LLM = 1;
- constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2;
- constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3;
- constexpr static int MODEL_TYPE_QNN_LLAMA = 4;
-
- static facebook::jni::local_ref initHybrid(
- facebook::jni::alias_ref,
+ ExecuTorchLlmNative(
+ JNIEnv* env,
jint model_type_category,
- facebook::jni::alias_ref model_path,
- facebook::jni::alias_ref tokenizer_path,
+ jstring model_path,
+ jstring tokenizer_path,
jfloat temperature,
- facebook::jni::alias_ref::javaobject>
- data_files) {
- return makeCxxInstance(
- model_type_category,
- model_path,
- tokenizer_path,
- temperature,
- data_files);
- }
-
- ExecuTorchLlmJni(
- jint model_type_category,
- facebook::jni::alias_ref model_path,
- facebook::jni::alias_ref tokenizer_path,
- jfloat temperature,
- facebook::jni::alias_ref data_files = nullptr) {
+ jobject data_files) {
temperature_ = temperature;
#if defined(ET_USE_THREADPOOL)
// Reserve 1 thread for the main thread.
@@ -171,44 +174,30 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass {
#endif
model_type_category_ = model_type_category;
- std::vector data_files_vector;
+ std::string model_path_str = jstring_to_string(env, model_path);
+ std::string tokenizer_path_str = jstring_to_string(env, tokenizer_path);
+ std::vector data_files_vector =
+ jlist_to_string_vector(env, data_files);
+
if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
multi_modal_runner_ = llm::create_multimodal_runner(
- model_path->toStdString().c_str(),
- llm::load_tokenizer(tokenizer_path->toStdString()));
+ model_path_str.c_str(), llm::load_tokenizer(tokenizer_path_str));
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
- if (data_files != nullptr) {
- // Convert Java List to C++ std::vector
- auto list_class = facebook::jni::findClassStatic("java/util/List");
- auto size_method = list_class->getMethod("size");
- auto get_method =
- list_class->getMethod(jint)>(
- "get");
-
- jint size = size_method(data_files);
- for (jint i = 0; i < size; ++i) {
- auto str_obj = get_method(data_files, i);
- auto jstr = facebook::jni::static_ref_cast(str_obj);
- data_files_vector.push_back(jstr->toStdString());
- }
- }
runner_ = executorch::extension::llm::create_text_llm_runner(
- model_path->toStdString(),
- llm::load_tokenizer(tokenizer_path->toStdString()),
- data_files_vector);
+ model_path_str, llm::load_tokenizer(tokenizer_path_str), data_files_vector);
#if defined(EXECUTORCH_BUILD_QNN)
} else if (model_type_category == MODEL_TYPE_QNN_LLAMA) {
std::unique_ptr module = std::make_unique<
executorch::extension::Module>(
- model_path->toStdString().c_str(),
+ model_path_str.c_str(),
data_files_vector,
executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors);
std::string decoder_model = "llama3"; // use llama3 for now
runner_ = std::make_unique>( // QNN runner
std::move(module),
decoder_model.c_str(),
- model_path->toStdString().c_str(),
- tokenizer_path->toStdString().c_str(),
+ model_path_str.c_str(),
+ tokenizer_path_str.c_str(),
"",
"");
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
@@ -216,249 +205,530 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass {
#if defined(EXECUTORCH_BUILD_MEDIATEK)
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
runner_ = std::make_unique(
- model_path->toStdString().c_str(),
- tokenizer_path->toStdString().c_str());
+ model_path_str.c_str(), tokenizer_path_str.c_str());
// Interpret the model type as LLM
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
#endif
}
}
+};
- jint generate(
- facebook::jni::alias_ref prompt,
- jint seq_len,
- facebook::jni::alias_ref callback,
- jboolean echo) {
- if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
- std::vector inputs = prefill_inputs_;
- prefill_inputs_.clear();
- if (!prompt->toStdString().empty()) {
- inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
- }
- executorch::extension::llm::GenerationConfig config{
- .echo = static_cast(echo),
- .seq_len = seq_len,
- .temperature = temperature_,
- };
- multi_modal_runner_->generate(
- std::move(inputs),
- config,
- [callback](const std::string& result) { callback->onResult(result); },
- [callback](const llm::Stats& result) { callback->onStats(result); });
- } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
- executorch::extension::llm::GenerationConfig config{
- .echo = static_cast(echo),
- .seq_len = seq_len,
- .temperature = temperature_,
- };
- runner_->generate(
- prompt->toStdString(),
- config,
- [callback](std::string result) { callback->onResult(result); },
- [callback](const llm::Stats& result) { callback->onStats(result); });
+// Helper class for callback invocation
+class CallbackHelper {
+ public:
+ CallbackHelper(JNIEnv* env, jobject callback)
+ : env_(env), callback_(nullptr), callback_class_(nullptr) {
+ if (callback != nullptr) {
+ callback_ = env_->NewGlobalRef(callback);
+ jclass local_class = env_->GetObjectClass(callback);
+ callback_class_ = static_cast(env_->NewGlobalRef(local_class));
+ env_->DeleteLocalRef(local_class);
+ on_result_method_ = env_->GetMethodID(
+ callback_class_, "onResult", "(Ljava/lang/String;)V");
+ on_stats_method_ =
+ env_->GetMethodID(callback_class_, "onStats", "(Ljava/lang/String;)V");
}
- return 0;
- }
-
- // Returns status_code
- // Contract is valid within an AAR (JNI + corresponding Java code)
- jint append_text_input(facebook::jni::alias_ref prompt) {
- prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()});
- return 0;
}
- // Returns status_code
- jint append_images_input(
- facebook::jni::alias_ref image,
- jint width,
- jint height,
- jint channels) {
- std::vector images;
- if (image == nullptr) {
- return static_cast(Error::EndOfMethod);
+ ~CallbackHelper() {
+ if (g_jvm == nullptr) {
+ return;
}
- auto image_size = image->size();
- if (image_size != 0) {
- std::vector image_data_jint(image_size);
- std::vector image_data(image_size);
- image->getRegion(0, image_size, image_data_jint.data());
- for (int i = 0; i < image_size; i++) {
- image_data[i] = image_data_jint[i];
+ // Get the current JNIEnv (might be different thread)
+ JNIEnv* env = nullptr;
+ int status = g_jvm->GetEnv((void**)&env, JNI_VERSION_1_6);
+ if (status == JNI_EDETACHED) {
+ g_jvm->AttachCurrentThread(&env, nullptr);
+ }
+ if (env != nullptr) {
+ if (callback_ != nullptr) {
+ env->DeleteGlobalRef(callback_);
+ }
+ if (callback_class_ != nullptr) {
+ env->DeleteGlobalRef(callback_class_);
}
- llm::Image image_runner{std::move(image_data), width, height, channels};
- prefill_inputs_.emplace_back(
- llm::MultimodalInput{std::move(image_runner)});
}
-
- return 0;
}
- // Returns status_code
- jint append_normalized_images_input(
- facebook::jni::alias_ref image,
- jint width,
- jint height,
- jint channels) {
- std::vector images;
- if (image == nullptr) {
- return static_cast(Error::EndOfMethod);
+ void onResult(const std::string& result) {
+ JNIEnv* env = getEnv();
+ if (env == nullptr || callback_ == nullptr || on_result_method_ == nullptr) {
+ return;
}
- auto image_size = image->size();
- if (image_size != 0) {
- std::vector image_data_jfloat(image_size);
- std::vector image_data(image_size);
- image->getRegion(0, image_size, image_data_jfloat.data());
- for (int i = 0; i < image_size; i++) {
- image_data[i] = image_data_jfloat[i];
- }
- llm::Image image_runner{std::move(image_data), width, height, channels};
- prefill_inputs_.emplace_back(
- llm::MultimodalInput{std::move(image_runner)});
+
+ std::string current_result = result;
+ token_buffer += current_result;
+ if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) {
+ ET_LOG(
+ Info, "Current token buffer is not valid UTF-8. Waiting for more.");
+ return;
}
+ current_result = token_buffer;
+ token_buffer = "";
- return 0;
+ jstring jstr = env->NewStringUTF(current_result.c_str());
+ if (jstr != nullptr) {
+ env->CallVoidMethod(callback_, on_result_method_, jstr);
+ env->DeleteLocalRef(jstr);
+ }
}
- // Returns status_code
- jint append_audio_input(
- facebook::jni::alias_ref data,
- jint batch_size,
- jint n_bins,
- jint n_frames) {
- if (data == nullptr) {
- return static_cast(Error::EndOfMethod);
+ void onStats(const llm::Stats& stats) {
+ JNIEnv* env = getEnv();
+ if (env == nullptr || callback_ == nullptr || on_stats_method_ == nullptr) {
+ return;
}
- auto data_size = data->size();
- if (data_size != 0) {
- std::vector data_jbyte(data_size);
- std::vector data_u8(data_size);
- data->getRegion(0, data_size, data_jbyte.data());
- for (int i = 0; i < data_size; i++) {
- data_u8[i] = data_jbyte[i];
- }
- llm::Audio audio{std::move(data_u8), batch_size, n_bins, n_frames};
- prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)});
+
+ std::string stats_json =
+ executorch::extension::llm::stats_to_json_string(stats);
+ jstring jstr = env->NewStringUTF(stats_json.c_str());
+ if (jstr != nullptr) {
+ env->CallVoidMethod(callback_, on_stats_method_, jstr);
+ env->DeleteLocalRef(jstr);
}
- return 0;
}
- // Returns status_code
- jint append_audio_input_float(
- facebook::jni::alias_ref data,
- jint batch_size,
- jint n_bins,
- jint n_frames) {
- if (data == nullptr) {
- return static_cast(Error::EndOfMethod);
+ private:
+ JNIEnv* getEnv() {
+ if (g_jvm == nullptr) {
+ return nullptr;
}
- auto data_size = data->size();
- if (data_size != 0) {
- std::vector data_jfloat(data_size);
- std::vector data_f(data_size);
- data->getRegion(0, data_size, data_jfloat.data());
- for (int i = 0; i < data_size; i++) {
- data_f[i] = data_jfloat[i];
- }
- llm::Audio audio{std::move(data_f), batch_size, n_bins, n_frames};
- prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)});
+ JNIEnv* env = nullptr;
+ int status = g_jvm->GetEnv((void**)&env, JNI_VERSION_1_6);
+ if (status == JNI_EDETACHED) {
+ g_jvm->AttachCurrentThread(&env, nullptr);
+ }
+ return env;
+ }
+
+ JNIEnv* env_;
+ jobject callback_;
+ jclass callback_class_ = nullptr;
+ jmethodID on_result_method_ = nullptr;
+ jmethodID on_stats_method_ = nullptr;
+};
+
+} // namespace executorch_jni
+
+extern "C" {
+
+JNIEXPORT jlong JNICALL
+Java_org_pytorch_executorch_extension_llm_LlmModule_nativeCreate(
+ JNIEnv* env,
+ jobject /* this */,
+ jint model_type_category,
+ jstring model_path,
+ jstring tokenizer_path,
+ jfloat temperature,
+ jobject data_files) {
+ auto* native = new executorch_jni::ExecuTorchLlmNative(
+ env, model_type_category, model_path, tokenizer_path, temperature, data_files);
+ return reinterpret_cast(native);
+}
+
+JNIEXPORT void JNICALL
+Java_org_pytorch_executorch_extension_llm_LlmModule_nativeDestroy(
+ JNIEnv* /* env */,
+ jobject /* this */,
+ jlong native_handle) {
+ if (native_handle != 0) {
+ auto* native =
+ reinterpret_cast(native_handle);
+ delete native;
+ }
+}
+
+JNIEXPORT jint JNICALL
+Java_org_pytorch_executorch_extension_llm_LlmModule_nativeGenerate(
+ JNIEnv* env,
+ jobject /* this */,
+ jlong native_handle,
+ jstring prompt,
+ jint seq_len,
+ jobject callback,
+ jboolean echo) {
+ auto* native =
+ reinterpret_cast(native_handle);
+ if (native == nullptr) {
+ return -1;
+ }
+
+ std::string prompt_str = jstring_to_string(env, prompt);
+
+ // Create a shared callback helper for use in lambdas
+ auto callback_helper =
+ std::make_shared(env, callback);
+
+ if (native->model_type_category_ ==
+ executorch_jni::MODEL_TYPE_CATEGORY_MULTIMODAL) {
+ std::vector inputs = native->prefill_inputs_;
+ native->prefill_inputs_.clear();
+ if (!prompt_str.empty()) {
+ inputs.emplace_back(llm::MultimodalInput{prompt_str});
}
- return 0;
+ executorch::extension::llm::GenerationConfig config{
+ .echo = static_cast(echo),
+ .seq_len = seq_len,
+ .temperature = native->temperature_,
+ };
+ native->multi_modal_runner_->generate(
+ std::move(inputs),
+ config,
+ [callback_helper](const std::string& result) {
+ callback_helper->onResult(result);
+ },
+ [callback_helper](const llm::Stats& result) {
+ callback_helper->onStats(result);
+ });
+ } else if (
+ native->model_type_category_ ==
+ executorch_jni::MODEL_TYPE_CATEGORY_LLM) {
+ executorch::extension::llm::GenerationConfig config{
+ .echo = static_cast(echo),
+ .seq_len = seq_len,
+ .temperature = native->temperature_,
+ };
+ native->runner_->generate(
+ prompt_str,
+ config,
+ [callback_helper](std::string result) {
+ callback_helper->onResult(result);
+ },
+ [callback_helper](const llm::Stats& result) {
+ callback_helper->onStats(result);
+ });
+ }
+ return 0;
+}
+
+JNIEXPORT void JNICALL
+Java_org_pytorch_executorch_extension_llm_LlmModule_nativeStop(
+ JNIEnv* /* env */,
+ jobject /* this */,
+ jlong native_handle) {
+ auto* native =
+ reinterpret_cast(native_handle);
+ if (native == nullptr) {
+ return;
}
- // Returns status_code
- jint append_raw_audio_input(
- facebook::jni::alias_ref data,
- jint batch_size,
- jint n_channels,
- jint n_samples) {
- if (data == nullptr) {
- return static_cast(Error::EndOfMethod);
+ if (native->model_type_category_ ==
+ executorch_jni::MODEL_TYPE_CATEGORY_MULTIMODAL) {
+ native->multi_modal_runner_->stop();
+ } else if (
+ native->model_type_category_ ==
+ executorch_jni::MODEL_TYPE_CATEGORY_LLM) {
+ native->runner_->stop();
+ }
+}
+
+JNIEXPORT jint JNICALL
+Java_org_pytorch_executorch_extension_llm_LlmModule_nativeLoad(
+ JNIEnv* env,
+ jobject /* this */,
+ jlong native_handle) {
+ auto* native =
+ reinterpret_cast(native_handle);
+ if (native == nullptr) {
+ return -1;
+ }
+
+ int result = -1;
+ std::stringstream ss;
+
+ if (native->model_type_category_ ==
+ executorch_jni::MODEL_TYPE_CATEGORY_MULTIMODAL) {
+ result = static_cast(native->multi_modal_runner_->load());
+ if (result != 0) {
+ ss << "Failed to load multimodal runner: [" << result << "]";
}
- auto data_size = data->size();
- if (data_size != 0) {
- std::vector data_jbyte(data_size);
- std::vector data_u8(data_size);
- data->getRegion(0, data_size, data_jbyte.data());
- for (int i = 0; i < data_size; i++) {
- data_u8[i] = data_jbyte[i];
- }
- llm::RawAudio audio{
- std::move(data_u8), batch_size, n_channels, n_samples};
- prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)});
+ } else if (
+ native->model_type_category_ ==
+ executorch_jni::MODEL_TYPE_CATEGORY_LLM) {
+ result = static_cast(native->runner_->load());
+ if (result != 0) {
+ ss << "Failed to load llm runner: [" << result << "]";
}
- return 0;
+ } else {
+ ss << "Invalid model type category: " << native->model_type_category_
+ << ". Valid values are: "
+ << executorch_jni::MODEL_TYPE_CATEGORY_LLM << " or "
+ << executorch_jni::MODEL_TYPE_CATEGORY_MULTIMODAL;
+ }
+ if (result != 0) {
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(Error::InvalidArgument), ss.str().c_str());
+ }
+ return result; // 0 on success to keep backward compatibility
+}
+
+JNIEXPORT jint JNICALL
+Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendTextInput(
+ JNIEnv* env,
+ jobject /* this */,
+ jlong native_handle,
+ jstring prompt) {
+ auto* native =
+ reinterpret_cast(native_handle);
+ if (native == nullptr) {
+ return -1;
+ }
+
+ std::string prompt_str = jstring_to_string(env, prompt);
+ native->prefill_inputs_.emplace_back(llm::MultimodalInput{prompt_str});
+ return 0;
+}
+
+JNIEXPORT jint JNICALL
+Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendImagesInput(
+ JNIEnv* env,
+ jobject /* this */,
+ jlong native_handle,
+ jintArray image,
+ jint width,
+ jint height,
+ jint channels) {
+ auto* native =
+ reinterpret_cast(native_handle);
+ if (native == nullptr) {
+ return -1;
}
- void stop() {
- if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
- multi_modal_runner_->stop();
- } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
- runner_->stop();
+ if (image == nullptr) {
+ return static_cast(Error::EndOfMethod);
+ }
+
+ jsize image_size = env->GetArrayLength(image);
+ if (image_size != 0) {
+ std::vector image_data_jint(image_size);
+ std::vector image_data(image_size);
+ env->GetIntArrayRegion(image, 0, image_size, image_data_jint.data());
+ for (int i = 0; i < image_size; i++) {
+ image_data[i] = static_cast(image_data_jint[i]);
}
+ llm::Image image_runner{std::move(image_data), width, height, channels};
+ native->prefill_inputs_.emplace_back(
+ llm::MultimodalInput{std::move(image_runner)});
}
- void reset_context() {
- if (runner_ != nullptr) {
- runner_->reset();
+ return 0;
+}
+
+JNIEXPORT jint JNICALL
+Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendNormalizedImagesInput(
+ JNIEnv* env,
+ jobject /* this */,
+ jlong native_handle,
+ jfloatArray image,
+ jint width,
+ jint height,
+ jint channels) {
+ auto* native =
+ reinterpret_cast(native_handle);
+ if (native == nullptr) {
+ return -1;
+ }
+
+ if (image == nullptr) {
+ return static_cast(Error::EndOfMethod);
+ }
+
+ jsize image_size = env->GetArrayLength(image);
+ if (image_size != 0) {
+ std::vector