At cpp ops (#378)

* crelu op added

* crelu op added

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* minor fixes

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* crelu(bp)+transformOpValidation op

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* added ClipByAvgNorm and DepthwiseConv2DBp

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* ClipByAvgNorm passes forward check

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* EmbeddingLookup draft

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* DepthwiseConv2DB gradient check

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* EmbeddingLookup and DepthwiseConv2dBp finished + tests added

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* ImageResize draft

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* DepthwiseConv2DB gradient check

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* ImageResize passed tests except helper::resizeFunctor:Non implemented

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* replaced ImageResizeMethods enum by codegen

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* minor fixes

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* polished checkpoint (OPValidationSuite passed and mvn install build succesfull after codegen)

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* manually merged LSTMLayerTestCases from master
Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* MaximumBp added and tested

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* MergeAddBp draft

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* MergeMaxBp and MergeAvgBP added and tests passed

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* minor fix

* draft LSTMLayerBp (big relative layer in gradient check)

* LSTMLayerBp check

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* LSTMLayerBp check v2

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* requested changes (test passes)

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* LSTMLayer testcases passed gradientcheck

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* small LSTMLayer testcase1 improvement (cLast, yLast)

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* Warnings issue solved

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* Fixes for MKLDNN LSTM layer helper

Signed-off-by: Alex Black <blacka101@gmail.com>

* stable version

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

Co-authored-by: raver119 <raver119@gmail.com>
Co-authored-by: Alex Black <blacka101@gmail.com>
master
Andrii T 2020-04-17 08:16:14 +03:00 committed by GitHub
parent 3967e039a5
commit 5fbb04531d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 1794 additions and 258 deletions

View File

@ -369,6 +369,7 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!");
REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !");
REQUIRE_TRUE(retLastH == retLastC, 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !");
REQUIRE_TRUE(hasInitH == hasInitC, 0, "LSTM_LAYER_MKLDNN operation: either both of or neither of initial C and initial H must be provided");
count = 0;
auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output
@ -498,7 +499,7 @@ PLATFORM_CHECK(lstmLayer, ENGINE_CPU) {
DataType WrType = Wr->dataType();
DataType bType = b != nullptr ? b->dataType() : (xType == DataType::HALF ? xType : DataType::FLOAT32);
DataType hIType = hI != nullptr ? hI->dataType() : xType;
DataType cIType = cI != nullptr ? hI->dataType() : xType;
DataType cIType = cI != nullptr ? cI->dataType() : xType;
DataType hType = h != nullptr ? h->dataType() : xType;
DataType hLType = hL != nullptr ? hL->dataType() : xType;
DataType cLType = cL != nullptr ? cL->dataType() : xType;
@ -509,7 +510,8 @@ PLATFORM_CHECK(lstmLayer, ENGINE_CPU) {
&& !hasSeqLen //Sequence length array not supported in MKL DNN
&& dataFormat < 2 //Data format - only 0 and 1 supported in MKL DNN- 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn]
&& directionMode < 4 //Direction mode - only 0-3 supported in MKL DNN (no extra dim option) - 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat
&& retLastH == retLastC; //Return both lastH and lastC, or return neither (not just 1 or other)
&& retLastH == retLastC //Return both lastH and lastC, or return neither (not just 1 or other)
&& hasInitH == hasInitC; //Need both or neither initial H and C
return block.isUseMKLDNN() && featuresSupported && (
(xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) ||

View File

@ -153,6 +153,7 @@ public abstract class DifferentialFunction {
public Map<String,Object> propertiesForFunction() {
Map<String,Field> fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
Map<String,Object> ret = new LinkedHashMap<>();
Preconditions.checkNotNull(fields, "DifferentialFunctionClassHolder returned null fields for %s - op has not been added to ImportClassMapping?", getClass());
for(val entry : fields.entrySet()) {
try {

View File

@ -24,6 +24,7 @@ import java.lang.String;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.enums.ImageResizeMethod;
public class SDImage extends SDOps {
public SDImage(SameDiff sameDiff) {
@ -254,6 +255,98 @@ public class SDImage extends SDOps {
return sd.updateVariableNameAndReference(out, name);
}
/**
* Resize images to size using the specified method.<br>
*
* @param input 4D image [NHWC] (NUMERIC type)
* @param size new height and width (INT type)
* @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False.
* @param antialis Whether to use an anti-aliasing filter when downsampling an image
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
* ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling.
* ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0.
* ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation.
* ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases.
* ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp.
* @return output Output image (NUMERIC type)
*/
public SDVariable imageResize(SDVariable input, SDVariable size, boolean preserveAspectRatio,
boolean antialis, ImageResizeMethod ImageResizeMethod) {
SDValidation.validateNumerical("imageResize", "input", input);
SDValidation.validateInteger("imageResize", "size", size);
return new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, preserveAspectRatio, antialis, ImageResizeMethod).outputVariable();
}
/**
* Resize images to size using the specified method.<br>
*
* @param name name May be null. Name for the output variable
* @param input 4D image [NHWC] (NUMERIC type)
* @param size new height and width (INT type)
* @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False.
* @param antialis Whether to use an anti-aliasing filter when downsampling an image
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
* ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling.
* ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0.
* ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation.
* ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases.
* ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp.
* @return output Output image (NUMERIC type)
*/
public SDVariable imageResize(String name, SDVariable input, SDVariable size,
boolean preserveAspectRatio, boolean antialis, ImageResizeMethod ImageResizeMethod) {
SDValidation.validateNumerical("imageResize", "input", input);
SDValidation.validateInteger("imageResize", "size", size);
SDVariable out = new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, preserveAspectRatio, antialis, ImageResizeMethod).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Resize images to size using the specified method.<br>
*
* @param input 4D image [NHWC] (NUMERIC type)
* @param size new height and width (INT type)
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
* ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling.
* ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0.
* ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation.
* ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases.
* ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp.
* @return output Output image (NUMERIC type)
*/
public SDVariable imageResize(SDVariable input, SDVariable size,
ImageResizeMethod ImageResizeMethod) {
SDValidation.validateNumerical("imageResize", "input", input);
SDValidation.validateInteger("imageResize", "size", size);
return new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, false, false, ImageResizeMethod).outputVariable();
}
/**
* Resize images to size using the specified method.<br>
*
* @param name name May be null. Name for the output variable
* @param input 4D image [NHWC] (NUMERIC type)
* @param size new height and width (INT type)
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
* ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling.
* ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0.
* ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation.
* ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases.
* ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp.
* @return output Output image (NUMERIC type)
*/
public SDVariable imageResize(String name, SDVariable input, SDVariable size,
ImageResizeMethod ImageResizeMethod) {
SDValidation.validateNumerical("imageResize", "input", input);
SDValidation.validateInteger("imageResize", "size", size);
SDVariable out = new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, false, false, ImageResizeMethod).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Greedily selects a subset of bounding boxes in descending order of score<br>
*

View File

@ -24,6 +24,7 @@ import java.lang.String;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.enums.PartitionMode;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.indexing.conditions.Condition;
@ -32,6 +33,67 @@ public class SDMath extends SDOps {
super(sameDiff);
}
/**
* Clips tensor values to a maximum average L2-norm.<br>
*
* @param x Input variable (NUMERIC type)
* @param clipValue Value for clipping
* @param dimensions Dimensions to reduce over (Size: AtLeast(min=0))
* @return output Output variable (NUMERIC type)
*/
public SDVariable clipByAvgNorm(SDVariable x, double clipValue, int... dimensions) {
SDValidation.validateNumerical("ClipByAvgNorm", "x", x);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd,x, clipValue, dimensions).outputVariable();
}
/**
* Clips tensor values to a maximum average L2-norm.<br>
*
* @param name name May be null. Name for the output variable
* @param x Input variable (NUMERIC type)
* @param clipValue Value for clipping
* @param dimensions Dimensions to reduce over (Size: AtLeast(min=0))
* @return output Output variable (NUMERIC type)
*/
public SDVariable clipByAvgNorm(String name, SDVariable x, double clipValue, int... dimensions) {
SDValidation.validateNumerical("ClipByAvgNorm", "x", x);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd,x, clipValue, dimensions).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Looks up ids in a list of embedding tensors.<br>
*
* @param x Input tensor (NUMERIC type)
* @param indices A Tensor containing the ids to be looked up. (INT type)
* @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div'
* @return output Shifted output (NUMERIC type)
*/
public SDVariable embeddingLookup(SDVariable x, SDVariable indices, PartitionMode PartitionMode) {
SDValidation.validateNumerical("EmbeddingLookup", "x", x);
SDValidation.validateInteger("EmbeddingLookup", "indices", indices);
return new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable();
}
/**
* Looks up ids in a list of embedding tensors.<br>
*
* @param name name May be null. Name for the output variable
* @param x Input tensor (NUMERIC type)
* @param indices A Tensor containing the ids to be looked up. (INT type)
* @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div'
* @return output Shifted output (NUMERIC type)
*/
public SDVariable embeddingLookup(String name, SDVariable x, SDVariable indices,
PartitionMode PartitionMode) {
SDValidation.validateNumerical("EmbeddingLookup", "x", x);
SDValidation.validateInteger("EmbeddingLookup", "indices", indices);
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Elementwise absolute value operation: out = abs(x)<br>
*

View File

@ -30,6 +30,30 @@ public class SDNN extends SDOps {
super(sameDiff);
}
/**
* Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.<br>
*
* @param x Input variable (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public SDVariable cReLU(SDVariable x) {
SDValidation.validateNumerical("CReLU", "x", x);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd,x).outputVariable();
}
/**
* Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.<br>
*
* @param name name May be null. Name for the output variable
* @param x Input variable (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public SDVariable cReLU(String name, SDVariable x) {
SDValidation.validateNumerical("CReLU", "x", x);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd,x).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Neural network batch normalization operation.<br>
* For details, see <a href="https://arxiv.org/abs/1502.03167">https://arxiv.org/abs/1502.03167</a><br>

View File

@ -0,0 +1,43 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;
/**
* ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
* ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling.
* ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0.
* ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation.
* ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases.
* ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. */
public enum ImageResizeMethod {
ResizeBilinear,
ResizeBicubic,
ResizeNearest,
ResizeGaussian,
ResizeLanczos5,
ResizeMitchelcubic,
ResizeArea
}

View File

@ -0,0 +1,27 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;
/**
* partition_mode == 0 - i.e. 'mod' , 1 - 'div' */
public enum PartitionMode {
MOD,
DIV
}

View File

@ -93,6 +93,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.grid.FreeGridOp.class,
org.nd4j.linalg.api.ops.impl.image.CropAndResize.class,
org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches.class,
org.nd4j.linalg.api.ops.impl.image.ImageResize.class,
org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression.class,
org.nd4j.linalg.api.ops.impl.image.NonMaxSuppressionV3.class,
org.nd4j.linalg.api.ops.impl.image.ResizeBilinear.class,
@ -127,6 +128,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2DBp.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization.class,
@ -146,6 +148,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell.class,
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell.class,
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer.class,
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayerBp.class,
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock.class,
org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU.class,
org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell.class,
@ -322,9 +325,12 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.shape.Unstack.class,
org.nd4j.linalg.api.ops.impl.shape.ZerosLike.class,
org.nd4j.linalg.api.ops.impl.shape.bp.ConcatBp.class,
org.nd4j.linalg.api.ops.impl.shape.bp.MergeMaxBp.class,
org.nd4j.linalg.api.ops.impl.shape.bp.MergeAvgBp.class,
org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp.class,
org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp.class,
org.nd4j.linalg.api.ops.impl.shape.bp.TileBp.class,
org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup.class,
org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray.class,
org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayConcat.class,
org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayGather.class,
@ -354,6 +360,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf.class,
org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN.class,
org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform.class,
org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm.class,
org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm.class,
org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNormBp.class,
org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue.class,
@ -365,6 +372,8 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CReluBp.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance.class,
@ -406,6 +415,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.Max.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.Min.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.MirrorPad.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention.class,
@ -492,11 +502,13 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.ModBpOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MergeAddBp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SquaredDifferenceBpOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Not.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or.class,

View File

@ -0,0 +1,67 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.image;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.enums.ImageResizeMethod;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class ImageResize extends DynamicCustomOp {
@Override
public String opName() {
return "image_resize";
}
public ImageResize(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable size, boolean preserveAspectRatio, boolean antialias, ImageResizeMethod method) {
super("image_resize", sameDiff, new SDVariable[]{in, size});
addBArgument(preserveAspectRatio, antialias);
addIArgument(method.ordinal());
}
public ImageResize(@NonNull INDArray in, @NonNull INDArray size, boolean preserveAspectRatio, boolean antialias, ImageResizeMethod method) {
super("image_resize", new INDArray[]{in, size}, null);
Preconditions.checkArgument(in.rank()==4,"expected input message in NHWC format i.e [batchSize, height, width, channels]");
addBArgument(preserveAspectRatio, antialias);
addIArgument(method.ordinal());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -56,9 +56,11 @@ public class DepthwiseConv2D extends DynamicCustomOp {
protected Conv2DConfig config;
public DepthwiseConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable input,
@NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) {
this(sameDiff, wrapFilterNull(input, weights, bias), conv2DConfig);
}
@Builder(builderMethodName = "sameDiffBuilder")
@ -71,14 +73,14 @@ public class DepthwiseConv2D extends DynamicCustomOp {
addArgs();
}
public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config) {
super(inputs, outputs);
this.config = config;
addArgs();
}
public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config) {
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
}
@ -127,7 +129,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
@Override
public Map<String, Object> propertiesForFunction() {
if(config == null && !iArguments.isEmpty()){
if (config == null && !iArguments.isEmpty()) {
config = Conv2DConfig.builder()
.kH(iArguments.get(0))
.kW(iArguments.get(1))
@ -308,7 +310,9 @@ public class DepthwiseConv2D extends DynamicCustomOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not implemented yet");
SDVariable bias = args().length==2 ? null : arg(2);
return Arrays.asList(new DepthwiseConv2DBp(sameDiff, arg(0), arg(1), bias, f1.get(0), this.config).outputVariables());
}
@ -323,7 +327,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));

View File

@ -0,0 +1,150 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter;
import org.nd4j.imports.descriptors.properties.adapters.NDArrayShapeAdapter;
import org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater;
import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.util.ArrayUtil;
import java.lang.reflect.Field;
import java.util.*;
/**
* Backpropagation for Depthwise Conv2D operation
*/
@Slf4j
@Getter
@NoArgsConstructor
public class DepthwiseConv2DBp extends DynamicCustomOp {
protected Conv2DConfig config;
public DepthwiseConv2DBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull SDVariable gradO, @NonNull Conv2DConfig config){
super(sameDiff, wrapFilterNull(input, weights, bias, gradO));
this.config = config;
addArgs();
}
public DepthwiseConv2DBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull SDVariable gradO, @NonNull Conv2DConfig config){
super(sameDiff, wrapFilterNull(input, weights, gradO));
this.config = config;
addArgs();
}
@Override
public long[] iArgs() {
if (iArguments.size() == 0)
addArgs();
return super.iArgs();
}
protected void addArgs() {
addIArgument(config.getKH(),
config.getKW(),
config.getSH(),
config.getSW(),
config.getPH(),
config.getPW(),
config.getDH(),
config.getDW(),
ArrayUtil.fromBoolean(config.isSameMode()),
config.getDataFormat().equalsIgnoreCase(Conv2DConfig.NCHW) ? 0 : 1);
}
@Override
public Object getValue(Field property) {
if (config == null) {
config = Conv2DConfig.builder().build();
}
try {
val t = config.getValue(property);
return t;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public Map<String, Object> propertiesForFunction() {
if (config == null && !iArguments.isEmpty()) {
config = Conv2DConfig.builder()
.kH(iArguments.get(0))
.kW(iArguments.get(1))
.sH(iArguments.get(2))
.sW(iArguments.get(3))
.pH(iArguments.get(4))
.pW(iArguments.get(5))
.dH(iArguments.get(6))
.dW(iArguments.get(7))
.isSameMode(iArguments.get(8) == 1)
.dataFormat(iArguments.get(9) == 1 ? Conv2DConfig.NHWC : Conv2DConfig.NCHW)
.build();
}
return config.toProperties();
}
@Override
public boolean isConfigProperties() {
return true;
}
@Override
public String configFieldName() {
return "config";
}
@Override
public String opName() {
return "depthwise_conv2d_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
int n = args().length;
List<DataType> list = new ArrayList<DataType>();
for(int i=0;i<n-1;i++){list.add(inputDataTypes.get(0));}
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return list;
}
@Override
public int getNumOutputs(){
return args().length == 4 ? 3 : 2;}
}

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -23,6 +24,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2DBp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.shade.guava.primitives.Booleans;
@ -60,6 +62,7 @@ import java.util.Map;
* 1: output at last step hL - rank 3 or 4, depends on DirectionMode and dataFormat<<br>
* 2: cell state at last step cL - same shape as in hL<br>
*/
@NoArgsConstructor
public class LSTMLayer extends DynamicCustomOp {
@Getter
@ -68,14 +71,18 @@ public class LSTMLayer extends DynamicCustomOp {
@Getter
private LSTMLayerWeights weights;
private SDVariable cLast;
private SDVariable yLast;
private SDVariable maxTSLength;
public LSTMLayer() {
}
public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, LSTMLayerWeights weights, LSTMLayerConfig configuration) {
super(null, sameDiff, weights.argsWithInputs(x, maxTSLength, cLast, yLast));
this.configuration = configuration;
this.weights = weights;
this.cLast = cLast;
this.yLast = yLast;
this.maxTSLength = maxTSLength;
addIArgument(iArgs());
addTArgument(tArgs());
addBArgument(bArgs(weights, maxTSLength, yLast, cLast));
@ -124,7 +131,13 @@ public class LSTMLayer extends DynamicCustomOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> grads) {
throw new UnsupportedOperationException("Not yet implemented");
int i=0;
SDVariable grad0 = this.configuration.isRetFullSequence() ? grads.get(i++): null;
SDVariable grad1 = this.configuration.isRetLastH() ? grads.get(i++): null;
SDVariable grad2 = this.configuration.isRetLastC() ? grads.get(i++): null;
return Arrays.asList(new LSTMLayerBp(sameDiff, arg(0), this.cLast, this.yLast, this.maxTSLength,
this.weights, this.configuration, grad0, grad1,grad2).outputVariables());
}
@ -155,7 +168,7 @@ public class LSTMLayer extends DynamicCustomOp {
}
public <T> boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) {
protected <T> boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) {
return new boolean[]{
weights.hasBias(), // hasBiases: B_ARG(0)
maxTSLength != null, // hasSeqLen: B_ARG(1)
@ -169,6 +182,16 @@ public class LSTMLayer extends DynamicCustomOp {
}
@Override
public boolean isConfigProperties() {
return true;
}
@Override
public String configFieldName() {
return "configuration";
}
@Override
public int getNumOutputs(){

View File

@ -0,0 +1,176 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.shade.guava.primitives.Booleans;
import javax.xml.crypto.Data;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
/**
* LSTM layer backpropagation
*/
@NoArgsConstructor
public class LSTMLayerBp extends DynamicCustomOp {
@Getter
private LSTMLayerConfig configuration;
@Getter
private LSTMLayerWeights weights;
private SDVariable cLast;
private SDVariable yLast;
private SDVariable maxTSLength;
public LSTMLayerBp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, @NonNull LSTMLayerWeights weights, @NonNull LSTMLayerConfig configuration,
SDVariable dLdh, SDVariable dLdhL, SDVariable dLdcL) {
super("lstmLayer_bp", sameDiff, wrapFilterNull(x, weights.getWeights(), weights.getRWeights(), weights.getBias(),
maxTSLength, yLast, cLast, weights.getPeepholeWeights(), dLdh, dLdhL, dLdcL));
this.configuration = configuration;
this.weights = weights;
this.cLast = cLast;
this.yLast = yLast;
this.maxTSLength = maxTSLength;
addIArgument(iArgs());
addTArgument(tArgs());
addBArgument(bArgs(weights, maxTSLength, yLast, cLast));
Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(),
"You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence methods in LSTMLayerConfig builder to specify them");
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
DataType dt = inputDataTypes.get(1);
Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", dt);
ArrayList<DataType> list = new ArrayList<>();
list.add(dt); // dLdx
list.add(dt); // dLdWx
list.add(dt); // dLdWr
if (this.weights.hasBias()) {
list.add(dt);
} // dLdb
if (this.maxTSLength != null) {
list.add(dt);
} // dLdSl
if (this.yLast != null) {
list.add(dt);
} //dLdhI
if (this.cLast != null) {
list.add(dt);
} // dLdcI
if (this.weights.hasPH()) {
list.add(dt);
} // dLdWp
return list;
}
@Override
public String opName() {
return "lstmLayer_bp";
}
@Override
public Map<String, Object> propertiesForFunction() {
return configuration.toProperties(true, true);
}
public long[] iArgs() {
return new long[]{
configuration.getLstmdataformat().ordinal(),// INT_ARG(0)
configuration.getDirectionMode().ordinal(), // INT_ARG(1)
configuration.getGateAct().ordinal(), // INT_ARG(2)
configuration.getOutAct().ordinal(), // INT_ARG(3)
configuration.getCellAct().ordinal() // INT_ARG(4)
};
}
public double[] tArgs() {
return new double[]{this.configuration.getCellClip()}; // T_ARG(0)
}
protected <T> boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) {
return new boolean[]{
weights.hasBias(), // hasBiases: B_ARG(0)
maxTSLength != null, // hasSeqLen: B_ARG(1)
yLast != null, // hasInitH: B_ARG(2)
cLast != null, // hasInitC: B_ARG(3)
weights.hasPH(), // hasPH: B_ARG(4)
configuration.isRetFullSequence(), //retFullSequence: B_ARG(5)
configuration.isRetLastH(), // retLastH: B_ARG(6)
configuration.isRetLastC() // retLastC: B_ARG(7)
};
}
@Override
public boolean isConfigProperties() {
return true;
}
@Override
public String configFieldName() {
return "configuration";
}
@Override
public int getNumOutputs() {
return Booleans.countTrue(
true,
true,
true,
weights.hasBias(),
this.maxTSLength != null,
this.yLast != null,
this.cLast != null,
weights.hasPH()
);
}
}

View File

@ -15,8 +15,10 @@
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
@ -26,9 +28,10 @@ import java.util.Map;
@Builder
@Data
@AllArgsConstructor
@NoArgsConstructor
public class LSTMLayerConfig {
/**
* notations <br>
* for unidirectional:
@ -90,23 +93,23 @@ public class LSTMLayerConfig {
* Cell clipping value, if it = 0 then do not apply clipping
*/
@Builder.Default
private double cellClip; //T_ARG(0)
private double cellClip = 0; //T_ARG(0)
public Map<String, Object> toProperties(boolean includeLSTMDataFormat, boolean includeLSTMDirectionMode) {
Map<String, Object> ret = new LinkedHashMap<>();
ret.put("gateAct", gateAct.ordinal());
ret.put("outAct", outAct.ordinal());
ret.put("cellAct", cellAct.ordinal());
ret.put("gateAct", gateAct.toString());
ret.put("outAct", outAct.toString());
ret.put("cellAct", cellAct.toString());
ret.put("retFullSequence", retFullSequence);
ret.put("retLastH", retLastH);
ret.put("retLastC", retLastC);
ret.put("cellClip", cellClip);
if (includeLSTMDataFormat)
ret.put("LSTMDataFormat", lstmdataformat.ordinal());
ret.put("lstmdataformat", lstmdataformat.toString());
if (includeLSTMDirectionMode)
ret.put("LSTMDirectionMode", directionMode.ordinal());
ret.put("directionMode", directionMode.toString());
return ret;
}

View File

@ -24,15 +24,13 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.shape.bp.MergeAvgBp;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.*;
@Slf4j
public class MergeAvg extends DynamicCustomOp {
@ -74,12 +72,8 @@ public class MergeAvg extends DynamicCustomOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
int nArgs = args().length;
SDVariable gradient = sameDiff.setupFunction(i_v.get(0)).div(nArgs);
List<SDVariable> ret = new ArrayList<>();
for (int i = 0; i < args().length; i++)
ret.add(gradient);
return ret;
return Arrays.asList(new MergeAvgBp(sameDiff, args(), i_v.get(0)).outputVariables());
}
@Override

View File

@ -24,14 +24,12 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.shape.bp.MergeMaxBp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.*;
@Slf4j
public class MergeMax extends DynamicCustomOp {
@ -71,14 +69,8 @@ public class MergeMax extends DynamicCustomOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable gradient = sameDiff.setupFunction(i_v.get(0));
List<SDVariable> ret = new ArrayList<>();
SDVariable out = outputVariable();
for (int i = 0; i < args().length; i++){
SDVariable isMax = out.eq(arg(i)).castTo(arg(i).dataType());
ret.add(isMax.mul(gradient));
}
return ret;
return Arrays.asList(new MergeMaxBp(sameDiff, args(), i_v.get(0)).outputVariables());
}
@Override

View File

@ -0,0 +1,57 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.shape.bp;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.ArrayList;
import java.util.List;
@NoArgsConstructor
public class MergeAvgBp extends DynamicCustomOp {
public MergeAvgBp(SameDiff sameDiff, @NonNull SDVariable[] inputs, @NonNull SDVariable gradO) {
super("mergeavg_bp", sameDiff, ArrayUtils.add(inputs, gradO));
}
@Override
public String opName() {
return "mergeavg_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
ArrayList<DataType> list = new ArrayList<DataType>();
for (int i = 0; i < args().length - 1; i++) {
list.add(inputDataTypes.get(0));
}
return list;
}
@Override
public int getNumOutputs() {
return args().length - 1;
}
}

View File

@ -0,0 +1,56 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.shape.bp;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.ArrayList;
import java.util.List;
@NoArgsConstructor
public class MergeMaxBp extends DynamicCustomOp {
public MergeMaxBp(SameDiff sameDiff, @NonNull SDVariable[] inputs, @NonNull SDVariable gradO) {
super("mergemax_bp", sameDiff, ArrayUtils.add(inputs, gradO));
}
@Override
public String opName() {
return "mergemax_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
List<DataType> list = new ArrayList<DataType>();
for (int i=0; i< args().length-1;i++){
list.add(inputDataTypes.get(0));
}
return list;
}
@Override
public int getNumOutputs(){
return args().length-1;
}
}

View File

@ -0,0 +1,71 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.enums.PartitionMode;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class EmbeddingLookup extends DynamicCustomOp {
public EmbeddingLookup(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable indices, PartitionMode partitionMode) {
super("embedding_lookup", sameDiff, new SDVariable[]{in, indices});
addIArgument(partitionMode.ordinal());
}
public EmbeddingLookup(@NonNull INDArray in, @NonNull INDArray indices, PartitionMode partitionMode, INDArray output) {
super("embedding_lookup", new INDArray[]{in, indices}, wrapOrNull(output));
addIArgument(partitionMode.ordinal());
}
public EmbeddingLookup(@NonNull INDArray in, INDArray output, PartitionMode partitionMode, @NonNull int... indices) {
super("embedding_lookup", new INDArray[]{in, Nd4j.createFromArray(indices)}, wrapOrNull(output));
addIArgument(partitionMode.ordinal());
}
@Override
public String opName() {
return "embedding_lookup";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(1).isIntType(), "Input datatype must be integer point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -0,0 +1,71 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.clip;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class ClipByAvgNorm extends DynamicCustomOp {
private double clipValue;
public ClipByAvgNorm(SameDiff sameDiff, SDVariable x, double clipValue, int... dimensions) {
super("clipbyavgnorm", sameDiff, new SDVariable[]{x});
this.clipValue = clipValue;
this.dimensions = dimensions;
addIArgument(dimensions);
addTArgument(clipValue);
}
public ClipByAvgNorm(INDArray in, double clipValue, int... dimensions){
this(in, null, clipValue, dimensions);
}
public ClipByAvgNorm(INDArray in, INDArray out, double clipValue, int... dimensions){
super("clipbyavgnorm", new INDArray[]{in}, wrapOrNull(out), Collections.singletonList(clipValue), dimensions);
}
@Override
public String opName() {
return "clipbyavgnorm";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> grad) {
throw new UnsupportedOperationException("Not yet implemented"); }
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
return inputDataTypes;
}
}

View File

@ -0,0 +1,65 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.base.Preconditions;
import java.util.Collections;
import java.util.List;
import lombok.Getter;
import lombok.NonNull;
@NoArgsConstructor
public class CReLU extends DynamicCustomOp {
public CReLU(SameDiff sd, SDVariable input) {
super(sd, new SDVariable[]{input});
}
public CReLU(@NonNull INDArray input) {
super(new INDArray[]{input}, null);
}
@Override
public String opName() {
return "crelu";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
return Collections.singletonList(new CReluBp(sameDiff, arg(), i_v.get(0)).outputVariable());
}
}

View File

@ -0,0 +1,59 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.base.Preconditions;
import java.util.Collections;
import java.util.List;
import lombok.Getter;
import lombok.NonNull;
@NoArgsConstructor
public class CReluBp extends DynamicCustomOp {
public CReluBp(SameDiff sd, SDVariable input, SDVariable epsilonNext) {
super(sd, new SDVariable[]{input, epsilonNext});
}
public CReluBp(@NonNull INDArray input, @NonNull INDArray epsilonNext, INDArray output) {
super(new INDArray[]{input, epsilonNext}, wrapOrNull(output));
}
@Override
public String opName() {
return "crelu_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -73,12 +73,7 @@ public class Max extends BaseDynamicTransformOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
//TODO Switch to maximum_bp op - https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/broadcastable/maximum.cpp
SDVariable max = outputVariables()[0];
SDVariable eq1 = sameDiff.eq(larg(), max).castTo(arg(0).dataType());
SDVariable eq2 = sameDiff.eq(rarg(), max).castTo(arg(1).dataType());
return Arrays.asList(eq1.mul(f1.get(0)), eq2.mul(f1.get(0)));
return Arrays.asList(new MaximumBp(sameDiff, arg(0), arg(1), f1.get(0)).outputVariables());
}
@Override

View File

@ -0,0 +1,48 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.ArrayList;
import java.util.List;
@NoArgsConstructor
public class MaximumBp extends DynamicCustomOp {
public MaximumBp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y, @NonNull SDVariable gradO) {
super("maximum_bp",sameDiff, new SDVariable[]{x,y, gradO});
}
@Override
public String opName() {
return "maximum_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
List<DataType> list = new ArrayList<DataType>();
list.add(inputDataTypes.get(0));
list.add(inputDataTypes.get(0));
return list;
}
}

View File

@ -18,14 +18,19 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MergeAddBp;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@ -70,11 +75,8 @@ public class MergeAddOp extends BaseDynamicTransformOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable gradient = sameDiff.setupFunction(i_v.get(0));
List<SDVariable> ret = new ArrayList<>();
for (int i = 0; i < args().length; i++)
ret.add(gradient);
return ret;
return Arrays.asList(new MergeAddBp(sameDiff, args(), i_v.get(0)).outputVariables());
}

View File

@ -0,0 +1,54 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@NoArgsConstructor
public class MergeAddBp extends DynamicCustomOp {
public MergeAddBp(SameDiff sameDiff, @NonNull SDVariable[] inputs, @NonNull SDVariable gradO) {
super("mergeadd_bp", sameDiff, ArrayUtils.add(inputs, gradO));
}
@Override
public String opName() {
return "mergeadd_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
ArrayList<DataType> list = new ArrayList<DataType>();
for (int i=0; i< args().length-1;i++){list.add(inputDataTypes.get(0));}
return list;
}
@Override
public int getNumOutputs(){
return args().length-1;
}
}

View File

@ -21,6 +21,7 @@ package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.base.Preconditions;
import org.nd4j.enums.ImageResizeMethod;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.NDValidation;
import org.nd4j.linalg.factory.Nd4j;
@ -134,6 +135,49 @@ public class NDImage {
return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.HsvToRgb(input))[0];
}
/**
* Resize images to size using the specified method.<br>
*
* @param input 4D image [NHWC] (NUMERIC type)
* @param size new height and width (INT type)
* @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False.
* @param antialis Whether to use an anti-aliasing filter when downsampling an image
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
* ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling.
* ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0.
* ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation.
* ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases.
* ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp.
* @return output Output image (NUMERIC type)
*/
public INDArray imageResize(INDArray input, INDArray size, boolean preserveAspectRatio,
boolean antialis, ImageResizeMethod ImageResizeMethod) {
NDValidation.validateNumerical("imageResize", "input", input);
NDValidation.validateInteger("imageResize", "size", size);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.ImageResize(input, size, preserveAspectRatio, antialis, ImageResizeMethod))[0];
}
/**
* Resize images to size using the specified method.<br>
*
* @param input 4D image [NHWC] (NUMERIC type)
* @param size new height and width (INT type)
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
* ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling.
* ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0.
* ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation.
* ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases.
* ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp.
* @return output Output image (NUMERIC type)
*/
public INDArray imageResize(INDArray input, INDArray size, ImageResizeMethod ImageResizeMethod) {
NDValidation.validateNumerical("imageResize", "input", input);
NDValidation.validateInteger("imageResize", "size", size);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.ImageResize(input, size, false, false, ImageResizeMethod))[0];
}
/**
* Greedily selects a subset of bounding boxes in descending order of score<br>
*

View File

@ -21,6 +21,7 @@ package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.base.Preconditions;
import org.nd4j.enums.PartitionMode;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.NDValidation;
@ -31,6 +32,34 @@ public class NDMath {
public NDMath() {
}
/**
* Clips tensor values to a maximum average L2-norm.<br>
*
* @param x Input variable (NUMERIC type)
* @param clipValue Value for clipping
* @param dimensions Dimensions to reduce over (Size: AtLeast(min=0))
* @return output Output variable (NUMERIC type)
*/
public INDArray clipByAvgNorm(INDArray x, double clipValue, int... dimensions) {
NDValidation.validateNumerical("ClipByAvgNorm", "x", x);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(x, clipValue, dimensions))[0];
}
/**
* Looks up ids in a list of embedding tensors.<br>
*
* @param x Input tensor (NUMERIC type)
* @param indices A Tensor containing the ids to be looked up. (INT type)
* @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div'
* @return output Shifted output (NUMERIC type)
*/
public INDArray embeddingLookup(INDArray x, INDArray indices, PartitionMode PartitionMode) {
NDValidation.validateNumerical("EmbeddingLookup", "x", x);
NDValidation.validateInteger("EmbeddingLookup", "indices", indices);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(x, indices, PartitionMode))[0];
}
/**
* Elementwise absolute value operation: out = abs(x)<br>
*

View File

@ -29,6 +29,17 @@ public class NDNN {
public NDNN() {
}
/**
* Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.<br>
*
* @param x Input variable (NUMERIC type)
* @return output Output variable (NUMERIC type)
*/
public INDArray cReLU(INDArray x) {
NDValidation.validateNumerical("CReLU", "x", x);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(x))[0];
}
/**
* Neural network batch normalization operation.<br>
* For details, see <a href="https://arxiv.org/abs/1502.03167">https://arxiv.org/abs/1502.03167</a><br>

View File

@ -20,7 +20,6 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.Ignore;
@ -35,6 +34,7 @@ import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
@ -265,7 +265,7 @@ public class LayerOpValidation extends BaseOpValidation {
msg = "7 - upsampling2d, NCHW, 2x2 - " + Arrays.toString(inSizeNCHW);
inSize = inSizeNCHW;
in = sd.var("in", inSize);
out = sd.cnn().upsampling2d(in, 2, 2, true);
out = sd.cnn().upsampling2d(in, 2, 2, true);
break;
default:
throw new RuntimeException();
@ -1469,6 +1469,43 @@ public class LayerOpValidation extends BaseOpValidation {
}
}
@Test
public void testDepthwiseConv2D(){
int bS = 10;
int kernelHeight = 2;
int kernelWidth = 2;
int strideHeight = 2;
int strideWidth = 2;
int inChannels = 2;
int outChannels = 3;
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", Nd4j.rand(bS, inChannels, 5,5));
SDVariable weights = sd.var("weights", Nd4j.rand(DataType.DOUBLE, kernelHeight, kernelWidth, inChannels, outChannels));
SDVariable bias = sd.var("bias", Nd4j.rand(DataType.DOUBLE, inChannels*outChannels));
Conv2DConfig config = Conv2DConfig.builder()
.kH(kernelHeight)
.kW(kernelWidth)
.sH(strideHeight)
.sW(strideWidth)
.dataFormat("NCHW")
.build();
SDVariable out = sd.cnn.depthWiseConv2d(in, weights, bias, config);
SDVariable loss = sd.standardDeviation("loss", out, true);
loss.markAsLoss();
String err = OpValidation.validate(new TestCase(sd)
.gradientCheck(true)
);
assertNull(err);
}
@Test
public void LSTMLayerTestCase1() {
@ -1476,9 +1513,8 @@ public class LayerOpValidation extends BaseOpValidation {
int bS = 5;
int nIn = 3;
int numUnits = 7;
int sL = 10; //small just for test
int sL = 3; //small just for test
SameDiff sd = SameDiff.create();
// notations:
// bS - batch size, numExamples
@ -1492,50 +1528,66 @@ public class LayerOpValidation extends BaseOpValidation {
// T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, bS, nIn, sL));
for (boolean useCLast : new boolean[]{false, true}) {
for (boolean useYLast : new boolean[]{false, true}) {
SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", Nd4j.randn(DataType.DOUBLE, bS, nIn, sL));
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits));
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits));
SDVariable cLast = useCLast ? sd.var("cLast", Nd4j.zeros(DataType.DOUBLE, bS, numUnits)) : null;
SDVariable yLast = useYLast ? sd.var("yLast", Nd4j.zeros(DataType.DOUBLE, bS, numUnits)) : null;
LSTMLayerConfig c = LSTMLayerConfig.builder()
.lstmdataformat(LSTMDataFormat.NST)
.directionMode(LSTMDirectionMode.FWD)
.gateAct(LSTMActivations.SIGMOID)
.cellAct(LSTMActivations.TANH)
.outAct(LSTMActivations.TANH)
.retFullSequence(true)
.retLastC(true)
.retLastH(true)
.build();
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(
in, cLast, yLast, null,
LSTMLayerWeights.builder()
.weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits)))
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits)))
.peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.rand(DataType.FLOAT, 3 * numUnits)))
.bias(sd.var("bias", Nd4j.rand(DataType.FLOAT, 4 * numUnits))).build(),
c), c);
LSTMLayerConfig c = LSTMLayerConfig.builder()
.lstmdataformat(LSTMDataFormat.NST)
.directionMode(LSTMDirectionMode.FWD)
.gateAct(LSTMActivations.SIGMOID)
.cellAct(LSTMActivations.TANH)
.outAct(LSTMActivations.TANH)
.retFullSequence(true)
.retLastC(true)
.retLastH(true)
.build();
long[] out = new long[]{bS, numUnits, sL};
long[] hL = new long[]{bS, numUnits};
long[] cL = new long[]{bS, numUnits};
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(
in, cLast, yLast, null,
LSTMLayerWeights.builder()
.weights(sd.var("weights", Nd4j.randn(DataType.DOUBLE, nIn, 4 * numUnits)))
.rWeights(sd.var("rWeights", Nd4j.randn(DataType.DOUBLE, numUnits, 4 * numUnits)))
.peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.randn(DataType.DOUBLE, 3 * numUnits)))
.bias(sd.var("bias", Nd4j.rand(DataType.DOUBLE, 4 * numUnits))).build(),
c), c);
assertArrayEquals(out, outputs.getOutput().eval().shape());
assertArrayEquals(hL, outputs.getLastTimeStepOutput().eval().shape());
assertArrayEquals(cL, outputs.getLastCellStateOutput().eval().shape());
long[] out = new long[]{bS, numUnits, sL};
long[] hL = new long[]{bS, numUnits};
long[] cL = new long[]{bS, numUnits};
assertArrayEquals(out, outputs.getOutput().eval().shape());
assertArrayEquals(hL, outputs.getLastOutput().eval().shape());
assertArrayEquals(cL, outputs.getLastState().eval().shape());
sd.setLossVariables(outputs.getOutput(), outputs.getLastTimeStepOutput(), outputs.getTimeSeriesOutput());
String err = OpValidation.validate(new TestCase(sd)
.gradientCheck(true)
.testName("cLast=" + cLast + ", yLast=" + yLast)
);
assertNull(err);
}
}
}
@Test @Ignore //AB 2020/04/08 - https://github.com/eclipse/deeplearning4j/issues/8824
@Test
public void LSTMLayerTestCase2() {
int bS = 5;
int nIn = 3;
int numUnits = 7;
int sL = 10; //small just for test
int sL = 3; //small just for test
SameDiff sd = SameDiff.create();
@ -1549,11 +1601,11 @@ public class LayerOpValidation extends BaseOpValidation {
// NTS: shape [numExamples, timeLength, inOutSize]<br>
// for bidirectional:
// T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, sL, bS, nIn));
SDVariable in = sd.var("in", Nd4j.rand(DataType.DOUBLE, sL, bS, nIn));
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits));
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits));
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.DOUBLE, bS, numUnits));
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.DOUBLE, bS, numUnits));
LSTMLayerConfig c = LSTMLayerConfig.builder()
.lstmdataformat(LSTMDataFormat.TNS)
@ -1569,8 +1621,8 @@ public class LayerOpValidation extends BaseOpValidation {
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(
in, cLast, yLast, null,
LSTMLayerWeights.builder()
.weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits)))
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits)))
.weights(sd.var("weights", Nd4j.rand(DataType.DOUBLE, nIn, 4 * numUnits)))
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.DOUBLE, numUnits, 4 * numUnits)))
.build(),
c), c);
@ -1578,14 +1630,22 @@ public class LayerOpValidation extends BaseOpValidation {
long[] out = new long[]{sL, bS, numUnits};
assertArrayEquals(out, outputs.getOutput().eval().shape());
sd.setLossVariables(outputs.getOutput());
String err = OpValidation.validate(new TestCase(sd)
.gradientCheck(true)
);
assertNull(err);
}
@Test @Ignore //AB 2020/04/08 - https://github.com/eclipse/deeplearning4j/issues/8824
@Test
public void LSTMLayerTestCase3() {
int bS = 5;
int nIn = 3;
int numUnits = 7;
int sL = 10; //small just for test
int sL = 3; //small just for test
SameDiff sd = SameDiff.create();
@ -1599,14 +1659,14 @@ public class LayerOpValidation extends BaseOpValidation {
// NTS: shape [numExamples, timeLength, inOutSize]<br>
// for bidirectional:
// T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, bS, sL, nIn));
SDVariable in = sd.var("in", Nd4j.rand(DataType.DOUBLE, bS, sL, nIn));
// when directionMode >= 2 (BIDIR_CONCAT=3)
// Wx, Wr [2, nIn, 4*nOut]
// hI, cI [2, bS, nOut]
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, 2, bS, numUnits));
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, 2, bS, numUnits));
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.DOUBLE, 2, bS, numUnits));
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.DOUBLE, 2, bS, numUnits));
LSTMLayerConfig c = LSTMLayerConfig.builder()
.lstmdataformat(LSTMDataFormat.NTS)
@ -1622,8 +1682,8 @@ public class LayerOpValidation extends BaseOpValidation {
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(new String[]{"out"},
in, cLast, yLast, null,
LSTMLayerWeights.builder()
.weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, 2, nIn, 4 * numUnits)))
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, 2, numUnits, 4 * numUnits)))
.weights(sd.var("weights", Nd4j.rand(DataType.DOUBLE, 2, nIn, 4 * numUnits)))
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.DOUBLE, 2, numUnits, 4 * numUnits)))
.build(),
c), c);
@ -1631,5 +1691,17 @@ public class LayerOpValidation extends BaseOpValidation {
long[] out = new long[]{bS, sL, 2 * numUnits};
assertArrayEquals(out, outputs.getOutput().eval().shape());
sd.setLossVariables(outputs.getOutput());
String err = OpValidation.validate(new TestCase(sd)
.gradientCheck(true)
);
assertNull(err);
}
}