AdaGrad validation test (#334)

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-03-20 17:25:46 +11:00 committed by GitHub
parent 7a2ac800dd
commit 2497290cb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 8 deletions

View File

@ -33,7 +33,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.nativeblas.NativeOpsHolder;
import sun.nio.ch.DirectBuffer;
import java.nio.ByteBuffer;
@ -406,15 +405,14 @@ public class DataBufferTests extends BaseNd4jTest {
//https://github.com/eclipse/deeplearning4j/issues/8783
Nd4j.create(1);
DirectBuffer bb = (DirectBuffer) ByteBuffer.allocateDirect(5);
System.out.println(bb.getClass());
System.out.println(bb.address());
Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(bb.address());
DataBuffer buff = Nd4j.createBuffer(ptr, 20, DataType.BYTE);
BytePointer bp = new BytePointer(5);
INDArray arr2 = Nd4j.create(buff, new long[]{5}, new long[]{1}, 1L, 'c', DataType.BYTE);
Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(bp.address());
DataBuffer buff = Nd4j.createBuffer(ptr, 5, DataType.INT8);
INDArray arr2 = Nd4j.create(buff, new long[]{5}, new long[]{1}, 0, 'c', DataType.INT8);
long before = arr2.data().pointer().address();
Nd4j.getAffinityManager().ensureLocation(arr2, AffinityManager.Location.HOST);
long after = arr2.data().pointer().address();

View File

@ -25,6 +25,8 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Map;
import static org.nd4j.linalg.ops.transforms.Transforms.sqrt;
public class UpdaterJavaCode {
private UpdaterJavaCode(){ }
@ -46,6 +48,14 @@ public class UpdaterJavaCode {
msdx.muli(rho).addi(update.mul(update).muli(1 - rho));
}
public static void applyAdaGradUpdater(INDArray gradient, INDArray state, double learningRate, double epsilon){
state.addi(gradient.mul(gradient));
INDArray sqrtHistory = sqrt(state.dup('c'), false).addi(epsilon);
// lr * gradient / (sqrt(sumSquaredGradients) + epsilon)
gradient.muli(sqrtHistory.rdivi(learningRate));
}
public static void applyAdamUpdater(INDArray gradient, INDArray m, INDArray v, double learningRate, double beta1, double beta2,
double epsilon, int iteration){

View File

@ -69,6 +69,32 @@ public class UpdaterValidation extends BaseNd4jTest {
}
}
@Test
public void testAdaGradUpdater(){
double lr = 0.1;
double epsilon = 1e-6;
INDArray s = Nd4j.zeros(DataType.DOUBLE, 1, 5);
Map<String,INDArray> state = new HashMap<>();
state.put("grad", s.dup());
AdaGradUpdater u = (AdaGradUpdater) new AdaGrad(lr, epsilon).instantiate(state, true);
assertEquals(s, state.get("grad"));
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
UpdaterJavaCode.applyAdaGradUpdater(g1, s, lr, epsilon);
u.applyUpdater(g2, i, 0);
assertEquals(s, state.get("grad"));
assertEquals(g1, g2);
}
}
@Test
public void testAdamUpdater(){