AdaGrad validation test (#334)
Signed-off-by: Alex Black <blacka101@gmail.com>
This commit is contained in:
		
							parent
							
								
									7a2ac800dd
								
							
						
					
					
						commit
						2497290cb0
					
				@ -33,7 +33,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4jBackend;
 | 
			
		||||
import org.nd4j.nativeblas.NativeOpsHolder;
 | 
			
		||||
import sun.nio.ch.DirectBuffer;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import java.nio.ByteBuffer;
 | 
			
		||||
@ -406,15 +405,14 @@ public class DataBufferTests extends BaseNd4jTest {
 | 
			
		||||
        //https://github.com/eclipse/deeplearning4j/issues/8783
 | 
			
		||||
        Nd4j.create(1);
 | 
			
		||||
 | 
			
		||||
        DirectBuffer bb = (DirectBuffer) ByteBuffer.allocateDirect(5);
 | 
			
		||||
        System.out.println(bb.getClass());
 | 
			
		||||
        System.out.println(bb.address());
 | 
			
		||||
 | 
			
		||||
        Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(bb.address());
 | 
			
		||||
        DataBuffer buff = Nd4j.createBuffer(ptr, 20, DataType.BYTE);
 | 
			
		||||
        BytePointer bp = new BytePointer(5);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        INDArray arr2 = Nd4j.create(buff, new long[]{5}, new long[]{1}, 1L, 'c', DataType.BYTE);
 | 
			
		||||
        Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(bp.address());
 | 
			
		||||
        DataBuffer buff = Nd4j.createBuffer(ptr, 5, DataType.INT8);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        INDArray arr2 = Nd4j.create(buff, new long[]{5}, new long[]{1}, 0, 'c', DataType.INT8);
 | 
			
		||||
        long before = arr2.data().pointer().address();
 | 
			
		||||
        Nd4j.getAffinityManager().ensureLocation(arr2, AffinityManager.Location.HOST);
 | 
			
		||||
        long after = arr2.data().pointer().address();
 | 
			
		||||
 | 
			
		||||
@ -25,6 +25,8 @@ import org.nd4j.linalg.ops.transforms.Transforms;
 | 
			
		||||
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
import static org.nd4j.linalg.ops.transforms.Transforms.sqrt;
 | 
			
		||||
 | 
			
		||||
public class UpdaterJavaCode {
 | 
			
		||||
 | 
			
		||||
    private UpdaterJavaCode(){ }
 | 
			
		||||
@ -46,6 +48,14 @@ public class UpdaterJavaCode {
 | 
			
		||||
        msdx.muli(rho).addi(update.mul(update).muli(1 - rho));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static void applyAdaGradUpdater(INDArray gradient, INDArray state, double learningRate, double epsilon){
 | 
			
		||||
        state.addi(gradient.mul(gradient));
 | 
			
		||||
 | 
			
		||||
        INDArray sqrtHistory = sqrt(state.dup('c'), false).addi(epsilon);
 | 
			
		||||
        // lr * gradient / (sqrt(sumSquaredGradients) + epsilon)
 | 
			
		||||
        gradient.muli(sqrtHistory.rdivi(learningRate));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    public static void applyAdamUpdater(INDArray gradient, INDArray m, INDArray v, double learningRate, double beta1, double beta2,
 | 
			
		||||
                                                         double epsilon, int iteration){
 | 
			
		||||
 | 
			
		||||
@ -69,6 +69,32 @@ public class UpdaterValidation extends BaseNd4jTest {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testAdaGradUpdater(){
 | 
			
		||||
        double lr = 0.1;
 | 
			
		||||
        double epsilon = 1e-6;
 | 
			
		||||
 | 
			
		||||
        INDArray s = Nd4j.zeros(DataType.DOUBLE, 1, 5);
 | 
			
		||||
 | 
			
		||||
        Map<String,INDArray> state = new HashMap<>();
 | 
			
		||||
        state.put("grad", s.dup());
 | 
			
		||||
        AdaGradUpdater u = (AdaGradUpdater) new AdaGrad(lr, epsilon).instantiate(state, true);
 | 
			
		||||
 | 
			
		||||
        assertEquals(s, state.get("grad"));
 | 
			
		||||
 | 
			
		||||
        for( int i=0; i<3; i++ ) {
 | 
			
		||||
            INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
 | 
			
		||||
            INDArray g2 = g1.dup();
 | 
			
		||||
 | 
			
		||||
            UpdaterJavaCode.applyAdaGradUpdater(g1, s, lr, epsilon);
 | 
			
		||||
 | 
			
		||||
            u.applyUpdater(g2, i, 0);
 | 
			
		||||
 | 
			
		||||
            assertEquals(s, state.get("grad"));
 | 
			
		||||
            assertEquals(g1, g2);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testAdamUpdater(){
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user