From f25e3e71e5fcfbb74f89005001db77c6ec2e6d59 Mon Sep 17 00:00:00 2001 From: Robert Altena Date: Thu, 5 Sep 2019 10:19:39 +0900 Subject: [PATCH 01/10] remove lengthLong (#236) Signed-off-by: Robert Altena --- .../solvers/accumulation/EncodingHandler.java | 10 +++---- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 11 -------- .../linalg/api/ndarray/BaseSparseNDArray.java | 5 ---- .../org/nd4j/linalg/api/ndarray/INDArray.java | 12 +-------- .../nd4j/linalg/jcublas/JCublasNDArray.java | 12 ++++----- .../linalg/jcublas/JCublasNDArrayFactory.java | 22 ++++++++-------- .../ops/executioner/CudaExecutioner.java | 26 +++++++++---------- .../ops/executioner/CudaGridExecutioner.java | 2 +- .../cpu/nativecpu/CpuNDArrayFactory.java | 10 +++---- .../nativecpu/ops/NativeOpExecutioner.java | 22 ++++++++-------- .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 2 +- .../java/org/nd4j/linalg/ShufflesTests.java | 4 +-- .../linalg/compression/CompressionTests.java | 2 +- 13 files changed, 57 insertions(+), 83 deletions(-) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodingHandler.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodingHandler.java index dc632cecd..24a46117a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodingHandler.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodingHandler.java @@ -135,7 +135,7 @@ public class EncodingHandler implements MessageHandler { iterations.get().incrementAndGet(); if (boundary != null && atomicBoundary.get() < 0) - atomicBoundary.compareAndSet(-1, (int) (updates.lengthLong() * boundary)); + atomicBoundary.compareAndSet(-1, (int) (updates.length() * boundary)); INDArray encoded; @@ -160,11 +160,11 @@ public class EncodingHandler implements MessageHandler { double encLen = encoded.data().getInt(0); // if updates are too dense - we fallback to bitmap encoding - if (encLen >= (updates.lengthLong() / 16)) { + if (encLen >= (updates.length() / 16)) { log.debug("Switching back to bitmapEncoding: iteration {}, epoch {}, threshold {}, encoded length {}", iteration, epoch, currThreshold, encLen); bitmapMode.get().set(true); - DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(updates.lengthLong() / 16 + 5); + DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(updates.length() / 16 + 5); encoded = Nd4j.createArrayFromShapeBuffer(buffer, updates.shapeInfoDataBuffer()); Nd4j.getExecutioner().bitmapEncode(updates, encoded, currentThreshold.get().get()); @@ -186,12 +186,12 @@ public class EncodingHandler implements MessageHandler { } } else { //Dense bitmap updates - DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(updates.lengthLong() / 16 + 5); + DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(updates.length() / 16 + 5); encoded = Nd4j.createArrayFromShapeBuffer(buffer, updates.shapeInfoDataBuffer()); long values = Nd4j.getExecutioner().bitmapEncode(updates, encoded, currentThreshold.get().get()); - if (values < (updates.lengthLong() / 16 + 5) / 2) { + if (values < (updates.length() / 16 + 5) / 2) { boolean current = bitmapMode.get().get(); bitmapMode.get().set(false); if(!current) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 771b74615..126ba2466 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -4641,17 +4641,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return jvmShapeInfo.length; } - /** - * Returns the total number of elements in the ndarray - * - * @return the number of elements in the ndarray - */ - @Override - @Deprecated - public long lengthLong() { - return jvmShapeInfo.length; - } - @Override public INDArray broadcast(INDArray result) { Nd4j.getCompressor().autoDecompress(this); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java index 11a005f91..c9c5cab37 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java @@ -279,11 +279,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { return (int) length(); } - @Override - public long lengthLong() { - return length; - } - protected void init(long[] shape) { if (shape.length == 1) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index 9288b6d51..f9f04cc43 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -2377,17 +2377,7 @@ public interface INDArray extends Serializable, AutoCloseable { * @return the number of elements in the ndarray */ long length(); - - /** - * Returns the total number of elements in the ndarray - * - * @return the number of elements in the ndarray - * @deprecated use {@link #length()} - */ - @Deprecated - long lengthLong(); - - + /** * Broadcasts this ndarray to be the specified shape * diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java index 170f22a52..eb0db01a3 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java @@ -664,7 +664,7 @@ public class JCublasNDArray extends BaseNDArray { //if (1 < 0) { Nd4j.getExecutioner().commit(); - DataBuffer buffer = Nd4j.createBuffer(this.lengthLong(), false); + DataBuffer buffer = Nd4j.createBuffer(this.length(), false); AllocationPoint pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer); AllocationPoint pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data); @@ -686,10 +686,10 @@ public class JCublasNDArray extends BaseNDArray { val perfD = PerformanceTracker.getInstance().helperStartTransaction(); if (pointSrc.isActualOnDeviceSide()) { - if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getDevicePointer(), this.lengthLong() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0) + if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getDevicePointer(), this.length() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0) throw new ND4JIllegalStateException("memcpyAsync failed"); } else { - if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getHostPointer(), this.lengthLong() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0) + if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getHostPointer(), this.length() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0) throw new ND4JIllegalStateException("memcpyAsync failed"); direction = MemcpyDirection.HOST_TO_DEVICE; @@ -738,7 +738,7 @@ public class JCublasNDArray extends BaseNDArray { if (!this.isView()) { Nd4j.getExecutioner().commit(); - val buffer = Nd4j.createBuffer(this.dataType(), this.lengthLong(), false); + val buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); val pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer); val pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data); @@ -749,10 +749,10 @@ public class JCublasNDArray extends BaseNDArray { val perfD = PerformanceTracker.getInstance().helperStartTransaction(); if (pointSrc.isActualOnDeviceSide()) { - if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getDevicePointer(), this.lengthLong() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0) + if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getDevicePointer(), this.length() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0) throw new ND4JIllegalStateException("memcpyAsync failed"); } else { - if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getHostPointer(), this.lengthLong() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0) + if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getHostPointer(), this.length() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0) throw new ND4JIllegalStateException("memcpyAsync failed"); direction = MemcpyDirection.HOST_TO_DEVICE; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java index 73daa679d..daebc041e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java @@ -424,7 +424,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { val perfD = PerformanceTracker.getInstance().helperStartTransaction(); - nativeOps.memcpyAsync(point.getDevicePointer(), point.getHostPointer(), ret.lengthLong() * Nd4j.sizeOfDataType(ret.data().dataType()), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()); + nativeOps.memcpyAsync(point.getDevicePointer(), point.getHostPointer(), ret.length() * Nd4j.sizeOfDataType(ret.data().dataType()), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()); context.getSpecialStream().synchronize(); if (nativeOps.lastErrorCode() != 0) @@ -580,7 +580,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { if (true) { Nd4j.getExecutioner().push(); - long len = target.lengthLong(); + long len = target.length(); AtomicAllocator allocator = AtomicAllocator.getInstance(); @@ -598,7 +598,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { if (arrays[i].elementWiseStride() != 1) throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays"); - if (arrays[i].lengthLong() != len) + if (arrays[i].length() != len) throw new ND4JIllegalStateException("All arrays should have equal length for averaging"); AllocationPoint point = allocator.getAllocationPoint(arrays[i]); @@ -621,7 +621,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { return target; } else { - long len = target.lengthLong(); + long len = target.length(); Nd4j.getExecutioner().commit(); @@ -637,7 +637,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { if (arrays[i].elementWiseStride() != 1) throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays"); - if (arrays[i].lengthLong() != len) + if (arrays[i].length() != len) throw new ND4JIllegalStateException("All arrays should have equal length for averaging"); ((BaseCudaDataBuffer) arrays[i].data()).lazyAllocateHostPointer(); @@ -689,7 +689,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { Nd4j.getExecutioner().push(); - long len = target != null ? target.lengthLong() : arrays[0].lengthLong(); + long len = target != null ? target.length() : arrays[0].length(); AtomicAllocator allocator = AtomicAllocator.getInstance(); @@ -707,7 +707,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { if (arrays[i].elementWiseStride() != 1) throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays"); - if (arrays[i].lengthLong() != len) + if (arrays[i].length() != len) throw new ND4JIllegalStateException("All arrays should have equal length for averaging"); AllocationPoint point = allocator.getAllocationPoint(arrays[i]); @@ -744,7 +744,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { /** * We expect all operations are complete at this point */ - long len = target == null ? arrays[0].lengthLong() : target.lengthLong(); + long len = target == null ? arrays[0].length() : target.length(); val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext(); @@ -758,7 +758,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { if (arrays[i].elementWiseStride() != 1) throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays"); - if (arrays[i].lengthLong() != len) + if (arrays[i].length() != len) throw new ND4JIllegalStateException("All arrays should have equal length for averaging"); ((BaseCudaDataBuffer) arrays[i].data()).lazyAllocateHostPointer(); @@ -1303,7 +1303,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { } - int numTads = (int)(tensor.lengthLong() / tadLength); + int numTads = (int)(tensor.length() / tadLength); INDArray[] result = new INDArray[numTads]; long[] xPointers = new long[numTads]; @@ -1378,7 +1378,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { new CudaPointer(0)); // we're sending > 10m elements to radixSort - boolean isRadix = !x.isView() && (x.lengthLong() > 1024 * 1024 * 10); + boolean isRadix = !x.isView() && (x.length() > 1024 * 1024 * 10); INDArray tmpX = x; // we need to guarantee all threads are finished here diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 43bbfbdca..6c95d3ce5 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -293,9 +293,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { Pointer yDevTadShapeInfo = null; if (op.y() != null) { - if (dimension.length == 0 || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE )|| op.x().tensorAlongDimension(0, dimension).lengthLong() != op.y().lengthLong()) { - if (!op.isComplexAccumulation() && op.x().lengthLong() != op.y().lengthLong()) - throw new ND4JIllegalStateException("Op.X [" + op.x().lengthLong() + "] and Op.Y [" + op.y().lengthLong() + "] lengths should match"); + if (dimension.length == 0 || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE )|| op.x().tensorAlongDimension(0, dimension).length() != op.y().length()) { + if (!op.isComplexAccumulation() && op.x().length() != op.y().length()) + throw new ND4JIllegalStateException("Op.X [" + op.x().length() + "] and Op.Y [" + op.y().length() + "] lengths should match"); if (!op.z().isScalar()) { Pair yTadBuffers = tadManager.getTADOnlyShapeInfo(op.y(), dimension); @@ -536,7 +536,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { } else { if (op.y() != null) { //2 options here: either pairwise, equal sizes - OR every X TAD vs. entirety of Y - if (op.x().lengthLong() == op.y().lengthLong()) { + if (op.x().length() == op.y().length()) { //Pairwise if (!wholeDims && op.x().tensorsAlongDimension(dimension) != op.y().tensorsAlongDimension(dimension)) { throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " + @@ -548,11 +548,11 @@ public class CudaExecutioner extends DefaultOpExecutioner { throw new ND4JIllegalStateException("TAD vs TAD comparison requires dimension (or other comparison mode was supposed to be used?)"); //Every X TAD vs. entirety of Y - val xTADSize = op.x().lengthLong() / op.x().tensorsAlongDimension(dimension); + val xTADSize = op.x().length() / op.x().tensorsAlongDimension(dimension); if (xTADSize != op.y().length()) { throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution:" + - " (x TAD size = " + xTADSize + ", y size = " + op.y().lengthLong()); + " (x TAD size = " + xTADSize + ", y size = " + op.y().length()); } } } @@ -976,7 +976,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (op.y() != null) { //2 options here: either pairwise, equal sizes - OR every X TAD vs. entirety of Y - if (op.x().lengthLong() == op.y().lengthLong()) { + if (op.x().length() == op.y().length()) { //Pairwise if (op.x().tensorsAlongDimension(dimension) != op.y().tensorsAlongDimension(dimension)) { throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " + @@ -985,11 +985,11 @@ public class CudaExecutioner extends DefaultOpExecutioner { } } else { //Every X TAD vs. entirety of Y - val xTADSize = op.x().lengthLong() / op.x().tensorsAlongDimension(dimension); + val xTADSize = op.x().length() / op.x().tensorsAlongDimension(dimension); if (xTADSize != op.y().length()) { throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution:" + - " (x TAD size = " + xTADSize + ", y size = " + op.y().lengthLong()); + " (x TAD size = " + xTADSize + ", y size = " + op.y().length()); } } } @@ -2031,8 +2031,8 @@ public class CudaExecutioner extends DefaultOpExecutioner { long compressedLength = buffer.getInt(0); long originalLength = buffer.getInt(1); - if (target.lengthLong() != originalLength) - throw new ND4JIllegalStateException("originalLength ["+ originalLength+"] stored in encoded array doesn't match target length ["+ target.lengthLong()+"]"); + if (target.length() != originalLength) + throw new ND4JIllegalStateException("originalLength ["+ originalLength+"] stored in encoded array doesn't match target length ["+ target.length()+"]"); DataBuffer result = target.data(); @@ -2056,7 +2056,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public long bitmapEncode(INDArray indArray, INDArray target, double threshold) { - long length = indArray.lengthLong(); + long length = indArray.length(); long tLen = target.data().length(); if (tLen != (length / 16 + 5)) @@ -2117,7 +2117,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { context.getBufferScalar(), context.getBufferReduction()); - nativeOps.decodeBitmap(extras, AtomicAllocator.getInstance().getPointer(encoded.data(), context), target.lengthLong(), AtomicAllocator.getInstance().getPointer(target, context), (LongPointer) AtomicAllocator.getInstance().getHostPointer(target.shapeInfoDataBuffer())); + nativeOps.decodeBitmap(extras, AtomicAllocator.getInstance().getPointer(encoded.data(), context), target.length(), AtomicAllocator.getInstance().getPointer(target, context), (LongPointer) AtomicAllocator.getInstance().getHostPointer(target.shapeInfoDataBuffer())); if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java index ca8e4eb07..850096359 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java @@ -655,7 +655,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio op.setZ(ret); } else { // compare length - if (op.z().lengthLong() != ArrayUtil.prodLong(retShape)) + if (op.z().length() != ArrayUtil.prodLong(retShape)) throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]"); ret = op.z(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java index cacf32b38..7cd4101ef 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java @@ -514,7 +514,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { - int numTads = (int)(tensor.lengthLong() / tadLength); + int numTads = (int)(tensor.length() / tadLength); INDArray[] result = new INDArray[numTads]; PointerPointer targets = new PointerPointer(numTads); @@ -693,7 +693,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { if (arrays.length == 1) return target.addi(arrays[0]); - long len = target.lengthLong(); + long len = target.length(); PointerPointer dataPointers = new PointerPointer(arrays.length); @@ -703,7 +703,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { if (arrays[i].elementWiseStride() != 1) throw new ND4JIllegalStateException("Native accumulation is applicable only to continuous INDArrays"); - if (arrays[i].lengthLong() != len) + if (arrays[i].length() != len) throw new ND4JIllegalStateException("All arrays should have equal length for accumulation"); dataPointers.put(i, arrays[i].data().addressPointer()); @@ -744,7 +744,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { return target.assign(arrays[0]); } - long len = target != null ? target.lengthLong() : arrays[0].length(); + long len = target != null ? target.length() : arrays[0].length(); PointerPointer dataPointers = new PointerPointer(arrays.length); val firstType = arrays[0].dataType(); @@ -757,7 +757,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { if (arrays[i].elementWiseStride() != 1) throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays"); - if (arrays[i].lengthLong() != len) + if (arrays[i].length() != len) throw new ND4JIllegalStateException("All arrays should have equal length for averaging"); dataPointers.put(i, arrays[i].data().addressPointer()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index e79c21feb..663eb862e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -303,11 +303,11 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } } else { //Every X TAD vs. entirety of Y - val xTADSize = op.x().lengthLong() / op.x().tensorsAlongDimension(dimension); + val xTADSize = op.x().length() / op.x().tensorsAlongDimension(dimension); if (xTADSize != op.y().length()) { throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution:" + - " (x TAD size = " + xTADSize + ", y size = " + op.y().lengthLong()); + " (x TAD size = " + xTADSize + ", y size = " + op.y().length()); } } } @@ -329,7 +329,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { long xT = op.x().tensorsAlongDimension(dimension); long yT = op.y().tensorsAlongDimension(dimension); - if (op.z().lengthLong() != xT * yT) + if (op.z().length() != xT * yT) throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + (xT * yT) + "]"); } @@ -358,7 +358,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { // we're going to check, if that's TAD vs TAD comparison or TAD vs full array. if later - we're going slightly different route boolean tvf = false; if (op.y() != null) { - if (op.x().tensorAlongDimension(0, dimension).lengthLong() == op.y().lengthLong()) { + if (op.x().tensorAlongDimension(0, dimension).length() == op.y().length()) { tvf = true; } } @@ -366,10 +366,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (op.isComplexAccumulation()) { yTadBuffers = tadManager.getTADOnlyShapeInfo(op.y(), dimension); - if (op.x().tensorAlongDimension(0, dimension).lengthLong() != op.y().tensorAlongDimension(0, dimension).lengthLong()) + if (op.x().tensorAlongDimension(0, dimension).length() != op.y().tensorAlongDimension(0, dimension).length()) throw new ND4JIllegalStateException("Impossible to issue AllDistances operation: TAD lengths mismatch along given dimension: " + - "x TAD length = " + op.x().tensorAlongDimension(0, dimension).lengthLong() + ", y TAD length " + - op.y().tensorAlongDimension(0, dimension).lengthLong()); + "x TAD length = " + op.x().tensorAlongDimension(0, dimension).length() + ", y TAD length " + + op.y().tensorAlongDimension(0, dimension).length()); } /** @@ -659,7 +659,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { //validateDataType(Nd4j.dataType(), op); - if (op.x().lengthLong() != op.z().lengthLong()) + if (op.x().length() != op.z().length()) throw new ND4JIllegalStateException("op.X length should be equal to op.Z length: " + "x.length()=" + op.x().length() + ", z.length()=" + op.z().length() + " - x shape info = [" + Arrays.toString(op.x().shapeInfoDataBuffer().asInt()) + "], z shape info = [" @@ -1449,8 +1449,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { long originalLength = buffer.getInt(1); float threshold = buffer.getInt(2); - if (target.lengthLong() != originalLength) - throw new ND4JIllegalStateException("originalLength ["+ originalLength+"] stored in encoded array doesn't match target length ["+ target.lengthLong()+"]"); + if (target.length() != originalLength) + throw new ND4JIllegalStateException("originalLength ["+ originalLength+"] stored in encoded array doesn't match target length ["+ target.length()+"]"); DataTypeEx typeDst = AbstractCompressor.getBufferTypeEx(target.data()); @@ -1465,7 +1465,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public long bitmapEncode(INDArray indArray, INDArray target, double threshold) { - long length = indArray.lengthLong(); + long length = indArray.length(); long tLen = target.data().length(); if (tLen != (length / 16 + 5)) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 22b911468..8f0c4de35 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -5155,7 +5155,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray res = x.entropy(1); - assertEquals(10, res.lengthLong()); + assertEquals(10, res.length()); for (int t = 0; t < x.rows(); t++) { double exp = MathUtils.entropy(x.getRow(t).dup().data().asDouble()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java index 23c0ae134..f88e78408 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java @@ -415,7 +415,7 @@ public class ShufflesTests extends BaseNd4jTest { for (int x = 0; x < newData.rows(); x++) { INDArray row = newData.getRow(x); - for (int y = 0; y < row.lengthLong(); y++) { + for (int y = 0; y < row.length(); y++) { if (Math.abs(row.getFloat(y) - newMap[x]) > Nd4j.EPS_THRESHOLD) { System.out.print("Different data in a row"); return false; @@ -442,7 +442,7 @@ public class ShufflesTests extends BaseNd4jTest { for (int x = 0; x < newData.rows(); x++) { INDArray column = newData.getColumn(x); double val = column.getDouble(0); - for (int y = 0; y < column.lengthLong(); y++) { + for (int y = 0; y < column.length(); y++) { if (Math.abs(column.getFloat(y) - val) > Nd4j.EPS_THRESHOLD) { System.out.print("Different data in a column: " + column.getFloat(y)); return false; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionTests.java index 69ac8eacc..ec78c53bd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionTests.java @@ -168,7 +168,7 @@ public class CompressionTests extends BaseNd4jTest { INDArray decompressed = Nd4j.create(1, initial.length()); Nd4j.getExecutioner().thresholdDecode(compressed, decompressed); - log.info("Decompressed length: {}", decompressed.lengthLong()); + log.info("Decompressed length: {}", decompressed.length()); assertEquals(exp_d, decompressed); } From 7d857759341a1fb5f511b394800c93c0529fb14f Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 5 Sep 2019 11:51:11 +1000 Subject: [PATCH 02/10] Arbiter generic JSON ser/de fixes (#237) * Arbiter generic JSON ser/de fixes Signed-off-by: AlexDBlack * Javadoc fix Signed-off-by: AlexDBlack --- .../optimize/parameter/FixedValue.java | 13 +++-- .../serde/jackson/FixedValueDeserializer.java | 52 +++++++++++++++++++ .../serde/jackson/FixedValueSerializer.java | 51 ++++++++++++++++++ .../serde/jackson/GenericDeserializer.java | 46 ---------------- .../serde/jackson/GenericSerializer.java | 38 -------------- .../optimize/serde/jackson/JsonMapper.java | 4 +- .../arbiter/layers/BaseOutputLayerSpace.java | 2 +- .../ui/data/GlobalConfigPersistable.java | 11 +++- .../ui/listener/ArbiterStatusListener.java | 9 +++- .../linalg/lossfunctions/impl/LossL2.java | 3 +- .../linalg/lossfunctions/impl/LossMSE.java | 3 +- 11 files changed, 133 insertions(+), 99 deletions(-) create mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueDeserializer.java create mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueSerializer.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/GenericDeserializer.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/GenericSerializer.java diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/FixedValue.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/FixedValue.java index 0be9613de..6482003e5 100644 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/FixedValue.java +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/FixedValue.java @@ -17,13 +17,14 @@ package org.deeplearning4j.arbiter.optimize.parameter; import lombok.EqualsAndHashCode; +import lombok.Getter; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.serde.jackson.GenericDeserializer; -import org.deeplearning4j.arbiter.optimize.serde.jackson.GenericSerializer; +import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueDeserializer; +import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueSerializer; import org.deeplearning4j.arbiter.util.ObjectUtils; import org.nd4j.shade.jackson.annotation.JsonCreator; -import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties; import org.nd4j.shade.jackson.annotation.JsonProperty; +import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; @@ -37,9 +38,11 @@ import java.util.Map; * @param Type of (fixed) value */ @EqualsAndHashCode +@JsonSerialize(using = FixedValueSerializer.class) +@JsonDeserialize(using = FixedValueDeserializer.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public class FixedValue implements ParameterSpace { - @JsonSerialize(using = GenericSerializer.class) - @JsonDeserialize(using = GenericDeserializer.class) + @Getter private Object value; private int index; diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueDeserializer.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueDeserializer.java new file mode 100644 index 000000000..24b76fd42 --- /dev/null +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueDeserializer.java @@ -0,0 +1,52 @@ +package org.deeplearning4j.arbiter.optimize.serde.jackson; + +import org.apache.commons.codec.binary.Base64; +import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; +import org.nd4j.shade.jackson.core.JsonParser; +import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.databind.DeserializationContext; +import org.nd4j.shade.jackson.databind.JsonDeserializer; +import org.nd4j.shade.jackson.databind.JsonNode; +import org.nd4j.shade.jackson.databind.ObjectMapper; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.ObjectInputStream; + +/** + * A custom deserializer to be used in conjunction with {@link FixedValueSerializer} + * @author Alex Black + */ +public class FixedValueDeserializer extends JsonDeserializer { + @Override + public FixedValue deserialize(JsonParser p, DeserializationContext deserializationContext) throws IOException, JsonProcessingException { + JsonNode node = p.getCodec().readTree(p); + String className = node.get("@valueclass").asText(); + Class c; + try { + c = Class.forName(className); + } catch (Exception e) { + throw new RuntimeException(e); + } + + if(node.has("value")){ + //Number, String, Enum + JsonNode valueNode = node.get("value"); + Object o = new ObjectMapper().treeToValue(valueNode, c); + return new FixedValue<>(o); + } else { + //Everything else + JsonNode valueNode = node.get("data"); + String data = valueNode.asText(); + + byte[] b = new Base64().decode(data); + ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(b)); + try { + Object o = ois.readObject(); + return new FixedValue<>(o); + } catch (Throwable t) { + throw new RuntimeException(t); + } + } + } +} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueSerializer.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueSerializer.java new file mode 100644 index 000000000..349177595 --- /dev/null +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueSerializer.java @@ -0,0 +1,51 @@ +package org.deeplearning4j.arbiter.optimize.serde.jackson; + +import org.apache.commons.net.util.Base64; +import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; +import org.nd4j.shade.jackson.core.JsonGenerator; +import org.nd4j.shade.jackson.core.type.WritableTypeId; +import org.nd4j.shade.jackson.databind.JsonSerializer; +import org.nd4j.shade.jackson.databind.SerializerProvider; +import org.nd4j.shade.jackson.databind.jsontype.TypeSerializer; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectOutputStream; + +import static org.nd4j.shade.jackson.core.JsonToken.START_OBJECT; + +/** + * A custom serializer to handle arbitrary object types + * Uses standard JSON where safe (number, string, enumerations) or Java object serialization (bytes -> base64) + * The latter is not an ideal approach, but Jackson doesn't support serialization/deserialization of arbitrary + * objects very well + * + * @author Alex Black + */ +public class FixedValueSerializer extends JsonSerializer { + @Override + public void serialize(FixedValue fixedValue, JsonGenerator j, SerializerProvider serializerProvider) throws IOException { + Object o = fixedValue.getValue(); + + j.writeStringField("@valueclass", o.getClass().getName()); + if(o instanceof Number || o instanceof String || o instanceof Enum){ + j.writeObjectField("value", o); + } else { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(baos); + oos.writeObject(o); + baos.close(); + byte[] b = baos.toByteArray(); + String base64 = new Base64().encodeToString(b); + j.writeStringField("data", base64); + } + } + + @Override + public void serializeWithType(FixedValue value, JsonGenerator gen, SerializerProvider serializers, TypeSerializer typeSer) throws IOException { + WritableTypeId typeId = typeSer.typeId(value, START_OBJECT); + typeSer.writeTypePrefix(gen, typeId); + serialize(value, gen, serializers); + typeSer.writeTypeSuffix(gen, typeId); + } +} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/GenericDeserializer.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/GenericDeserializer.java deleted file mode 100644 index de35dba18..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/GenericDeserializer.java +++ /dev/null @@ -1,46 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.serde.jackson; - -import org.nd4j.shade.jackson.core.JsonParser; -import org.nd4j.shade.jackson.databind.DeserializationContext; -import org.nd4j.shade.jackson.databind.JsonDeserializer; -import org.nd4j.shade.jackson.databind.JsonNode; -import org.nd4j.shade.jackson.databind.ObjectMapper; - -import java.io.IOException; - -/** - * Created by Alex on 15/02/2017. - */ -public class GenericDeserializer extends JsonDeserializer { - @Override - public Object deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { - JsonNode node = p.getCodec().readTree(p); - String className = node.get("@class").asText(); - Class c; - try { - c = Class.forName(className); - } catch (Exception e) { - throw new RuntimeException(e); - } - - JsonNode valueNode = node.get("value"); - Object o = new ObjectMapper().treeToValue(valueNode, c); - return o; - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/GenericSerializer.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/GenericSerializer.java deleted file mode 100644 index 035ac7c50..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/GenericSerializer.java +++ /dev/null @@ -1,38 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.serde.jackson; - -import org.nd4j.shade.jackson.core.JsonGenerator; -import org.nd4j.shade.jackson.core.JsonProcessingException; -import org.nd4j.shade.jackson.databind.JsonSerializer; -import org.nd4j.shade.jackson.databind.SerializerProvider; - -import java.io.IOException; - -/** - * Created by Alex on 15/02/2017. - */ -public class GenericSerializer extends JsonSerializer { - @Override - public void serialize(Object o, JsonGenerator j, SerializerProvider serializerProvider) - throws IOException, JsonProcessingException { - j.writeStartObject(); - j.writeStringField("@class", o.getClass().getName()); - j.writeObjectField("value", o); - j.writeEndObject(); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java index 8cfb07723..f30cab109 100644 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java @@ -24,9 +24,6 @@ import org.nd4j.shade.jackson.databind.SerializationFeature; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; import org.nd4j.shade.jackson.datatype.joda.JodaModule; -import java.util.Collections; -import java.util.Map; - /** * Created by Alex on 16/11/2016. */ @@ -44,6 +41,7 @@ public class JsonMapper { mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); + mapper.setVisibility(PropertyAccessor.SETTER, JsonAutoDetect.Visibility.ANY); yamlMapper = new ObjectMapper(new YAMLFactory()); yamlMapper.registerModule(new JodaModule()); yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseOutputLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseOutputLayerSpace.java index 3a72156e9..857f729ad 100644 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseOutputLayerSpace.java +++ b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseOutputLayerSpace.java @@ -32,7 +32,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; */ @Data @EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization +@NoArgsConstructor(access = AccessLevel.PUBLIC) //For Jackson JSON/YAML deserialization public abstract class BaseOutputLayerSpace extends FeedForwardLayerSpace { protected ParameterSpace lossFunction; diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/GlobalConfigPersistable.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/GlobalConfigPersistable.java index d11c251e3..00a95a628 100644 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/GlobalConfigPersistable.java +++ b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/GlobalConfigPersistable.java @@ -18,8 +18,11 @@ package org.deeplearning4j.arbiter.ui.data; import lombok.Getter; import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.ui.misc.JsonMapper; +import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper; import org.deeplearning4j.arbiter.ui.module.ArbiterModule; +import org.deeplearning4j.nn.conf.serde.JsonMappers; + +import java.io.IOException; /** * @@ -64,7 +67,11 @@ public class GlobalConfigPersistable extends BaseJavaPersistable { public OptimizationConfiguration getOptimizationConfiguration(){ - return JsonMapper.fromJson(optimizationConfigJson, OptimizationConfiguration.class); + try { + return JsonMapper.getMapper().readValue(optimizationConfigJson, OptimizationConfiguration.class); + } catch (IOException e){ + throw new RuntimeException(e); + } } public int getCandidatesQueued(){ diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/listener/ArbiterStatusListener.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/listener/ArbiterStatusListener.java index f323f8a7a..7802762c9 100644 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/listener/ArbiterStatusListener.java +++ b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/listener/ArbiterStatusListener.java @@ -26,13 +26,14 @@ import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo; import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener; +import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper; import org.deeplearning4j.arbiter.ui.data.GlobalConfigPersistable; import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable; -import org.deeplearning4j.arbiter.ui.misc.JsonMapper; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.primitives.Pair; +import java.io.IOException; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -217,7 +218,11 @@ public class ArbiterStatusListener implements StatusListener { // } //TODO: cache global config, but we don't want to have outdated info (like uninitialized termination conditions) - ocJson = JsonMapper.asJson(r.getConfiguration()); + try { + ocJson = JsonMapper.getMapper().writeValueAsString(r.getConfiguration()); + } catch (IOException e){ + throw new RuntimeException(e); + } GlobalConfigPersistable p = new GlobalConfigPersistable.Builder() .sessionId(sessionId) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL2.java index e9b30328c..a9f3b2c18 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL2.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.primitives.Pair; import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer; import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer; import org.nd4j.shade.jackson.annotation.JsonInclude; +import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; @@ -58,7 +59,7 @@ public class LossL2 implements ILossFunction { * * @param weights Weights array (row vector). May be null. */ - public LossL2(INDArray weights) { + public LossL2(@JsonProperty("weights") INDArray weights) { if (weights != null && !weights.isRowVector()) { throw new IllegalArgumentException("Weights array must be a row vector"); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSE.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSE.java index 4fc5a7eec..bb64bb777 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSE.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSE.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.lossfunctions.impl; import lombok.EqualsAndHashCode; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.shade.jackson.annotation.JsonProperty; /** * Mean Squared Error loss function: L = 1/N sum_i (actual_i - predicted)^2 @@ -38,7 +39,7 @@ public class LossMSE extends LossL2 { * * @param weights Weights array (row vector). May be null. */ - public LossMSE(INDArray weights) { + public LossMSE(@JsonProperty("weights") INDArray weights) { super(weights); } From 79867f5c5a0c7e44ea4a31e0b7f8fc0835cbd262 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 4 Sep 2019 19:25:03 -0700 Subject: [PATCH 03/10] cleanup SDRNN and rnn ops (#238) Signed-off-by: Ryan Nett --- .../declarable/generic/recurrent/sruCell.cpp | 2 +- .../org/nd4j/autodiff/samediff/SameDiff.java | 18 ++ .../org/nd4j/autodiff/samediff/ops/SDRNN.java | 205 ++++++++++++------ .../ops/impl/layers/recurrent/GRUCell.java | 11 +- .../impl/layers/recurrent/LSTMBlockCell.java | 29 ++- .../ops/impl/layers/recurrent/LSTMLayer.java | 16 +- .../api/ops/impl/layers/recurrent/SRU.java | 22 +- .../ops/impl/layers/recurrent/SRUCell.java | 16 +- .../config/LSTMBlockCellConfiguration.java | 57 ----- .../recurrent/config/LSTMConfiguration.java | 38 +++- .../config/SRUCellConfiguration.java | 44 ---- .../recurrent/config/SRUConfiguration.java | 38 ---- .../recurrent/outputs/GRUCellOutputs.java | 62 ++++++ .../recurrent/outputs/LSTMCellOutputs.java | 88 ++++++++ .../recurrent/outputs/LSTMLayerOutputs.java | 180 +++++++++++++++ .../recurrent/outputs/SRUCellOutputs.java | 60 +++++ .../recurrent/outputs/SRULayerOutputs.java | 92 ++++++++ .../layers/recurrent/weights/GRUWeights.java | 51 +++++ .../layers/recurrent/weights/LSTMWeights.java | 57 +++++ .../layers/recurrent/weights/RNNWeights.java | 35 +++ .../layers/recurrent/weights/SRUWeights.java | 37 ++++ .../opvalidation/RnnOpValidation.java | 55 ++--- .../java/org/nd4j/linalg/primitives/Pair.java | 7 + 23 files changed, 943 insertions(+), 277 deletions(-) delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMBlockCellConfiguration.java delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUCellConfiguration.java delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUConfiguration.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/GRUCellOutputs.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMCellOutputs.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRUCellOutputs.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRULayerOutputs.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/GRUWeights.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMWeights.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/RNNWeights.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/SRUWeights.java diff --git a/libnd4j/include/ops/declarable/generic/recurrent/sruCell.cpp b/libnd4j/include/ops/declarable/generic/recurrent/sruCell.cpp index 2961e3bcf..23b2ec172 100644 --- a/libnd4j/include/ops/declarable/generic/recurrent/sruCell.cpp +++ b/libnd4j/include/ops/declarable/generic/recurrent/sruCell.cpp @@ -34,7 +34,7 @@ CUSTOM_OP_IMPL(sruCell, 4, 2, false, 0, 0) { auto xt = INPUT_VARIABLE(0); // input [bS x inSize], bS - batch size, inSize - number of features auto ct_1 = INPUT_VARIABLE(1); // previous cell state ct [bS x inSize], that is at previous time step t-1 auto w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize] - auto b = INPUT_VARIABLE(3); // biases [1 x 2*inSize] + auto b = INPUT_VARIABLE(3); // biases [2*inSize] auto ht = OUTPUT_VARIABLE(0); // current cell output [bS x inSize], that is at current time step t auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x inSize], that is at current time step t diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 0b5a4c03f..1821a30a0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -6511,4 +6511,22 @@ public class SameDiff extends SDBaseOps { public String generateNewVarName(String base, int argIndex) { return generateNewVarName(base, argIndex, true); } + + /** + * Returns an unused variable name of the format <base>_#. + * + * Intended to be used for custom variables (like weights), arguments and op outputs should use {@link #generateNewVarName(String, int)}. + */ + public String generateDistinctCustomVariableName(String base){ + if(!variables.containsKey(base)) + return base; + + int inc = 1; + + while(variables.containsKey(base + "_" + inc)){ + inc++; + } + + return base + "_" + inc; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java index f47f32b87..de0114b92 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java @@ -16,6 +16,7 @@ package org.nd4j.autodiff.samediff.ops; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ops.impl.layers.recurrent.*; @@ -23,6 +24,15 @@ import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.*; import java.util.Arrays; import java.util.List; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRUCellOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRULayerOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; +import org.nd4j.linalg.primitives.Pair; /** * SameDiff Recurrent Neural Network operations
@@ -39,90 +49,163 @@ public class SDRNN extends SDOps { /** - * The gru cell - * - * @param configuration the configuration to use - * @return + * See {@link #gru(String, SDVariable, SDVariable, GRUWeights)}. */ - public List gru(GRUCellConfiguration configuration) { - GRUCell c = new GRUCell(sd, configuration); - return Arrays.asList(c.outputVariables()); + public GRUCellOutputs gru(@NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) { + GRUCell c = new GRUCell(sd, x, hLast, weights); + return new GRUCellOutputs(c.outputVariables()); } /** - * The gru cell + * The GRU cell. Does a single time step operation. * - * @param baseName the base name for the gru cell - * @param configuration the configuration to use - * @return + * @param baseName The base name for the gru cell + * @param x Input, with shape [batchSize, inSize] + * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] + * @param weights The cell's weights. + * @return The cell's outputs. */ - public List gru(String baseName, GRUCellConfiguration configuration) { - GRUCell c = new GRUCell(sd, configuration); - return Arrays.asList(c.outputVariables(baseName)); - } - - - /** - * LSTM unit - * - * @param baseName the base name for outputs - * @param configuration the configuration to use - * @return - */ - public SDVariable lstmCell(String baseName, LSTMCellConfiguration configuration) { - return new LSTMCell(sd, configuration).outputVariables(baseName)[0]; - } - - public List lstmBlockCell(String name, LSTMBlockCellConfiguration configuration){ - SDVariable[] v = new LSTMBlockCell(sd, configuration).outputVariables(name); - return Arrays.asList(v); - } - - public List lstmLayer(String name, LSTMConfiguration configuration){ - SDVariable[] v = new LSTMLayer(sd, configuration).outputVariables(name); - return Arrays.asList(v); + public GRUCellOutputs gru(String baseName, @NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) { + GRUCell c = new GRUCell(sd, x, hLast, weights); + return new GRUCellOutputs(c.outputVariables(baseName)); } /** - * Simple recurrent unit - * - * @param configuration the configuration for the sru - * @return + * See {@link #lstmCell(String, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}. */ - public SDVariable sru(SRUConfiguration configuration) { - return new SRU(sd, configuration).outputVariables()[0]; + public LSTMCellOutputs lstmCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, + LSTMWeights weights, LSTMConfiguration config){ + LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config); + return new LSTMCellOutputs(c.outputVariables()); } /** - * Simiple recurrent unit + * The LSTM cell. Does a single time step operation. * - * @param baseName the base name to use for output variables - * @param configuration the configuration for the sru - * @return + * @param baseName The base name for the lstm cell + * @param x Input, with shape [batchSize, inSize] + * @param cLast Previous cell state, with shape [batchSize, numUnits] + * @param yLast Previous cell output, with shape [batchSize, numUnits] + * @param weights The cell's weights. + * @param config The cell's config. + * @return The cell's outputs. */ - public SDVariable sru(String baseName, SRUConfiguration configuration) { - return new SRU(sd, configuration).outputVariables(baseName)[0]; + public LSTMCellOutputs lstmCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, + @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ + LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config); + return new LSTMCellOutputs(c.outputVariables(baseName)); } /** - * An sru cell - * - * @param configuration the configuration for the sru cell - * @return + * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)} */ - public SDVariable sruCell(SRUCellConfiguration configuration) { - return new SRUCell(sd, configuration).outputVariables()[0]; + public LSTMLayerOutputs lstmLayer(@NonNull SDVariable maxTSLength, + @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, + @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ + LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config); + return new LSTMLayerOutputs(c.outputVariables(), config.getDataFormat()); } /** - * An sru cell - * - * @param baseName the base name to use for the output variables - * @param configuration the configuration for the sru cell - * @return + * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)} */ - public SDVariable sruCell(String baseName, SRUCellConfiguration configuration) { - return new SRUCell(sd, configuration).outputVariables(baseName)[0]; + public LSTMLayerOutputs lstmLayer(int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, + @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ + return lstmLayer( + sd.scalar("lstm_max_ts_length", maxTSLength), + x, cLast, yLast, weights, config); + } + + /** + * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)} + */ + public LSTMLayerOutputs lstmLayer(String baseName, int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, + @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ + if(baseName != null) { + return lstmLayer(baseName, + sd.scalar(sd.generateDistinctCustomVariableName(baseName + "_max_ts_length"), maxTSLength), + x, cLast, yLast, weights, config); + } else { + return lstmLayer(maxTSLength, x, cLast, yLast, weights, config); + } + } + + /** + * The LSTM layer. Does multiple time steps. + * + * Input shape depends on data format (in config):
+ * TNS -> [timeSteps, batchSize, inSize]
+ * NST -> [batchSize, inSize, timeSteps]
+ * NTS -> [batchSize, timeSteps, inSize]
+ * + * @param baseName The base name for the lstm layer + * @param x Input, with shape dependent on the data format (in config). + * @param cLast Previous/initial cell state, with shape [batchSize, numUnits] + * @param yLast Previous/initial cell output, with shape [batchSize, numUnits] + * @param weights The layer's weights. + * @param config The layer's config. + * @return The layer's outputs. + */ + public LSTMLayerOutputs lstmLayer(String baseName, @NonNull SDVariable maxTSLength, + @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, + @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ + LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config); + return new LSTMLayerOutputs(c.outputVariables(baseName), config.getDataFormat()); + } + + /** + * See {@link #sruCell(String, SDVariable, SDVariable, SRUWeights)}. + */ + public SRUCellOutputs sruCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) { + return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables()); + } + + /** + * The SRU cell. Does a single time step operation. + * + * @param baseName The base name for the sru cell + * @param x Input, with shape [batchSize, inSize] + * @param cLast Previous cell state, with shape [batchSize, inSize] + * @param weights The cell's weights. + * @return The cell's outputs. + */ + public SRUCellOutputs sruCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) { + return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables(baseName)); + } + + /** + * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)} + */ + public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) { + return sru(x, initialC, null, weights); + } + + /** + * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)} + */ + public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) { + return sru(baseName, x, initialC, null, weights); + } + + /** + * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)} + */ + public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) { + return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables()); + } + + /** + * The SRU layer. Does a single time step operation. + * + * @param baseName The base name for the sru layer + * @param x Input, with shape [batchSize, inSize, timeSeriesLength] + * @param initialC Initial cell state, with shape [batchSize, inSize] + * @param mask An optional dropout mask, with shape [batchSize, inSize] + * @param weights The layer's weights. + * @return The layer's outputs. + */ + public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) { + return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables(baseName)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java index 6c7daca69..2fa99ace5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent; +import lombok.Getter; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -23,6 +24,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -39,14 +41,15 @@ import java.util.Map; */ public class GRUCell extends DynamicCustomOp { - private GRUCellConfiguration configuration; + @Getter + private GRUWeights weights; public GRUCell() { } - public GRUCell(SameDiff sameDiff, GRUCellConfiguration configuration) { - super(null, sameDiff, configuration.args()); - this.configuration = configuration; + public GRUCell(SameDiff sameDiff, SDVariable x, SDVariable hLast, GRUWeights weights) { + super(null, sameDiff, weights.argsWithInputs(x, hLast)); + this.weights = weights; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlockCell.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlockCell.java index 36512f610..f88625987 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlockCell.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMBlockCell.java @@ -16,12 +16,15 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent; +import lombok.Getter; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMBlockCellConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; +import org.nd4j.linalg.primitives.Pair; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -49,10 +52,12 @@ import java.util.Map; * 6: weights - cell peephole (t) connections to output gate, [numUnits]
* 7: biases, shape [4*numUnits]
*
- * Input integer arguments: set via {@link LSTMBlockCellConfiguration}
+ * Weights are set via {@link LSTMWeights}.
+ *
+ * Input integer arguments: set via {@link LSTMConfiguration}
* 0: if not zero, provide peephole connections
*
- * Input float arguments: set via {@link LSTMBlockCellConfiguration}
+ * Input float arguments: set via {@link LSTMConfiguration}
* 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training
* 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped
*
@@ -69,15 +74,19 @@ import java.util.Map; */ public class LSTMBlockCell extends DynamicCustomOp { - private LSTMBlockCellConfiguration configuration; + private LSTMConfiguration configuration; + + @Getter + private LSTMWeights weights; public LSTMBlockCell() { } - public LSTMBlockCell(SameDiff sameDiff, LSTMBlockCellConfiguration configuration) { - super(null, sameDiff, configuration.args()); + public LSTMBlockCell(SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) { + super(null, sameDiff, weights.argsWithInputs(x, cLast, yLast)); this.configuration = configuration; - addIArgument(configuration.iArgs()); + this.weights = weights; + addIArgument(configuration.iArgs(false)); addTArgument(configuration.tArgs()); } @@ -97,12 +106,12 @@ public class LSTMBlockCell extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - configuration = LSTMBlockCellConfiguration.builder() + configuration = LSTMConfiguration.builder() .forgetBias(attributesForNode.get("forget_bias").getF()) .clippingCellValue(attributesForNode.get("cell_clip").getF()) .peepHole(attributesForNode.get("use_peephole").getB()) .build(); - addIArgument(configuration.iArgs()); + addIArgument(configuration.iArgs(false)); addTArgument(configuration.tArgs()); } @@ -113,7 +122,7 @@ public class LSTMBlockCell extends DynamicCustomOp { @Override public Map propertiesForFunction() { - return configuration.toProperties(); + return configuration.toProperties(false); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java index 527c0c3ca..1e1ae3c47 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent; +import lombok.Getter; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -24,6 +25,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -75,13 +77,17 @@ public class LSTMLayer extends DynamicCustomOp { private LSTMConfiguration configuration; + @Getter + private LSTMWeights weights; + public LSTMLayer() { } - public LSTMLayer(@NonNull SameDiff sameDiff, @NonNull LSTMConfiguration configuration) { - super(null, sameDiff, configuration.args()); + public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) { + super(null, sameDiff, weights.argsWithInputs(maxTSLength, x, cLast, yLast)); this.configuration = configuration; - addIArgument(configuration.iArgs()); + this.weights = weights; + addIArgument(configuration.iArgs(true)); addTArgument(configuration.tArgs()); } @@ -107,7 +113,7 @@ public class LSTMLayer extends DynamicCustomOp { .peepHole(attributesForNode.get("use_peephole").getB()) .dataFormat(RnnDataFormat.TNS) //Always time major for TF BlockLSTM .build(); - addIArgument(configuration.iArgs()); + addIArgument(configuration.iArgs(true)); addTArgument(configuration.tArgs()); } @@ -118,7 +124,7 @@ public class LSTMLayer extends DynamicCustomOp { @Override public Map propertiesForFunction() { - return configuration.toProperties(); + return configuration.toProperties(true); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java index b916d4961..a2de2beb8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java @@ -16,11 +16,16 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent; +import java.util.Arrays; +import java.util.List; +import lombok.Getter; +import lombok.NonNull; import onnx.Onnx; +import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -34,13 +39,18 @@ import java.util.Map; */ public class SRU extends DynamicCustomOp { - private SRUConfiguration configuration; + @Getter + private SRUWeights weights; + + @Getter + private SDVariable mask; public SRU() { } - public SRU(SameDiff sameDiff, SRUConfiguration configuration) { - super(null, sameDiff, configuration.args()); - this.configuration = configuration; + public SRU(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) { + super(null, sameDiff, wrapFilterNull(x, weights.getWeights(), weights.getBias(), initialC, mask)); + this.mask = mask; + this.weights = weights; } @Override @@ -68,6 +78,4 @@ public class SRU extends DynamicCustomOp { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { super.initFromOnnx(node, initWith, attributesForNode, graph); } - - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java index 4880b90fe..ac3f6c07f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java @@ -16,17 +16,18 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent; +import java.util.Map; +import lombok.Getter; import onnx.Onnx; +import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUCellConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.util.Map; - /** * A simple recurrent unit cell. * @@ -34,14 +35,15 @@ import java.util.Map; */ public class SRUCell extends DynamicCustomOp { - private SRUCellConfiguration configuration; + @Getter + private SRUWeights weights; public SRUCell() { } - public SRUCell(SameDiff sameDiff, SRUCellConfiguration configuration) { - super(null, sameDiff, configuration.args()); - this.configuration = configuration; + public SRUCell(SameDiff sameDiff, SDVariable x, SDVariable cLast, SRUWeights weights) { + super(null, sameDiff, weights.argsWithInputs(x, cLast)); + this.weights = weights; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMBlockCellConfiguration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMBlockCellConfiguration.java deleted file mode 100644 index 3e2591062..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMBlockCellConfiguration.java +++ /dev/null @@ -1,57 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; - -import lombok.Builder; -import lombok.Data; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.linalg.util.ArrayUtil; - -import java.util.LinkedHashMap; -import java.util.Map; - -@Builder -@Data -public class LSTMBlockCellConfiguration { - - private boolean peepHole; //IArg(0) - private double forgetBias; //TArg(0) - private double clippingCellValue; //TArg(1) - - private SDVariable xt, cLast, yLast, W, Wci, Wcf, Wco, b; - - public Map toProperties() { - Map ret = new LinkedHashMap<>(); - ret.put("peepHole",peepHole); - ret.put("clippingCellValue",clippingCellValue); - ret.put("forgetBias",forgetBias); - return ret; - } - - public SDVariable[] args() { - return new SDVariable[] {xt,cLast, yLast, W, Wci, Wcf, Wco, b}; - } - - - public int[] iArgs() { - return new int[] {ArrayUtil.fromBoolean(peepHole)}; - } - - public double[] tArgs() { - return new double[] {forgetBias,clippingCellValue}; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMConfiguration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMConfiguration.java index 98dc58876..4cf807765 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMConfiguration.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMConfiguration.java @@ -19,13 +19,15 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; import lombok.Builder; import lombok.Data; import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; import org.nd4j.linalg.util.ArrayUtil; import java.util.LinkedHashMap; import java.util.Map; /** - * LSTM Configuration - for {@link org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer} + * LSTM Configuration - for {@link LSTMLayer} and {@link LSTMBlockCell} * * @author Alex Black */ @@ -33,29 +35,41 @@ import java.util.Map; @Data public class LSTMConfiguration { + /** + * Whether to provide peephole connections. + */ private boolean peepHole; //IArg(0) - @Builder.Default private RnnDataFormat dataFormat = RnnDataFormat.TNS; //IArg(1) + + /** + * The data format of the input. Only used in {@link LSTMLayer}, ignored in {@link LSTMBlockCell}. + */ + @Builder.Default private RnnDataFormat dataFormat = RnnDataFormat.TNS; //IArg(1) (only for lstmBlock, not lstmBlockCell) + + /** + * The bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training. + */ private double forgetBias; //TArg(0) + + /** + * Clipping value for cell state, if it is not equal to zero, then cell state is clipped. + */ private double clippingCellValue; //TArg(1) - private SDVariable xt, cLast, yLast, W, Wci, Wcf, Wco, b; - - public Map toProperties() { + public Map toProperties(boolean includeDataFormat) { Map ret = new LinkedHashMap<>(); ret.put("peepHole",peepHole); ret.put("clippingCellValue",clippingCellValue); ret.put("forgetBias",forgetBias); - ret.put("dataFormat", dataFormat); + if(includeDataFormat) + ret.put("dataFormat", dataFormat); return ret; } - public SDVariable[] args() { - return new SDVariable[] {xt,cLast, yLast, W, Wci, Wcf, Wco, b}; - } - - public int[] iArgs() { - return new int[] {ArrayUtil.fromBoolean(peepHole), dataFormat.ordinal()}; + public int[] iArgs(boolean includeDataFormat) { + if(includeDataFormat) { + return new int[]{ArrayUtil.fromBoolean(peepHole), dataFormat.ordinal()}; + } else return new int[]{ArrayUtil.fromBoolean(peepHole)}; } public double[] tArgs() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUCellConfiguration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUCellConfiguration.java deleted file mode 100644 index 4b0a39a80..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUCellConfiguration.java +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; - -import lombok.Builder; -import lombok.Data; -import org.nd4j.autodiff.samediff.SDVariable; - -@Data -@Builder -public class SRUCellConfiguration { - /** - * - NDArray* xt = INPUT_VARIABLE(0); // input [batchSize x inSize], batchSize - batch size, inSize - number of features - NDArray* ct_1 = INPUT_VARIABLE(1); // previous cell state ct [batchSize x inSize], that is at previous time step t-1 - NDArray* w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize] - NDArray* b = INPUT_VARIABLE(3); // biases [1 x 2*inSize] - - NDArray* ht = OUTPUT_VARIABLE(0); // current cell output [batchSize x inSize], that is at current time step t - NDArray* ct = OUTPUT_VARIABLE(1); // current cell state [batchSize x inSize], that is at current time step t - - */ - private SDVariable xt,ct_1,w,b,h1,ct; - - - public SDVariable[] args() { - return new SDVariable[] {xt,ct_1,w,b,h1,ct}; - } - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUConfiguration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUConfiguration.java deleted file mode 100644 index 8bfa90330..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/SRUConfiguration.java +++ /dev/null @@ -1,38 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; - -import lombok.Builder; -import lombok.Data; -import org.nd4j.autodiff.samediff.SDVariable; - -@Data -@Builder -public class SRUConfiguration { - /** - * NDArray* input = INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features - NDArray* weights = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K] - NDArray* bias = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*K] - NDArray* init = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0 - - */ - private SDVariable inputs,weights,bias,init; - - public SDVariable[] args() { - return new SDVariable[] {inputs,weights,bias,init}; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/GRUCellOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/GRUCellOutputs.java new file mode 100644 index 000000000..a39a5bcc7 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/GRUCellOutputs.java @@ -0,0 +1,62 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs; + +import java.util.Arrays; +import java.util.List; +import lombok.Getter; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell; + +/** + * The outputs of a GRU cell ({@link GRUCell}. + */ +@Getter +public class GRUCellOutputs { + + /** + * Reset gate output [batchSize, numUnits]. + */ + private SDVariable r; + + /** + * Update gate output [batchSize, numUnits]. + */ + private SDVariable u; + + /** + * Cell gate output [batchSize, numUnits]. + */ + private SDVariable c; + + /** + * Current cell output [batchSize, numUnits]. + */ + private SDVariable h; + + public GRUCellOutputs(SDVariable[] outputs){ + Preconditions.checkArgument(outputs.length == 4, + "Must have 4 GRU cell outputs, got %s", outputs.length); + + r = outputs[0]; + u = outputs[1]; + c = outputs[2]; + h = outputs[3]; + } + + /** + * Get all outputs returned by the cell. + */ + public List getAllOutputs(){ + return Arrays.asList(r, u, c, h); + } + + /** + * Get h, the output of the cell. + * + * Has shape [batchSize, numUnits]. + */ + public SDVariable getOutput(){ + return h; + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMCellOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMCellOutputs.java new file mode 100644 index 000000000..4fec87e8b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMCellOutputs.java @@ -0,0 +1,88 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs; + +import java.util.Arrays; +import java.util.List; +import lombok.Getter; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; + +/** + * The outputs of a LSTM cell ({@link LSTMBlockCell}. + */ +@Getter +public class LSTMCellOutputs { + + /** + * Output - input modulation gate activations [batchSize, numUnits]. + */ + private SDVariable i; + + /** + * Activations, cell state (pre tanh) [batchSize, numUnits]. + */ + private SDVariable c; + + /** + * Output - forget gate activations [batchSize, numUnits]. + */ + private SDVariable f; + + /** + * Output - output gate activations [batchSize, numUnits]. + */ + private SDVariable o; + + /** + * Output - input gate activations [batchSize, numUnits]. + */ + private SDVariable z; + + /** + * Cell state, post tanh [batchSize, numUnits]. + */ + private SDVariable h; + + /** + * Current cell output [batchSize, numUnits]. + */ + private SDVariable y; + + public LSTMCellOutputs(SDVariable[] outputs){ + Preconditions.checkArgument(outputs.length == 7, + "Must have 7 LSTM cell outputs, got %s", outputs.length); + + i = outputs[0]; + c = outputs[1]; + f = outputs[2]; + o = outputs[3]; + z = outputs[4]; + h = outputs[5]; + y = outputs[6]; + } + + /** + * Get all outputs returned by the cell. + */ + public List getAllOutputs(){ + return Arrays.asList(i, c, f, o, z, h, y); + } + + /** + * Get y, the output of the cell. + * + * Has shape [batchSize, numUnits]. + */ + public SDVariable getOutput(){ + return y; + } + + /** + * Get c, the cell's state. + * + * Has shape [batchSize, numUnits]. + */ + public SDVariable getState(){ + return c; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java new file mode 100644 index 000000000..a01be219f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java @@ -0,0 +1,180 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs; + +import java.util.Arrays; +import java.util.List; +import lombok.AccessLevel; +import lombok.Getter; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; + +/** + * The outputs of a LSTM layer ({@link LSTMLayer}. + */ +@Getter +public class LSTMLayerOutputs { + + private RnnDataFormat dataFormat; + + /** + * Output - input modulation gate activations. + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + private SDVariable i; + + /** + * Activations, cell state (pre tanh). + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + private SDVariable c; + + /** + * Output - forget gate activations. + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + private SDVariable f; + + /** + * Output - output gate activations. + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + private SDVariable o; + + /** + * Output - input gate activations. + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + private SDVariable z; + + /** + * Cell state, post tanh. + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + private SDVariable h; + + /** + * Current cell output. + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + private SDVariable y; + + public LSTMLayerOutputs(SDVariable[] outputs, RnnDataFormat dataFormat){ + Preconditions.checkArgument(outputs.length == 7, + "Must have 7 LSTM layer outputs, got %s", outputs.length); + + i = outputs[0]; + c = outputs[1]; + f = outputs[2]; + o = outputs[3]; + z = outputs[4]; + h = outputs[5]; + y = outputs[6]; + this.dataFormat = dataFormat; + } + + /** + * Get all outputs returned by the cell. + */ + public List getAllOutputs(){ + return Arrays.asList(i, c, f, o, z, h, y); + } + + /** + * Get y, the output of the cell for all time steps. + * + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + public SDVariable getOutput(){ + return y; + } + + /** + * Get c, the cell's state for all time steps. + * + * Shape depends on data format (in layer config):
+ * TNS -> [timeSteps, batchSize, numUnits]
+ * NST -> [batchSize, numUnits, timeSteps]
+ * NTS -> [batchSize, timeSteps, numUnits]
+ */ + public SDVariable getState(){ + return c; + } + + private SDVariable lastOutput = null; + + /** + * Get y, the output of the cell, for the last time step. + * + * Has shape [batchSize, numUnits]. + */ + public SDVariable getLastOutput(){ + if(lastOutput != null) + return lastOutput; + + switch (dataFormat){ + case TNS: + lastOutput = getOutput().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all()); + break; + case NST: + lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1)); + break; + case NTS: + lastOutput = getOutput().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all()); + break; + } + return lastOutput; + } + + private SDVariable lastState = null; + + /** + * Get c, the state of the cell, for the last time step. + * + * Has shape [batchSize, numUnits]. + */ + public SDVariable getLastState(){ + if(lastState != null) + return lastState; + + switch (dataFormat){ + case TNS: + lastState = getState().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all()); + break; + case NST: + lastState = getState().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1)); + break; + case NTS: + lastState = getState().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all()); + break; + } + return lastState; + } + + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRUCellOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRUCellOutputs.java new file mode 100644 index 000000000..d82ad63b1 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRUCellOutputs.java @@ -0,0 +1,60 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs; + +import java.util.Arrays; +import java.util.List; +import lombok.Getter; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell; + +/** + * The outputs of a GRU cell ({@link GRUCell}. + */ +@Getter +public class SRUCellOutputs { + + + /** + * Current cell output [batchSize, numUnits]. + */ + private SDVariable h; + + /** + * Current cell state [batchSize, numUnits]. + */ + private SDVariable c; + + public SRUCellOutputs(SDVariable[] outputs){ + Preconditions.checkArgument(outputs.length == 2, + "Must have 2 SRU cell outputs, got %s", outputs.length); + + h = outputs[0]; + c = outputs[1]; + } + + /** + * Get all outputs returned by the cell. + */ + public List getAllOutputs(){ + return Arrays.asList(h, c); + } + + /** + * Get h, the output of the cell. + * + * Has shape [batchSize, inSize]. + */ + public SDVariable getOutput(){ + return h; + } + + /** + * Get c, the state of the cell. + * + * Has shape [batchSize, inSize]. + */ + public SDVariable getState(){ + return c; + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRULayerOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRULayerOutputs.java new file mode 100644 index 000000000..281d2cc10 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRULayerOutputs.java @@ -0,0 +1,92 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs; + +import java.util.Arrays; +import java.util.List; +import lombok.AccessLevel; +import lombok.Getter; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell; + +/** + * The outputs of a GRU cell ({@link GRUCell}. + */ +@Getter +public class SRULayerOutputs { + + + /** + * Current cell output [batchSize, inSize, timeSeriesLength]. + */ + private SDVariable h; + + /** + * Current cell state [batchSize, inSize, timeSeriesLength]. + */ + private SDVariable c; + + public SRULayerOutputs(SDVariable[] outputs){ + Preconditions.checkArgument(outputs.length == 2, + "Must have 2 SRU cell outputs, got %s", outputs.length); + + h = outputs[0]; + c = outputs[1]; + } + + /** + * Get all outputs returned by the cell. + */ + public List getAllOutputs(){ + return Arrays.asList(h, c); + } + + /** + * Get h, the output of the cell. + * + * Has shape [batchSize, inSize, timeSeriesLength]. + */ + public SDVariable getOutput(){ + return h; + } + + /** + * Get c, the state of the cell. + * + * Has shape [batchSize, inSize, timeSeriesLength]. + */ + public SDVariable getState(){ + return c; + } + + private SDVariable lastOutput = null; + + /** + * Get y, the output of the cell, for the last time step. + * + * Has shape [batchSize, inSize]. + */ + public SDVariable getLastOutput(){ + if(lastOutput != null) + return lastOutput; + + lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1)); + return lastOutput; + } + + private SDVariable lastState = null; + + /** + * Get c, the state of the cell, for the last time step. + * + * Has shape [batchSize, inSize]. + */ + public SDVariable getLastState(){ + if(lastState != null) + return lastState; + + lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1)); + return lastState; + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/GRUWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/GRUWeights.java new file mode 100644 index 000000000..f95438ae3 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/GRUWeights.java @@ -0,0 +1,51 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell; + +/** + * The weight configuration of a GRU cell. For {@link GRUCell}. + * + */ +@EqualsAndHashCode(callSuper = true) +@Data +@Builder +public class GRUWeights extends RNNWeights { + + /** + * Reset and Update gate weights, with a shape of [inSize + numUnits, 2*numUnits]. + * + * The reset weights are the [:, 0:numUnits] subset and the update weights are the [:, numUnits:2*numUnits] subset. + */ + @NonNull + private SDVariable ruWeight; + + /** + * Cell gate weights, with a shape of [inSize + numUnits, numUnits] + */ + @NonNull + private SDVariable cWeight; + + /** + * Reset and Update gate bias, with a shape of [2*numUnits]. May be null. + * + * The reset bias is the [0:numUnits] subset and the update bias is the [numUnits:2*numUnits] subset. + */ + @NonNull + private SDVariable ruBias; + + /** + * Cell gate bias, with a shape of [numUnits]. May be null. + */ + @NonNull + private SDVariable cBias; + + @Override + public SDVariable[] args() { + return filterNonNull(ruWeight, cWeight, ruBias, cBias); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMWeights.java new file mode 100644 index 000000000..bf401d66c --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMWeights.java @@ -0,0 +1,57 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; + +/** + * The weight configuration of a LSTM layer. For {@link LSTMLayer} and {@link LSTMBlockCell}. + * + */ +@EqualsAndHashCode(callSuper = true) +@Data +@Builder +public class LSTMWeights extends RNNWeights { + + /** + * Input to hidden weights and hidden to hidden weights, with a shape of [inSize + numUnits, 4*numUnits]. + * + * Input to hidden and hidden to hidden are concatenated in dimension 0, + * so the input to hidden weights are [:inSize, :] and the hidden to hidden weights are [inSize:, :]. + */ + @NonNull + private SDVariable weights; + + /** + * Cell peephole (t-1) connections to input modulation gate, with a shape of [numUnits]. + */ + @NonNull + private SDVariable inputPeepholeWeights; + + /** + * Cell peephole (t-1) connections to forget gate, with a shape of [numUnits]. + */ + @NonNull + private SDVariable forgetPeepholeWeights; + + /** + * Cell peephole (t) connections to output gate, with a shape of [numUnits]. + */ + @NonNull + private SDVariable outputPeepholeWeights; + + /** + * Input to hidden and hidden to hidden biases, with shape [1, 4*numUnits]. + */ + @NonNull + private SDVariable bias; + + @Override + public SDVariable[] args() { + return filterNonNull(weights, inputPeepholeWeights, forgetPeepholeWeights, outputPeepholeWeights, bias); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/RNNWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/RNNWeights.java new file mode 100644 index 000000000..62e295d80 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/RNNWeights.java @@ -0,0 +1,35 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights; + +import java.util.Arrays; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.linalg.util.ArrayUtil; + +public abstract class RNNWeights { + public abstract SDVariable[] args(); + + protected static SDVariable[] filterNonNull(SDVariable... args){ + int count = 0; + for(SDVariable v : args){ + if(v != null){ + count++; + } + } + + SDVariable[] res = new SDVariable[count]; + + int i = 0; + + for(SDVariable v : args){ + if(v != null){ + res[i] = v; + i++; + } + } + + return res; + } + + public SDVariable[] argsWithInputs(SDVariable... inputs){ + return ArrayUtil.combine(inputs, args()); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/SRUWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/SRUWeights.java new file mode 100644 index 000000000..821895f17 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/SRUWeights.java @@ -0,0 +1,37 @@ +package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell; + +/** + * The weight configuration of a SRU layer. For {@link SRU} and {@link SRUCell}. + * + */ +@EqualsAndHashCode(callSuper = true) +@Data +@Builder +public class SRUWeights extends RNNWeights { + + /** + * Weights, with shape [inSize, 3*inSize]. + */ + @NonNull + private SDVariable weights; + + /** + * Biases, with shape [2*inSize]. + */ + @NonNull + private SDVariable bias; + + @Override + public SDVariable[] args() { + return new SDVariable[]{weights, bias}; + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java index 2aae9eda1..8ecdc4eac 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java @@ -16,14 +16,19 @@ package org.nd4j.autodiff.opvalidation; +import java.util.Arrays; import lombok.extern.slf4j.Slf4j; import org.junit.Test; +import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMBlockCellConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -59,23 +64,18 @@ public class RnnOpValidation extends BaseOpValidation { SDVariable b = sd.constant(Nd4j.rand(DataType.FLOAT, 4*nOut)); double fb = 1.0; - LSTMBlockCellConfiguration conf = LSTMBlockCellConfiguration.builder() - .xt(x) - .cLast(cLast) - .yLast(yLast) - .W(W) - .Wci(Wci) - .Wcf(Wcf) - .Wco(Wco) - .b(b) + LSTMConfiguration conf = LSTMConfiguration.builder() .peepHole(true) .forgetBias(fb) .clippingCellValue(0.0) .build(); - List v = sd.rnn().lstmBlockCell("lstm", conf); //Output order: i, c, f, o, z, h, y + LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b) + .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build(); + + LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y List toExec = new ArrayList<>(); - for(SDVariable sdv : v){ + for(SDVariable sdv : v.getAllOutputs()){ toExec.add(sdv.getVarName()); } @@ -167,23 +167,18 @@ public class RnnOpValidation extends BaseOpValidation { SDVariable b = sd.constant(Nd4j.zeros(DataType.FLOAT, 8)); double fb = 1.0; - LSTMBlockCellConfiguration conf = LSTMBlockCellConfiguration.builder() - .xt(x) - .cLast(cLast) - .yLast(yLast) - .W(W) - .Wci(Wci) - .Wcf(Wcf) - .Wco(Wco) - .b(b) + LSTMConfiguration conf = LSTMConfiguration.builder() .peepHole(false) .forgetBias(fb) .clippingCellValue(0.0) .build(); - List v = sd.rnn().lstmBlockCell("lstm", conf); //Output order: i, c, f, o, z, h, y + LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b) + .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build(); + + LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y List toExec = new ArrayList<>(); - for(SDVariable sdv : v){ + for(SDVariable sdv : v.getAllOutputs()){ toExec.add(sdv.getVarName()); } @@ -228,16 +223,14 @@ public class RnnOpValidation extends BaseOpValidation { SDVariable bc = sd.constant(Nd4j.rand(DataType.FLOAT, nOut)); double fb = 1.0; - GRUCellConfiguration conf = GRUCellConfiguration.builder() - .xt(x) - .hLast(hLast) - .Wru(Wru) - .Wc(Wc) - .bru(bru) - .bc(bc) + GRUWeights weights = GRUWeights.builder() + .ruWeight(Wru) + .cWeight(Wc) + .ruBias(bru) + .cBias(bc) .build(); - List v = sd.rnn().gru("gru", conf); + List v = sd.rnn().gru("gru", x, hLast, weights).getAllOutputs(); List toExec = new ArrayList<>(); for(SDVariable sdv : v){ toExec.add(sdv.getVarName()); diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/Pair.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/Pair.java index 8f5ca3888..31976080d 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/Pair.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/Pair.java @@ -23,6 +23,7 @@ import lombok.NoArgsConstructor; import java.io.Serializable; import java.util.Arrays; +import org.nd4j.base.Preconditions; /** * Simple pair implementation @@ -86,4 +87,10 @@ public class Pair implements Serializable { public static Pair pairOf(T key, E value) { return new Pair(key, value); } + + public static Pair fromArray(T[] arr){ + Preconditions.checkArgument(arr.length == 2, + "Can only create a pair from an array with two values, got %s", arr.length); + return new Pair<>(arr[0], arr[1]); + } } From bb41546ec79981b684b179d2aadd302db6a87479 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 5 Sep 2019 12:26:54 +1000 Subject: [PATCH 04/10] Small DataVec test fix (#239) Signed-off-by: AlexDBlack --- .../java/org/datavec/spark/transform/analysis/TestAnalysis.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java index 55efd7758..05058fea8 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java @@ -35,7 +35,7 @@ import org.datavec.spark.BaseSparkTest; import org.datavec.spark.transform.AnalyzeSpark; import org.joda.time.DateTimeZone; import org.junit.Test; -import org.nd4j.graph.DataType; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; From 87d873929f4984116057b83148fa9ee192455a73 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 5 Sep 2019 14:25:20 +1000 Subject: [PATCH 05/10] Small LapackTest fix (#240) Signed-off-by: AlexDBlack --- .../src/test/java/org/nd4j/linalg/api/blas/LapackTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/LapackTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/LapackTest.java index a7d024058..22e5f3705 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/LapackTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/LapackTest.java @@ -92,7 +92,7 @@ public class LapackTest extends BaseNd4jTest { @Test public void testCholeskyU() { - INDArray A = Nd4j.create(new double[] {2, -1, 2, -1, 2, -1, 2, -1, 2,}); + INDArray A = Nd4j.create(new double[] {3, -1, 2, -1, 3, -1, 2, -1, 3,}); A = A.reshape('f', 3, 3); INDArray O = Nd4j.create(A.dataType(), A.shape()); Nd4j.copy(A, O); From 52d279519393b3b2cdc92075e1cd0faa900e7dd3 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 5 Sep 2019 17:01:47 +1000 Subject: [PATCH 06/10] Another round of small fixes (#241) * Small base spark test fix; ROC toString for empty ROC Signed-off-by: Alex Black * More fixes Signed-off-by: Alex Black --- .../spark/dl4j-spark-nlp/pom.xml | 2 +- .../spark/text/BaseSparkTest.java | 31 ++++++++++++++++++- .../spark/BaseSparkKryoTest.java | 29 +++++++++++++++++ .../deeplearning4j/spark/BaseSparkTest.java | 4 ++- .../nd4j/evaluation/classification/ROC.java | 6 +++- 5 files changed, 68 insertions(+), 4 deletions(-) diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml index 16c4ac298..a4746be70 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml @@ -63,7 +63,7 @@ com.fasterxml.jackson.module jackson-module-scala_2.11 - ${jackson.version} + 2.6.7 diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java index 738daa647..152ef4db5 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java @@ -23,6 +23,9 @@ import org.junit.After; import org.junit.Before; import java.io.Serializable; +import java.lang.reflect.Field; +import java.util.Collections; +import java.util.Map; /** * Created by agibsonccc on 1/23/15. @@ -37,7 +40,9 @@ public abstract class BaseSparkTest implements Serializable { @After public void after() { - sc.close(); + if(sc != null) { + sc.close(); + } sc = null; } @@ -48,6 +53,30 @@ public abstract class BaseSparkTest implements Serializable { public JavaSparkContext getContext() { if (sc != null) return sc; + + //Ensure SPARK_USER environment variable is set for Spark tests + String u = System.getenv("SPARK_USER"); + Map env = System.getenv(); + if(u == null || u.isEmpty()) { + try { + Class[] classes = Collections.class.getDeclaredClasses(); + for (Class cl : classes) { + if ("java.util.Collections$UnmodifiableMap".equals(cl.getName())) { + Field field = cl.getDeclaredField("m"); + field.setAccessible(true); + Object obj = field.get(env); + Map map = (Map) obj; + String user = System.getProperty("user.name"); + if (user == null || user.isEmpty()) + user = "user"; + map.put("SPARK_USER", user); + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + // set to test mode SparkConf sparkConf = new SparkConf().setMaster("local[4]").set("spark.driver.host", "localhost") .setAppName("sparktest") diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java index bb3a7180e..1c794ebf6 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java @@ -19,6 +19,10 @@ package org.deeplearning4j.spark; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; +import java.lang.reflect.Field; +import java.util.Collections; +import java.util.Map; + /** * Created by Alex on 04/07/2017. */ @@ -30,6 +34,31 @@ public class BaseSparkKryoTest extends BaseSparkTest { return sc; } + //Ensure SPARK_USER environment variable is set for Spark Kryo tests + String u = System.getenv("SPARK_USER"); + if(u == null || u.isEmpty()){ + try { + Class[] classes = Collections.class.getDeclaredClasses(); + Map env = System.getenv(); + for (Class cl : classes) { + if ("java.util.Collections$UnmodifiableMap".equals(cl.getName())) { + Field field = cl.getDeclaredField("m"); + field.setAccessible(true); + Object obj = field.get(env); + Map map = (Map) obj; + String user = System.getProperty("user.name"); + if(user == null || user.isEmpty()) + user = "user"; + map.put("SPARK_USER", user); + } + } + } catch (Exception e){ + throw new RuntimeException(e); + } + } + + + SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").setAppName("sparktest"); sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java index d2a6e08e1..781e3dad2 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java @@ -74,7 +74,9 @@ public abstract class BaseSparkTest implements Serializable { @After public void after() { - sc.close(); + if(sc != null) { + sc.close(); + } sc = null; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java index 63a5a012a..c9f5cabdf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java @@ -75,7 +75,6 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.interval; @EqualsAndHashCode(callSuper = true, exclude = {"auc", "auprc", "probAndLabel", "exactAllocBlockSize", "rocCurve", "prCurve", "axis"}) @Data -@ToString(exclude = {"probAndLabel", "exactAllocBlockSize", "rocCurve", "prCurve"}) @JsonIgnoreProperties({"probAndLabel", "exactAllocBlockSize"}) @JsonSerialize(using = ROCSerializer.class) @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY) @@ -824,6 +823,11 @@ public class ROC extends BaseEvaluation { return sb.toString(); } + @Override + public String toString(){ + return stats(); + } + public double scoreForMetric(Metric metric){ switch (metric){ case AUROC: From 45017ec914c5286e63825000e4f68264e7a9c6bb Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 5 Sep 2019 17:55:52 +1000 Subject: [PATCH 07/10] Fixes for remote UI (#242) Signed-off-by: AlexDBlack --- .../ui/module/remote/RemoteReceiverModule.java | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/remote/RemoteReceiverModule.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/remote/RemoteReceiverModule.java index 4230d16cf..4709b89ba 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/remote/RemoteReceiverModule.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/remote/RemoteReceiverModule.java @@ -19,23 +19,20 @@ package org.deeplearning4j.ui.module.remote; import com.fasterxml.jackson.databind.JsonNode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.api.storage.*; -import org.deeplearning4j.ui.api.FunctionType; import org.deeplearning4j.ui.api.HttpMethod; import org.deeplearning4j.ui.api.Route; import org.deeplearning4j.ui.api.UIModule; import org.deeplearning4j.ui.i18n.I18NResource; +import play.mvc.Http; import play.mvc.Result; import play.mvc.Results; import javax.xml.bind.DatatypeConverter; -import java.io.File; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; -import static play.mvc.Http.Context.Implicit.request; - /** * * Used to receive UI updates remotely. @@ -73,7 +70,7 @@ public class RemoteReceiverModule implements UIModule { @Override public List getRoutes() { - Route r = new Route("/remoteReceive", HttpMethod.POST, FunctionType.Supplier, this::receiveData); + Route r = Route.request0Function("/remoteReceive", HttpMethod.POST, this::receiveData); return Collections.singletonList(r); } @@ -98,7 +95,7 @@ public class RemoteReceiverModule implements UIModule { return Collections.emptyList(); } - private Result receiveData() { + private Result receiveData(Http.Request request) { if (!enabled.get()) { return Results.forbidden( "UI server remote listening is currently disabled. Use UIServer.getInstance().enableRemoteListener()"); @@ -109,7 +106,7 @@ public class RemoteReceiverModule implements UIModule { "UI Server remote listener: no StatsStorage instance is set/available to store results"); } - JsonNode jn = request().body().asJson(); + JsonNode jn = request.body().asJson(); JsonNode type = jn.get("type"); JsonNode dataClass = jn.get("class"); JsonNode data = jn.get("data"); From 04596327c4ec2d5338089e5000abdcad0250bb99 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 5 Sep 2019 18:30:51 +1000 Subject: [PATCH 08/10] Ensure UI overview page is refreshed when loading saved net data (#243) Signed-off-by: AlexDBlack --- .../resources/deeplearning4jUiAssets/js/train/overview.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/resources/deeplearning4jUiAssets/js/train/overview.js b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/resources/deeplearning4jUiAssets/js/train/overview.js index 107488028..05c1ccace 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/resources/deeplearning4jUiAssets/js/train/overview.js +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/resources/deeplearning4jUiAssets/js/train/overview.js @@ -21,16 +21,16 @@ function selectStdevChart(fieldName) { $("#stdevUpdates").attr("class", "active"); } - renderOverviewPage(false); + renderOverviewPage(true); } /* ---------- Render page ---------- */ var lastUpdateTime = -1; var lastUpdateSession = ""; -function renderOverviewPage(firstLoad) { +function renderOverviewPage(forceupdate) { updateSessionWorkerSelect(); - if(firstLoad || !lastUpdateSession || lastUpdateSession == "" || lastUpdateSession != currSession){ + if(forceupdate || !lastUpdateSession || lastUpdateSession == "" || lastUpdateSession != currSession){ executeOverviewUpdate(); } else { //Check last update time first - see if data has actually changed... From a6de3b5d6fd92a5de7c72ea522d36fa9544957a3 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 5 Sep 2019 19:05:40 +1000 Subject: [PATCH 09/10] Small Arbiter UI fix (#244) Signed-off-by: AlexDBlack --- .../arbiter/ui/views/html/ArbiterUI.template.scala | 11 +++++------ .../arbiter/ui/views/ArbiterUI.scala.html | 4 ++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/arbiter/arbiter-ui/src/main/scala/org/deeplearning4j/arbiter/ui/views/html/ArbiterUI.template.scala b/arbiter/arbiter-ui/src/main/scala/org/deeplearning4j/arbiter/ui/views/html/ArbiterUI.template.scala index bfbfd9903..57fd8b899 100644 --- a/arbiter/arbiter-ui/src/main/scala/org/deeplearning4j/arbiter/ui/views/html/ArbiterUI.template.scala +++ b/arbiter/arbiter-ui/src/main/scala/org/deeplearning4j/arbiter/ui/views/html/ArbiterUI.template.scala @@ -1,4 +1,3 @@ - /******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. * @@ -223,12 +222,12 @@ Seq[Any](format.raw/*1.1*/("""