parent
7a2ac800dd
commit
2497290cb0
|
@ -33,7 +33,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
import sun.nio.ch.DirectBuffer;
|
||||
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
@ -406,15 +405,14 @@ public class DataBufferTests extends BaseNd4jTest {
|
|||
//https://github.com/eclipse/deeplearning4j/issues/8783
|
||||
Nd4j.create(1);
|
||||
|
||||
DirectBuffer bb = (DirectBuffer) ByteBuffer.allocateDirect(5);
|
||||
System.out.println(bb.getClass());
|
||||
System.out.println(bb.address());
|
||||
|
||||
Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(bb.address());
|
||||
DataBuffer buff = Nd4j.createBuffer(ptr, 20, DataType.BYTE);
|
||||
BytePointer bp = new BytePointer(5);
|
||||
|
||||
|
||||
INDArray arr2 = Nd4j.create(buff, new long[]{5}, new long[]{1}, 1L, 'c', DataType.BYTE);
|
||||
Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(bp.address());
|
||||
DataBuffer buff = Nd4j.createBuffer(ptr, 5, DataType.INT8);
|
||||
|
||||
|
||||
INDArray arr2 = Nd4j.create(buff, new long[]{5}, new long[]{1}, 0, 'c', DataType.INT8);
|
||||
long before = arr2.data().pointer().address();
|
||||
Nd4j.getAffinityManager().ensureLocation(arr2, AffinityManager.Location.HOST);
|
||||
long after = arr2.data().pointer().address();
|
||||
|
|
|
@ -25,6 +25,8 @@ import org.nd4j.linalg.ops.transforms.Transforms;
|
|||
|
||||
import java.util.Map;
|
||||
|
||||
import static org.nd4j.linalg.ops.transforms.Transforms.sqrt;
|
||||
|
||||
public class UpdaterJavaCode {
|
||||
|
||||
private UpdaterJavaCode(){ }
|
||||
|
@ -46,6 +48,14 @@ public class UpdaterJavaCode {
|
|||
msdx.muli(rho).addi(update.mul(update).muli(1 - rho));
|
||||
}
|
||||
|
||||
public static void applyAdaGradUpdater(INDArray gradient, INDArray state, double learningRate, double epsilon){
|
||||
state.addi(gradient.mul(gradient));
|
||||
|
||||
INDArray sqrtHistory = sqrt(state.dup('c'), false).addi(epsilon);
|
||||
// lr * gradient / (sqrt(sumSquaredGradients) + epsilon)
|
||||
gradient.muli(sqrtHistory.rdivi(learningRate));
|
||||
}
|
||||
|
||||
|
||||
public static void applyAdamUpdater(INDArray gradient, INDArray m, INDArray v, double learningRate, double beta1, double beta2,
|
||||
double epsilon, int iteration){
|
||||
|
|
|
@ -69,6 +69,32 @@ public class UpdaterValidation extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAdaGradUpdater(){
|
||||
double lr = 0.1;
|
||||
double epsilon = 1e-6;
|
||||
|
||||
INDArray s = Nd4j.zeros(DataType.DOUBLE, 1, 5);
|
||||
|
||||
Map<String,INDArray> state = new HashMap<>();
|
||||
state.put("grad", s.dup());
|
||||
AdaGradUpdater u = (AdaGradUpdater) new AdaGrad(lr, epsilon).instantiate(state, true);
|
||||
|
||||
assertEquals(s, state.get("grad"));
|
||||
|
||||
for( int i=0; i<3; i++ ) {
|
||||
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||
INDArray g2 = g1.dup();
|
||||
|
||||
UpdaterJavaCode.applyAdaGradUpdater(g1, s, lr, epsilon);
|
||||
|
||||
u.applyUpdater(g2, i, 0);
|
||||
|
||||
assertEquals(s, state.get("grad"));
|
||||
assertEquals(g1, g2);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testAdamUpdater(){
|
||||
|
|
Loading…
Reference in New Issue