diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java index b271c7bff..b7660bc6e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java @@ -33,7 +33,6 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.nativeblas.NativeOpsHolder; -import sun.nio.ch.DirectBuffer; import java.nio.ByteBuffer; @@ -406,15 +405,14 @@ public class DataBufferTests extends BaseNd4jTest { //https://github.com/eclipse/deeplearning4j/issues/8783 Nd4j.create(1); - DirectBuffer bb = (DirectBuffer) ByteBuffer.allocateDirect(5); - System.out.println(bb.getClass()); - System.out.println(bb.address()); - - Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(bb.address()); - DataBuffer buff = Nd4j.createBuffer(ptr, 20, DataType.BYTE); + BytePointer bp = new BytePointer(5); - INDArray arr2 = Nd4j.create(buff, new long[]{5}, new long[]{1}, 1L, 'c', DataType.BYTE); + Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(bp.address()); + DataBuffer buff = Nd4j.createBuffer(ptr, 5, DataType.INT8); + + + INDArray arr2 = Nd4j.create(buff, new long[]{5}, new long[]{1}, 0, 'c', DataType.INT8); long before = arr2.data().pointer().address(); Nd4j.getAffinityManager().ensureLocation(arr2, AffinityManager.Location.HOST); long after = arr2.data().pointer().address(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterJavaCode.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterJavaCode.java index c80a04c55..5e640ec8b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterJavaCode.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterJavaCode.java @@ -25,6 +25,8 @@ import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Map; +import static org.nd4j.linalg.ops.transforms.Transforms.sqrt; + public class UpdaterJavaCode { private UpdaterJavaCode(){ } @@ -46,6 +48,14 @@ public class UpdaterJavaCode { msdx.muli(rho).addi(update.mul(update).muli(1 - rho)); } + public static void applyAdaGradUpdater(INDArray gradient, INDArray state, double learningRate, double epsilon){ + state.addi(gradient.mul(gradient)); + + INDArray sqrtHistory = sqrt(state.dup('c'), false).addi(epsilon); + // lr * gradient / (sqrt(sumSquaredGradients) + epsilon) + gradient.muli(sqrtHistory.rdivi(learningRate)); + } + public static void applyAdamUpdater(INDArray gradient, INDArray m, INDArray v, double learningRate, double beta1, double beta2, double epsilon, int iteration){ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java index e8df8d7ae..660b178e4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java @@ -69,6 +69,32 @@ public class UpdaterValidation extends BaseNd4jTest { } } + @Test + public void testAdaGradUpdater(){ + double lr = 0.1; + double epsilon = 1e-6; + + INDArray s = Nd4j.zeros(DataType.DOUBLE, 1, 5); + + Map state = new HashMap<>(); + state.put("grad", s.dup()); + AdaGradUpdater u = (AdaGradUpdater) new AdaGrad(lr, epsilon).instantiate(state, true); + + assertEquals(s, state.get("grad")); + + for( int i=0; i<3; i++ ) { + INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); + INDArray g2 = g1.dup(); + + UpdaterJavaCode.applyAdaGradUpdater(g1, s, lr, epsilon); + + u.applyUpdater(g2, i, 0); + + assertEquals(s, state.get("grad")); + assertEquals(g1, g2); + } + } + @Test public void testAdamUpdater(){