⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

package org.pytorch.executorch.training;

import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.util.Map;
Expand All @@ -32,26 +30,27 @@ public class SGD {
NativeLoader.loadLibrary("executorch");
}

private final HybridData mHybridData;
private long mNativeHandle;

@DoNotStrip
private static native HybridData initHybrid(
private static native long nativeCreate(
Map<String, Tensor> namedParameters,
double learningRate,
double momentum,
double dampening,
double weightDecay,
boolean nesterov);

private static native void nativeDestroy(long nativeHandle);

private SGD(
Map<String, Tensor> namedParameters,
double learningRate,
double momentum,
double dampening,
double weightDecay,
boolean nesterov) {
mHybridData =
initHybrid(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov);
mNativeHandle =
nativeCreate(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov);
}

/**
Expand Down Expand Up @@ -92,12 +91,34 @@ public static SGD create(Map<String, Tensor> namedParameters, double learningRat
* @param namedGradients Map of parameter names to gradient tensors
*/
public void step(Map<String, Tensor> namedGradients) {
if (!mHybridData.isValid()) {
if (mNativeHandle == 0) {
throw new RuntimeException("Attempt to use a destroyed SGD optimizer");
}
stepNative(namedGradients);
nativeStep(mNativeHandle, namedGradients);
}

@DoNotStrip
private native void stepNative(Map<String, Tensor> namedGradients);
private static native void nativeStep(long nativeHandle, Map<String, Tensor> namedGradients);

/**
* Explicitly destroys the native SGD optimizer object. Calling this method is not required, as
* the native object will be destroyed when this object is garbage-collected. However, the timing
* of garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory
* more quickly.
*/
public void destroy() {
if (mNativeHandle != 0) {
nativeDestroy(mNativeHandle);
mNativeHandle = 0;
}
}

@SuppressWarnings("deprecation")
@Override
protected void finalize() throws Throwable {
if (mNativeHandle != 0) {
nativeDestroy(mNativeHandle);
mNativeHandle = 0;
}
super.finalize();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
package org.pytorch.executorch.training;

import android.util.Log;
import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.io.File;
Expand All @@ -36,13 +34,14 @@ public class TrainingModule {
NativeLoader.loadLibrary("executorch");
}

private final HybridData mHybridData;
private long mNativeHandle;

@DoNotStrip
private static native HybridData initHybrid(String moduleAbsolutePath, String dataAbsolutePath);
private static native long nativeCreate(String moduleAbsolutePath, String dataAbsolutePath);

private static native void nativeDestroy(long nativeHandle);

private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) {
mHybridData = initHybrid(moduleAbsolutePath, dataAbsolutePath);
mNativeHandle = nativeCreate(moduleAbsolutePath, dataAbsolutePath);
}

/**
Expand Down Expand Up @@ -87,35 +86,58 @@ public static TrainingModule load(final String modelPath) {
* @return return value(s) from the method.
*/
public EValue[] executeForwardBackward(String methodName, EValue... inputs) {
if (!mHybridData.isValid()) {
if (mNativeHandle == 0) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new EValue[0];
}
return executeForwardBackwardNative(methodName, inputs);
return nativeExecuteForwardBackward(mNativeHandle, methodName, inputs);
}

@DoNotStrip
private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs);
private static native EValue[] nativeExecuteForwardBackward(
long nativeHandle, String methodName, EValue... inputs);

public Map<String, Tensor> namedParameters(String methodName) {
if (!mHybridData.isValid()) {
if (mNativeHandle == 0) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new HashMap<String, Tensor>();
}
return namedParametersNative(methodName);
return nativeNamedParameters(mNativeHandle, methodName);
}

@DoNotStrip
private native Map<String, Tensor> namedParametersNative(String methodName);
private static native Map<String, Tensor> nativeNamedParameters(
long nativeHandle, String methodName);

public Map<String, Tensor> namedGradients(String methodName) {
if (!mHybridData.isValid()) {
if (mNativeHandle == 0) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new HashMap<String, Tensor>();
}
return namedGradientsNative(methodName);
return nativeNamedGradients(mNativeHandle, methodName);
}

@DoNotStrip
private native Map<String, Tensor> namedGradientsNative(String methodName);
private static native Map<String, Tensor> nativeNamedGradients(
long nativeHandle, String methodName);

/**
* Explicitly destroys the native TrainingModule object. Calling this method is not required, as
* the native object will be destroyed when this object is garbage-collected. However, the timing
* of garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory
* more quickly.
*/
public void destroy() {
if (mNativeHandle != 0) {
nativeDestroy(mNativeHandle);
mNativeHandle = 0;
}
}

@SuppressWarnings("deprecation")
@Override
protected void finalize() throws Throwable {
if (mNativeHandle != 0) {
nativeDestroy(mNativeHandle);
mNativeHandle = 0;
}
super.finalize();
}
}
7 changes: 4 additions & 3 deletions extension/android/jni/jni_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include <fbjni/fbjni.h>

using namespace executorch::extension;
using namespace executorch::jni_helper;
using namespace torch::executor;

namespace executorch::extension {
Expand Down Expand Up @@ -543,17 +544,17 @@ void register_natives_for_llm() {}
extern void register_natives_for_runtime();

#ifdef EXECUTORCH_BUILD_EXTENSION_TRAINING
extern void register_natives_for_training();
extern void register_natives_for_training(JNIEnv* env);
#else
// No op if we don't build training JNI
void register_natives_for_training() {}
void register_natives_for_training(JNIEnv* /* env */) {}
#endif

JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
return facebook::jni::initialize(vm, [] {
executorch::extension::ExecuTorchJni::registerNatives();
register_natives_for_llm();
register_natives_for_runtime();
register_natives_for_training();
register_natives_for_training(facebook::jni::Environment::current());
});
}
Loading