AdaGrad validation test (#334)

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-03-20 17:25:46 +11:00 committed by GitHub
parent 7a2ac800dd
commit 2497290cb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 8 deletions

View File

@ -33,7 +33,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
import sun.nio.ch.DirectBuffer;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -406,15 +405,14 @@ public class DataBufferTests extends BaseNd4jTest {
//https://github.com/eclipse/deeplearning4j/issues/8783 //https://github.com/eclipse/deeplearning4j/issues/8783
Nd4j.create(1); Nd4j.create(1);
DirectBuffer bb = (DirectBuffer) ByteBuffer.allocateDirect(5); BytePointer bp = new BytePointer(5);
System.out.println(bb.getClass());
System.out.println(bb.address());
Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(bb.address());
DataBuffer buff = Nd4j.createBuffer(ptr, 20, DataType.BYTE);
INDArray arr2 = Nd4j.create(buff, new long[]{5}, new long[]{1}, 1L, 'c', DataType.BYTE); Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(bp.address());
DataBuffer buff = Nd4j.createBuffer(ptr, 5, DataType.INT8);
INDArray arr2 = Nd4j.create(buff, new long[]{5}, new long[]{1}, 0, 'c', DataType.INT8);
long before = arr2.data().pointer().address(); long before = arr2.data().pointer().address();
Nd4j.getAffinityManager().ensureLocation(arr2, AffinityManager.Location.HOST); Nd4j.getAffinityManager().ensureLocation(arr2, AffinityManager.Location.HOST);
long after = arr2.data().pointer().address(); long after = arr2.data().pointer().address();

View File

@ -25,6 +25,8 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Map; import java.util.Map;
import static org.nd4j.linalg.ops.transforms.Transforms.sqrt;
public class UpdaterJavaCode { public class UpdaterJavaCode {
private UpdaterJavaCode(){ } private UpdaterJavaCode(){ }
@ -46,6 +48,14 @@ public class UpdaterJavaCode {
msdx.muli(rho).addi(update.mul(update).muli(1 - rho)); msdx.muli(rho).addi(update.mul(update).muli(1 - rho));
} }
public static void applyAdaGradUpdater(INDArray gradient, INDArray state, double learningRate, double epsilon){
state.addi(gradient.mul(gradient));
INDArray sqrtHistory = sqrt(state.dup('c'), false).addi(epsilon);
// lr * gradient / (sqrt(sumSquaredGradients) + epsilon)
gradient.muli(sqrtHistory.rdivi(learningRate));
}
public static void applyAdamUpdater(INDArray gradient, INDArray m, INDArray v, double learningRate, double beta1, double beta2, public static void applyAdamUpdater(INDArray gradient, INDArray m, INDArray v, double learningRate, double beta1, double beta2,
double epsilon, int iteration){ double epsilon, int iteration){

View File

@ -69,6 +69,32 @@ public class UpdaterValidation extends BaseNd4jTest {
} }
} }
@Test
public void testAdaGradUpdater(){
double lr = 0.1;
double epsilon = 1e-6;
INDArray s = Nd4j.zeros(DataType.DOUBLE, 1, 5);
Map<String,INDArray> state = new HashMap<>();
state.put("grad", s.dup());
AdaGradUpdater u = (AdaGradUpdater) new AdaGrad(lr, epsilon).instantiate(state, true);
assertEquals(s, state.get("grad"));
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
UpdaterJavaCode.applyAdaGradUpdater(g1, s, lr, epsilon);
u.applyUpdater(g2, i, 0);
assertEquals(s, state.get("grad"));
assertEquals(g1, g2);
}
}
@Test @Test
public void testAdamUpdater(){ public void testAdamUpdater(){