At cpp ops (#378)
* crelu op added * crelu op added Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * minor fixes Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * crelu(bp)+transformOpValidation op Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * added ClipByAvgNorm and DepthwiseConv2DBp Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * ClipByAvgNorm passes forward check Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * EmbeddingLookup draft Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * DepthwiseConv2DB gradient check Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * EmbeddingLookup and DepthwiseConv2dBp finished + tests added Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * ImageResize draft Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * DepthwiseConv2DB gradient check Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * ImageResize passed tests except helper::resizeFunctor:Non implemented Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * replaced ImageResizeMethods enum by codegen Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * minor fixes Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * polished checkpoint (OPValidationSuite passed and mvn install build succesfull after codegen) Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * manually merged LSTMLayerTestCases from master Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * MaximumBp added and tested Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * MergeAddBp draft Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * MergeMaxBp and MergeAvgBP added and tests passed Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * minor fix * draft LSTMLayerBp (big relative layer in gradient check) * LSTMLayerBp check Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * LSTMLayerBp check v2 Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * requested changes (test passes) Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * LSTMLayer testcases passed gradientcheck Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * small LSTMLayer testcase1 improvement (cLast, yLast) Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * Warnings issue solved Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * Fixes for MKLDNN LSTM layer helper Signed-off-by: Alex Black <blacka101@gmail.com> * stable version Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> Co-authored-by: raver119 <raver119@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>master
parent
3967e039a5
commit
5fbb04531d
|
@ -369,6 +369,7 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
|
|||
REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!");
|
||||
REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !");
|
||||
REQUIRE_TRUE(retLastH == retLastC, 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !");
|
||||
REQUIRE_TRUE(hasInitH == hasInitC, 0, "LSTM_LAYER_MKLDNN operation: either both of or neither of initial C and initial H must be provided");
|
||||
|
||||
count = 0;
|
||||
auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output
|
||||
|
@ -498,7 +499,7 @@ PLATFORM_CHECK(lstmLayer, ENGINE_CPU) {
|
|||
DataType WrType = Wr->dataType();
|
||||
DataType bType = b != nullptr ? b->dataType() : (xType == DataType::HALF ? xType : DataType::FLOAT32);
|
||||
DataType hIType = hI != nullptr ? hI->dataType() : xType;
|
||||
DataType cIType = cI != nullptr ? hI->dataType() : xType;
|
||||
DataType cIType = cI != nullptr ? cI->dataType() : xType;
|
||||
DataType hType = h != nullptr ? h->dataType() : xType;
|
||||
DataType hLType = hL != nullptr ? hL->dataType() : xType;
|
||||
DataType cLType = cL != nullptr ? cL->dataType() : xType;
|
||||
|
@ -509,7 +510,8 @@ PLATFORM_CHECK(lstmLayer, ENGINE_CPU) {
|
|||
&& !hasSeqLen //Sequence length array not supported in MKL DNN
|
||||
&& dataFormat < 2 //Data format - only 0 and 1 supported in MKL DNN- 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn]
|
||||
&& directionMode < 4 //Direction mode - only 0-3 supported in MKL DNN (no extra dim option) - 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat
|
||||
&& retLastH == retLastC; //Return both lastH and lastC, or return neither (not just 1 or other)
|
||||
&& retLastH == retLastC //Return both lastH and lastC, or return neither (not just 1 or other)
|
||||
&& hasInitH == hasInitC; //Need both or neither initial H and C
|
||||
|
||||
return block.isUseMKLDNN() && featuresSupported && (
|
||||
(xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) ||
|
||||
|
|
|
@ -153,6 +153,7 @@ public abstract class DifferentialFunction {
|
|||
public Map<String,Object> propertiesForFunction() {
|
||||
Map<String,Field> fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
|
||||
Map<String,Object> ret = new LinkedHashMap<>();
|
||||
Preconditions.checkNotNull(fields, "DifferentialFunctionClassHolder returned null fields for %s - op has not been added to ImportClassMapping?", getClass());
|
||||
|
||||
for(val entry : fields.entrySet()) {
|
||||
try {
|
||||
|
|
|
@ -24,6 +24,7 @@ import java.lang.String;
|
|||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.enums.ImageResizeMethod;
|
||||
|
||||
public class SDImage extends SDOps {
|
||||
public SDImage(SameDiff sameDiff) {
|
||||
|
@ -254,6 +255,98 @@ public class SDImage extends SDOps {
|
|||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Resize images to size using the specified method.<br>
|
||||
*
|
||||
* @param input 4D image [NHWC] (NUMERIC type)
|
||||
* @param size new height and width (INT type)
|
||||
* @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False.
|
||||
* @param antialis Whether to use an anti-aliasing filter when downsampling an image
|
||||
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
|
||||
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
|
||||
* ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling.
|
||||
* ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0.
|
||||
* ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation.
|
||||
* ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases.
|
||||
* ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp.
|
||||
* @return output Output image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable imageResize(SDVariable input, SDVariable size, boolean preserveAspectRatio,
|
||||
boolean antialis, ImageResizeMethod ImageResizeMethod) {
|
||||
SDValidation.validateNumerical("imageResize", "input", input);
|
||||
SDValidation.validateInteger("imageResize", "size", size);
|
||||
return new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, preserveAspectRatio, antialis, ImageResizeMethod).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Resize images to size using the specified method.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input 4D image [NHWC] (NUMERIC type)
|
||||
* @param size new height and width (INT type)
|
||||
* @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False.
|
||||
* @param antialis Whether to use an anti-aliasing filter when downsampling an image
|
||||
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
|
||||
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
|
||||
* ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling.
|
||||
* ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0.
|
||||
* ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation.
|
||||
* ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases.
|
||||
* ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp.
|
||||
* @return output Output image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable imageResize(String name, SDVariable input, SDVariable size,
|
||||
boolean preserveAspectRatio, boolean antialis, ImageResizeMethod ImageResizeMethod) {
|
||||
SDValidation.validateNumerical("imageResize", "input", input);
|
||||
SDValidation.validateInteger("imageResize", "size", size);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, preserveAspectRatio, antialis, ImageResizeMethod).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Resize images to size using the specified method.<br>
|
||||
*
|
||||
* @param input 4D image [NHWC] (NUMERIC type)
|
||||
* @param size new height and width (INT type)
|
||||
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
|
||||
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
|
||||
* ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling.
|
||||
* ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0.
|
||||
* ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation.
|
||||
* ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases.
|
||||
* ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp.
|
||||
* @return output Output image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable imageResize(SDVariable input, SDVariable size,
|
||||
ImageResizeMethod ImageResizeMethod) {
|
||||
SDValidation.validateNumerical("imageResize", "input", input);
|
||||
SDValidation.validateInteger("imageResize", "size", size);
|
||||
return new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, false, false, ImageResizeMethod).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Resize images to size using the specified method.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input 4D image [NHWC] (NUMERIC type)
|
||||
* @param size new height and width (INT type)
|
||||
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
|
||||
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
|
||||
* ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling.
|
||||
* ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0.
|
||||
* ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation.
|
||||
* ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases.
|
||||
* ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp.
|
||||
* @return output Output image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable imageResize(String name, SDVariable input, SDVariable size,
|
||||
ImageResizeMethod ImageResizeMethod) {
|
||||
SDValidation.validateNumerical("imageResize", "input", input);
|
||||
SDValidation.validateInteger("imageResize", "size", size);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, false, false, ImageResizeMethod).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Greedily selects a subset of bounding boxes in descending order of score<br>
|
||||
*
|
||||
|
|
|
@ -24,6 +24,7 @@ import java.lang.String;
|
|||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.enums.PartitionMode;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.indexing.conditions.Condition;
|
||||
|
||||
|
@ -32,6 +33,67 @@ public class SDMath extends SDOps {
|
|||
super(sameDiff);
|
||||
}
|
||||
|
||||
/**
|
||||
* Clips tensor values to a maximum average L2-norm.<br>
|
||||
*
|
||||
* @param x Input variable (NUMERIC type)
|
||||
* @param clipValue Value for clipping
|
||||
* @param dimensions Dimensions to reduce over (Size: AtLeast(min=0))
|
||||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable clipByAvgNorm(SDVariable x, double clipValue, int... dimensions) {
|
||||
SDValidation.validateNumerical("ClipByAvgNorm", "x", x);
|
||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd,x, clipValue, dimensions).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Clips tensor values to a maximum average L2-norm.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x Input variable (NUMERIC type)
|
||||
* @param clipValue Value for clipping
|
||||
* @param dimensions Dimensions to reduce over (Size: AtLeast(min=0))
|
||||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable clipByAvgNorm(String name, SDVariable x, double clipValue, int... dimensions) {
|
||||
SDValidation.validateNumerical("ClipByAvgNorm", "x", x);
|
||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd,x, clipValue, dimensions).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Looks up ids in a list of embedding tensors.<br>
|
||||
*
|
||||
* @param x Input tensor (NUMERIC type)
|
||||
* @param indices A Tensor containing the ids to be looked up. (INT type)
|
||||
* @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div'
|
||||
* @return output Shifted output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable embeddingLookup(SDVariable x, SDVariable indices, PartitionMode PartitionMode) {
|
||||
SDValidation.validateNumerical("EmbeddingLookup", "x", x);
|
||||
SDValidation.validateInteger("EmbeddingLookup", "indices", indices);
|
||||
return new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Looks up ids in a list of embedding tensors.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x Input tensor (NUMERIC type)
|
||||
* @param indices A Tensor containing the ids to be looked up. (INT type)
|
||||
* @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div'
|
||||
* @return output Shifted output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable embeddingLookup(String name, SDVariable x, SDVariable indices,
|
||||
PartitionMode PartitionMode) {
|
||||
SDValidation.validateNumerical("EmbeddingLookup", "x", x);
|
||||
SDValidation.validateInteger("EmbeddingLookup", "indices", indices);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Elementwise absolute value operation: out = abs(x)<br>
|
||||
*
|
||||
|
|
|
@ -30,6 +30,30 @@ public class SDNN extends SDOps {
|
|||
super(sameDiff);
|
||||
}
|
||||
|
||||
/**
|
||||
* Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.<br>
|
||||
*
|
||||
* @param x Input variable (NUMERIC type)
|
||||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable cReLU(SDVariable x) {
|
||||
SDValidation.validateNumerical("CReLU", "x", x);
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd,x).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x Input variable (NUMERIC type)
|
||||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable cReLU(String name, SDVariable x) {
|
||||
SDValidation.validateNumerical("CReLU", "x", x);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd,x).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Neural network batch normalization operation.<br>
|
||||
* For details, see <a href="https://arxiv.org/abs/1502.03167">https://arxiv.org/abs/1502.03167</a><br>
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.enums;
|
||||
|
||||
/**
|
||||
* ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
|
||||
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
|
||||
* ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling.
|
||||
* ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0.
|
||||
* ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation.
|
||||
* ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases.
|
||||
* ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. */
|
||||
public enum ImageResizeMethod {
|
||||
ResizeBilinear,
|
||||
|
||||
ResizeBicubic,
|
||||
|
||||
ResizeNearest,
|
||||
|
||||
ResizeGaussian,
|
||||
|
||||
ResizeLanczos5,
|
||||
|
||||
ResizeMitchelcubic,
|
||||
|
||||
ResizeArea
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.enums;
|
||||
|
||||
/**
|
||||
* partition_mode == 0 - i.e. 'mod' , 1 - 'div' */
|
||||
public enum PartitionMode {
|
||||
MOD,
|
||||
|
||||
DIV
|
||||
}
|
|
@ -93,6 +93,7 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.grid.FreeGridOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.image.CropAndResize.class,
|
||||
org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches.class,
|
||||
org.nd4j.linalg.api.ops.impl.image.ImageResize.class,
|
||||
org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression.class,
|
||||
org.nd4j.linalg.api.ops.impl.image.NonMaxSuppressionV3.class,
|
||||
org.nd4j.linalg.api.ops.impl.image.ResizeBilinear.class,
|
||||
|
@ -127,6 +128,7 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2DBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization.class,
|
||||
|
@ -146,6 +148,7 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayerBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell.class,
|
||||
|
@ -322,9 +325,12 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.shape.Unstack.class,
|
||||
org.nd4j.linalg.api.ops.impl.shape.ZerosLike.class,
|
||||
org.nd4j.linalg.api.ops.impl.shape.bp.ConcatBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.shape.bp.MergeMaxBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.shape.bp.MergeAvgBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.shape.bp.TileBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup.class,
|
||||
org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray.class,
|
||||
org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayConcat.class,
|
||||
org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayGather.class,
|
||||
|
@ -354,6 +360,7 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNormBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue.class,
|
||||
|
@ -365,6 +372,8 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CReluBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance.class,
|
||||
|
@ -406,6 +415,7 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.Max.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.Min.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.MirrorPad.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention.class,
|
||||
|
@ -492,11 +502,13 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.ModBpOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MergeAddBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SquaredDifferenceBpOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Not.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or.class,
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.impl.image;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.enums.ImageResizeMethod;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class ImageResize extends DynamicCustomOp {
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "image_resize";
|
||||
}
|
||||
|
||||
|
||||
public ImageResize(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable size, boolean preserveAspectRatio, boolean antialias, ImageResizeMethod method) {
|
||||
super("image_resize", sameDiff, new SDVariable[]{in, size});
|
||||
addBArgument(preserveAspectRatio, antialias);
|
||||
addIArgument(method.ordinal());
|
||||
}
|
||||
|
||||
public ImageResize(@NonNull INDArray in, @NonNull INDArray size, boolean preserveAspectRatio, boolean antialias, ImageResizeMethod method) {
|
||||
super("image_resize", new INDArray[]{in, size}, null);
|
||||
Preconditions.checkArgument(in.rank()==4,"expected input message in NHWC format i.e [batchSize, height, width, channels]");
|
||||
addBArgument(preserveAspectRatio, antialias);
|
||||
addIArgument(method.ordinal());
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions
|
||||
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes);
|
||||
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -56,9 +56,11 @@ public class DepthwiseConv2D extends DynamicCustomOp {
|
|||
|
||||
protected Conv2DConfig config;
|
||||
|
||||
|
||||
public DepthwiseConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable input,
|
||||
@NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) {
|
||||
this(sameDiff, wrapFilterNull(input, weights, bias), conv2DConfig);
|
||||
|
||||
}
|
||||
|
||||
@Builder(builderMethodName = "sameDiffBuilder")
|
||||
|
@ -71,14 +73,14 @@ public class DepthwiseConv2D extends DynamicCustomOp {
|
|||
addArgs();
|
||||
}
|
||||
|
||||
public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
|
||||
public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config) {
|
||||
super(inputs, outputs);
|
||||
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
|
||||
public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config) {
|
||||
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||
}
|
||||
|
||||
|
@ -127,7 +129,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public Map<String, Object> propertiesForFunction() {
|
||||
if(config == null && !iArguments.isEmpty()){
|
||||
if (config == null && !iArguments.isEmpty()) {
|
||||
config = Conv2DConfig.builder()
|
||||
.kH(iArguments.get(0))
|
||||
.kW(iArguments.get(1))
|
||||
|
@ -308,7 +310,9 @@ public class DepthwiseConv2D extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
throw new UnsupportedOperationException("Not implemented yet");
|
||||
SDVariable bias = args().length==2 ? null : arg(2);
|
||||
return Arrays.asList(new DepthwiseConv2DBp(sameDiff, arg(0), arg(1), bias, f1.get(0), this.config).outputVariables());
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -323,7 +327,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
||||
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter;
|
||||
import org.nd4j.imports.descriptors.properties.adapters.NDArrayShapeAdapter;
|
||||
import org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater;
|
||||
import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.*;
|
||||
|
||||
|
||||
/**
|
||||
* Backpropagation for Depthwise Conv2D operation
|
||||
*/
|
||||
@Slf4j
|
||||
@Getter
|
||||
@NoArgsConstructor
|
||||
public class DepthwiseConv2DBp extends DynamicCustomOp {
|
||||
|
||||
protected Conv2DConfig config;
|
||||
|
||||
|
||||
public DepthwiseConv2DBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull SDVariable gradO, @NonNull Conv2DConfig config){
|
||||
super(sameDiff, wrapFilterNull(input, weights, bias, gradO));
|
||||
this.config = config;
|
||||
addArgs();
|
||||
|
||||
}
|
||||
|
||||
public DepthwiseConv2DBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull SDVariable gradO, @NonNull Conv2DConfig config){
|
||||
super(sameDiff, wrapFilterNull(input, weights, gradO));
|
||||
this.config = config;
|
||||
addArgs();
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public long[] iArgs() {
|
||||
if (iArguments.size() == 0)
|
||||
addArgs();
|
||||
|
||||
return super.iArgs();
|
||||
}
|
||||
|
||||
protected void addArgs() {
|
||||
addIArgument(config.getKH(),
|
||||
config.getKW(),
|
||||
config.getSH(),
|
||||
config.getSW(),
|
||||
config.getPH(),
|
||||
config.getPW(),
|
||||
config.getDH(),
|
||||
config.getDW(),
|
||||
ArrayUtil.fromBoolean(config.isSameMode()),
|
||||
config.getDataFormat().equalsIgnoreCase(Conv2DConfig.NCHW) ? 0 : 1);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getValue(Field property) {
|
||||
if (config == null) {
|
||||
config = Conv2DConfig.builder().build();
|
||||
}
|
||||
|
||||
try {
|
||||
val t = config.getValue(property);
|
||||
return t;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> propertiesForFunction() {
|
||||
if (config == null && !iArguments.isEmpty()) {
|
||||
config = Conv2DConfig.builder()
|
||||
.kH(iArguments.get(0))
|
||||
.kW(iArguments.get(1))
|
||||
.sH(iArguments.get(2))
|
||||
.sW(iArguments.get(3))
|
||||
.pH(iArguments.get(4))
|
||||
.pW(iArguments.get(5))
|
||||
.dH(iArguments.get(6))
|
||||
.dW(iArguments.get(7))
|
||||
.isSameMode(iArguments.get(8) == 1)
|
||||
.dataFormat(iArguments.get(9) == 1 ? Conv2DConfig.NHWC : Conv2DConfig.NCHW)
|
||||
.build();
|
||||
}
|
||||
return config.toProperties();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public boolean isConfigProperties() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String configFieldName() {
|
||||
return "config";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "depthwise_conv2d_bp";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||
int n = args().length;
|
||||
List<DataType> list = new ArrayList<DataType>();
|
||||
for(int i=0;i<n-1;i++){list.add(inputDataTypes.get(0));}
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return list;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumOutputs(){
|
||||
return args().length == 4 ? 3 : 2;}
|
||||
}
|
|
@ -16,6 +16,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -23,6 +24,7 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2DBp;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
|
||||
import org.nd4j.shade.guava.primitives.Booleans;
|
||||
|
@ -60,6 +62,7 @@ import java.util.Map;
|
|||
* 1: output at last step hL - rank 3 or 4, depends on DirectionMode and dataFormat<<br>
|
||||
* 2: cell state at last step cL - same shape as in hL<br>
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
public class LSTMLayer extends DynamicCustomOp {
|
||||
|
||||
@Getter
|
||||
|
@ -68,14 +71,18 @@ public class LSTMLayer extends DynamicCustomOp {
|
|||
@Getter
|
||||
private LSTMLayerWeights weights;
|
||||
|
||||
private SDVariable cLast;
|
||||
private SDVariable yLast;
|
||||
private SDVariable maxTSLength;
|
||||
|
||||
public LSTMLayer() {
|
||||
}
|
||||
|
||||
public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, LSTMLayerWeights weights, LSTMLayerConfig configuration) {
|
||||
super(null, sameDiff, weights.argsWithInputs(x, maxTSLength, cLast, yLast));
|
||||
this.configuration = configuration;
|
||||
this.weights = weights;
|
||||
this.cLast = cLast;
|
||||
this.yLast = yLast;
|
||||
this.maxTSLength = maxTSLength;
|
||||
addIArgument(iArgs());
|
||||
addTArgument(tArgs());
|
||||
addBArgument(bArgs(weights, maxTSLength, yLast, cLast));
|
||||
|
@ -124,7 +131,13 @@ public class LSTMLayer extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> grads) {
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
int i=0;
|
||||
SDVariable grad0 = this.configuration.isRetFullSequence() ? grads.get(i++): null;
|
||||
SDVariable grad1 = this.configuration.isRetLastH() ? grads.get(i++): null;
|
||||
SDVariable grad2 = this.configuration.isRetLastC() ? grads.get(i++): null;
|
||||
|
||||
return Arrays.asList(new LSTMLayerBp(sameDiff, arg(0), this.cLast, this.yLast, this.maxTSLength,
|
||||
this.weights, this.configuration, grad0, grad1,grad2).outputVariables());
|
||||
}
|
||||
|
||||
|
||||
|
@ -155,7 +168,7 @@ public class LSTMLayer extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
|
||||
public <T> boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) {
|
||||
protected <T> boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) {
|
||||
return new boolean[]{
|
||||
weights.hasBias(), // hasBiases: B_ARG(0)
|
||||
maxTSLength != null, // hasSeqLen: B_ARG(1)
|
||||
|
@ -169,6 +182,16 @@ public class LSTMLayer extends DynamicCustomOp {
|
|||
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isConfigProperties() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String configFieldName() {
|
||||
return "configuration";
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumOutputs(){
|
||||
|
||||
|
|
|
@ -0,0 +1,176 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
|
||||
import org.nd4j.shade.guava.primitives.Booleans;
|
||||
|
||||
import javax.xml.crypto.Data;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
/**
|
||||
* LSTM layer backpropagation
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
public class LSTMLayerBp extends DynamicCustomOp {
|
||||
|
||||
@Getter
|
||||
private LSTMLayerConfig configuration;
|
||||
|
||||
@Getter
|
||||
private LSTMLayerWeights weights;
|
||||
|
||||
private SDVariable cLast;
|
||||
private SDVariable yLast;
|
||||
private SDVariable maxTSLength;
|
||||
|
||||
|
||||
public LSTMLayerBp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, @NonNull LSTMLayerWeights weights, @NonNull LSTMLayerConfig configuration,
|
||||
SDVariable dLdh, SDVariable dLdhL, SDVariable dLdcL) {
|
||||
super("lstmLayer_bp", sameDiff, wrapFilterNull(x, weights.getWeights(), weights.getRWeights(), weights.getBias(),
|
||||
maxTSLength, yLast, cLast, weights.getPeepholeWeights(), dLdh, dLdhL, dLdcL));
|
||||
this.configuration = configuration;
|
||||
this.weights = weights;
|
||||
this.cLast = cLast;
|
||||
this.yLast = yLast;
|
||||
this.maxTSLength = maxTSLength;
|
||||
addIArgument(iArgs());
|
||||
addTArgument(tArgs());
|
||||
addBArgument(bArgs(weights, maxTSLength, yLast, cLast));
|
||||
|
||||
|
||||
Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(),
|
||||
"You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence methods in LSTMLayerConfig builder to specify them");
|
||||
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||
|
||||
DataType dt = inputDataTypes.get(1);
|
||||
Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", dt);
|
||||
ArrayList<DataType> list = new ArrayList<>();
|
||||
list.add(dt); // dLdx
|
||||
list.add(dt); // dLdWx
|
||||
list.add(dt); // dLdWr
|
||||
|
||||
if (this.weights.hasBias()) {
|
||||
list.add(dt);
|
||||
} // dLdb
|
||||
|
||||
if (this.maxTSLength != null) {
|
||||
list.add(dt);
|
||||
} // dLdSl
|
||||
if (this.yLast != null) {
|
||||
list.add(dt);
|
||||
} //dLdhI
|
||||
if (this.cLast != null) {
|
||||
list.add(dt);
|
||||
} // dLdcI
|
||||
if (this.weights.hasPH()) {
|
||||
list.add(dt);
|
||||
} // dLdWp
|
||||
|
||||
return list;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "lstmLayer_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> propertiesForFunction() {
|
||||
return configuration.toProperties(true, true);
|
||||
}
|
||||
|
||||
|
||||
public long[] iArgs() {
|
||||
return new long[]{
|
||||
configuration.getLstmdataformat().ordinal(),// INT_ARG(0)
|
||||
configuration.getDirectionMode().ordinal(), // INT_ARG(1)
|
||||
configuration.getGateAct().ordinal(), // INT_ARG(2)
|
||||
configuration.getOutAct().ordinal(), // INT_ARG(3)
|
||||
configuration.getCellAct().ordinal() // INT_ARG(4)
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
public double[] tArgs() {
|
||||
return new double[]{this.configuration.getCellClip()}; // T_ARG(0)
|
||||
}
|
||||
|
||||
|
||||
protected <T> boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) {
|
||||
return new boolean[]{
|
||||
weights.hasBias(), // hasBiases: B_ARG(0)
|
||||
maxTSLength != null, // hasSeqLen: B_ARG(1)
|
||||
yLast != null, // hasInitH: B_ARG(2)
|
||||
cLast != null, // hasInitC: B_ARG(3)
|
||||
weights.hasPH(), // hasPH: B_ARG(4)
|
||||
configuration.isRetFullSequence(), //retFullSequence: B_ARG(5)
|
||||
configuration.isRetLastH(), // retLastH: B_ARG(6)
|
||||
configuration.isRetLastC() // retLastC: B_ARG(7)
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isConfigProperties() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String configFieldName() {
|
||||
return "configuration";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public int getNumOutputs() {
|
||||
|
||||
return Booleans.countTrue(
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
weights.hasBias(),
|
||||
this.maxTSLength != null,
|
||||
this.yLast != null,
|
||||
this.cLast != null,
|
||||
weights.hasPH()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -15,8 +15,10 @@
|
|||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
|
||||
|
||||
|
@ -26,9 +28,10 @@ import java.util.Map;
|
|||
|
||||
@Builder
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class LSTMLayerConfig {
|
||||
|
||||
|
||||
/**
|
||||
* notations <br>
|
||||
* for unidirectional:
|
||||
|
@ -90,23 +93,23 @@ public class LSTMLayerConfig {
|
|||
* Cell clipping value, if it = 0 then do not apply clipping
|
||||
*/
|
||||
@Builder.Default
|
||||
private double cellClip; //T_ARG(0)
|
||||
private double cellClip = 0; //T_ARG(0)
|
||||
|
||||
|
||||
public Map<String, Object> toProperties(boolean includeLSTMDataFormat, boolean includeLSTMDirectionMode) {
|
||||
Map<String, Object> ret = new LinkedHashMap<>();
|
||||
ret.put("gateAct", gateAct.ordinal());
|
||||
ret.put("outAct", outAct.ordinal());
|
||||
ret.put("cellAct", cellAct.ordinal());
|
||||
ret.put("gateAct", gateAct.toString());
|
||||
ret.put("outAct", outAct.toString());
|
||||
ret.put("cellAct", cellAct.toString());
|
||||
ret.put("retFullSequence", retFullSequence);
|
||||
ret.put("retLastH", retLastH);
|
||||
ret.put("retLastC", retLastC);
|
||||
ret.put("cellClip", cellClip);
|
||||
|
||||
if (includeLSTMDataFormat)
|
||||
ret.put("LSTMDataFormat", lstmdataformat.ordinal());
|
||||
ret.put("lstmdataformat", lstmdataformat.toString());
|
||||
if (includeLSTMDirectionMode)
|
||||
ret.put("LSTMDirectionMode", directionMode.ordinal());
|
||||
ret.put("directionMode", directionMode.toString());
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -24,15 +24,13 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.bp.MergeAvgBp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
|
||||
@Slf4j
|
||||
public class MergeAvg extends DynamicCustomOp {
|
||||
|
@ -74,12 +72,8 @@ public class MergeAvg extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
int nArgs = args().length;
|
||||
SDVariable gradient = sameDiff.setupFunction(i_v.get(0)).div(nArgs);
|
||||
List<SDVariable> ret = new ArrayList<>();
|
||||
for (int i = 0; i < args().length; i++)
|
||||
ret.add(gradient);
|
||||
return ret;
|
||||
return Arrays.asList(new MergeAvgBp(sameDiff, args(), i_v.get(0)).outputVariables());
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -24,14 +24,12 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.bp.MergeMaxBp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
|
||||
@Slf4j
|
||||
public class MergeMax extends DynamicCustomOp {
|
||||
|
@ -71,14 +69,8 @@ public class MergeMax extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
SDVariable gradient = sameDiff.setupFunction(i_v.get(0));
|
||||
List<SDVariable> ret = new ArrayList<>();
|
||||
SDVariable out = outputVariable();
|
||||
for (int i = 0; i < args().length; i++){
|
||||
SDVariable isMax = out.eq(arg(i)).castTo(arg(i).dataType());
|
||||
ret.add(isMax.mul(gradient));
|
||||
}
|
||||
return ret;
|
||||
return Arrays.asList(new MergeMaxBp(sameDiff, args(), i_v.get(0)).outputVariables());
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.impl.shape.bp;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@NoArgsConstructor
|
||||
public class MergeAvgBp extends DynamicCustomOp {
|
||||
|
||||
public MergeAvgBp(SameDiff sameDiff, @NonNull SDVariable[] inputs, @NonNull SDVariable gradO) {
|
||||
super("mergeavg_bp", sameDiff, ArrayUtils.add(inputs, gradO));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "mergeavg_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||
ArrayList<DataType> list = new ArrayList<DataType>();
|
||||
for (int i = 0; i < args().length - 1; i++) {
|
||||
list.add(inputDataTypes.get(0));
|
||||
}
|
||||
return list;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumOutputs() {
|
||||
return args().length - 1;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.impl.shape.bp;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@NoArgsConstructor
|
||||
public class MergeMaxBp extends DynamicCustomOp {
|
||||
|
||||
public MergeMaxBp(SameDiff sameDiff, @NonNull SDVariable[] inputs, @NonNull SDVariable gradO) {
|
||||
super("mergemax_bp", sameDiff, ArrayUtils.add(inputs, gradO));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "mergemax_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||
List<DataType> list = new ArrayList<DataType>();
|
||||
for (int i=0; i< args().length-1;i++){
|
||||
list.add(inputDataTypes.get(0));
|
||||
}
|
||||
return list;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumOutputs(){
|
||||
return args().length-1;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.val;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.enums.PartitionMode;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class EmbeddingLookup extends DynamicCustomOp {
|
||||
|
||||
public EmbeddingLookup(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable indices, PartitionMode partitionMode) {
|
||||
super("embedding_lookup", sameDiff, new SDVariable[]{in, indices});
|
||||
addIArgument(partitionMode.ordinal());
|
||||
}
|
||||
|
||||
public EmbeddingLookup(@NonNull INDArray in, @NonNull INDArray indices, PartitionMode partitionMode, INDArray output) {
|
||||
super("embedding_lookup", new INDArray[]{in, indices}, wrapOrNull(output));
|
||||
addIArgument(partitionMode.ordinal());
|
||||
|
||||
}
|
||||
|
||||
public EmbeddingLookup(@NonNull INDArray in, INDArray output, PartitionMode partitionMode, @NonNull int... indices) {
|
||||
super("embedding_lookup", new INDArray[]{in, Nd4j.createFromArray(indices)}, wrapOrNull(output));
|
||||
addIArgument(partitionMode.ordinal());
|
||||
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "embedding_lookup";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions
|
||||
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(1).isIntType(), "Input datatype must be integer point, got %s", dataTypes);
|
||||
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.clip;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@NoArgsConstructor
|
||||
public class ClipByAvgNorm extends DynamicCustomOp {
|
||||
|
||||
private double clipValue;
|
||||
|
||||
|
||||
public ClipByAvgNorm(SameDiff sameDiff, SDVariable x, double clipValue, int... dimensions) {
|
||||
super("clipbyavgnorm", sameDiff, new SDVariable[]{x});
|
||||
this.clipValue = clipValue;
|
||||
this.dimensions = dimensions;
|
||||
addIArgument(dimensions);
|
||||
addTArgument(clipValue);
|
||||
}
|
||||
|
||||
public ClipByAvgNorm(INDArray in, double clipValue, int... dimensions){
|
||||
this(in, null, clipValue, dimensions);
|
||||
}
|
||||
|
||||
public ClipByAvgNorm(INDArray in, INDArray out, double clipValue, int... dimensions){
|
||||
super("clipbyavgnorm", new INDArray[]{in}, wrapOrNull(out), Collections.singletonList(clipValue), dimensions);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "clipbyavgnorm";
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> grad) {
|
||||
throw new UnsupportedOperationException("Not yet implemented"); }
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
|
||||
return inputDataTypes;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class CReLU extends DynamicCustomOp {
|
||||
|
||||
|
||||
public CReLU(SameDiff sd, SDVariable input) {
|
||||
super(sd, new SDVariable[]{input});
|
||||
}
|
||||
|
||||
public CReLU(@NonNull INDArray input) {
|
||||
super(new INDArray[]{input}, null);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "crelu";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions
|
||||
.checkArgument(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes);
|
||||
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
|
||||
return Collections.singletonList(new CReluBp(sameDiff, arg(), i_v.get(0)).outputVariable());
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
|
||||
|
||||
@NoArgsConstructor
|
||||
public class CReluBp extends DynamicCustomOp {
|
||||
|
||||
public CReluBp(SameDiff sd, SDVariable input, SDVariable epsilonNext) {
|
||||
super(sd, new SDVariable[]{input, epsilonNext});
|
||||
}
|
||||
|
||||
public CReluBp(@NonNull INDArray input, @NonNull INDArray epsilonNext, INDArray output) {
|
||||
super(new INDArray[]{input, epsilonNext}, wrapOrNull(output));
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "crelu_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions
|
||||
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes);
|
||||
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -73,12 +73,7 @@ public class Max extends BaseDynamicTransformOp {
|
|||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
//TODO Switch to maximum_bp op - https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/broadcastable/maximum.cpp
|
||||
SDVariable max = outputVariables()[0];
|
||||
SDVariable eq1 = sameDiff.eq(larg(), max).castTo(arg(0).dataType());
|
||||
SDVariable eq2 = sameDiff.eq(rarg(), max).castTo(arg(1).dataType());
|
||||
|
||||
return Arrays.asList(eq1.mul(f1.get(0)), eq2.mul(f1.get(0)));
|
||||
return Arrays.asList(new MaximumBp(sameDiff, arg(0), arg(1), f1.get(0)).outputVariables());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class MaximumBp extends DynamicCustomOp {
|
||||
|
||||
public MaximumBp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y, @NonNull SDVariable gradO) {
|
||||
super("maximum_bp",sameDiff, new SDVariable[]{x,y, gradO});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "maximum_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||
List<DataType> list = new ArrayList<DataType>();
|
||||
list.add(inputDataTypes.get(0));
|
||||
list.add(inputDataTypes.get(0));
|
||||
return list;
|
||||
}
|
||||
}
|
|
@ -18,14 +18,19 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
|
|||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MergeAddBp;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
|
@ -70,11 +75,8 @@ public class MergeAddOp extends BaseDynamicTransformOp {
|
|||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
SDVariable gradient = sameDiff.setupFunction(i_v.get(0));
|
||||
List<SDVariable> ret = new ArrayList<>();
|
||||
for (int i = 0; i < args().length; i++)
|
||||
ret.add(gradient);
|
||||
return ret;
|
||||
return Arrays.asList(new MergeAddBp(sameDiff, args(), i_v.get(0)).outputVariables());
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class MergeAddBp extends DynamicCustomOp {
|
||||
|
||||
public MergeAddBp(SameDiff sameDiff, @NonNull SDVariable[] inputs, @NonNull SDVariable gradO) {
|
||||
super("mergeadd_bp", sameDiff, ArrayUtils.add(inputs, gradO));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "mergeadd_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||
ArrayList<DataType> list = new ArrayList<DataType>();
|
||||
for (int i=0; i< args().length-1;i++){list.add(inputDataTypes.get(0));}
|
||||
return list;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumOutputs(){
|
||||
return args().length-1;
|
||||
}
|
||||
}
|
|
@ -21,6 +21,7 @@ package org.nd4j.linalg.factory.ops;
|
|||
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.enums.ImageResizeMethod;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.NDValidation;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -134,6 +135,49 @@ public class NDImage {
|
|||
return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.HsvToRgb(input))[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Resize images to size using the specified method.<br>
|
||||
*
|
||||
* @param input 4D image [NHWC] (NUMERIC type)
|
||||
* @param size new height and width (INT type)
|
||||
* @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False.
|
||||
* @param antialis Whether to use an anti-aliasing filter when downsampling an image
|
||||
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
|
||||
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
|
||||
* ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling.
|
||||
* ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0.
|
||||
* ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation.
|
||||
* ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases.
|
||||
* ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp.
|
||||
* @return output Output image (NUMERIC type)
|
||||
*/
|
||||
public INDArray imageResize(INDArray input, INDArray size, boolean preserveAspectRatio,
|
||||
boolean antialis, ImageResizeMethod ImageResizeMethod) {
|
||||
NDValidation.validateNumerical("imageResize", "input", input);
|
||||
NDValidation.validateInteger("imageResize", "size", size);
|
||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.ImageResize(input, size, preserveAspectRatio, antialis, ImageResizeMethod))[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Resize images to size using the specified method.<br>
|
||||
*
|
||||
* @param input 4D image [NHWC] (NUMERIC type)
|
||||
* @param size new height and width (INT type)
|
||||
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
|
||||
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
|
||||
* ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling.
|
||||
* ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0.
|
||||
* ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation.
|
||||
* ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases.
|
||||
* ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp.
|
||||
* @return output Output image (NUMERIC type)
|
||||
*/
|
||||
public INDArray imageResize(INDArray input, INDArray size, ImageResizeMethod ImageResizeMethod) {
|
||||
NDValidation.validateNumerical("imageResize", "input", input);
|
||||
NDValidation.validateInteger("imageResize", "size", size);
|
||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.ImageResize(input, size, false, false, ImageResizeMethod))[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Greedily selects a subset of bounding boxes in descending order of score<br>
|
||||
*
|
||||
|
|
|
@ -21,6 +21,7 @@ package org.nd4j.linalg.factory.ops;
|
|||
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.enums.PartitionMode;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.NDValidation;
|
||||
|
@ -31,6 +32,34 @@ public class NDMath {
|
|||
public NDMath() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Clips tensor values to a maximum average L2-norm.<br>
|
||||
*
|
||||
* @param x Input variable (NUMERIC type)
|
||||
* @param clipValue Value for clipping
|
||||
* @param dimensions Dimensions to reduce over (Size: AtLeast(min=0))
|
||||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public INDArray clipByAvgNorm(INDArray x, double clipValue, int... dimensions) {
|
||||
NDValidation.validateNumerical("ClipByAvgNorm", "x", x);
|
||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(x, clipValue, dimensions))[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Looks up ids in a list of embedding tensors.<br>
|
||||
*
|
||||
* @param x Input tensor (NUMERIC type)
|
||||
* @param indices A Tensor containing the ids to be looked up. (INT type)
|
||||
* @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div'
|
||||
* @return output Shifted output (NUMERIC type)
|
||||
*/
|
||||
public INDArray embeddingLookup(INDArray x, INDArray indices, PartitionMode PartitionMode) {
|
||||
NDValidation.validateNumerical("EmbeddingLookup", "x", x);
|
||||
NDValidation.validateInteger("EmbeddingLookup", "indices", indices);
|
||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(x, indices, PartitionMode))[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Elementwise absolute value operation: out = abs(x)<br>
|
||||
*
|
||||
|
|
|
@ -29,6 +29,17 @@ public class NDNN {
|
|||
public NDNN() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.<br>
|
||||
*
|
||||
* @param x Input variable (NUMERIC type)
|
||||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public INDArray cReLU(INDArray x) {
|
||||
NDValidation.validateNumerical("CReLU", "x", x);
|
||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(x))[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Neural network batch normalization operation.<br>
|
||||
* For details, see <a href="https://arxiv.org/abs/1502.03167">https://arxiv.org/abs/1502.03167</a><br>
|
||||
|
|
|
@ -20,7 +20,6 @@ import java.util.ArrayList;
|
|||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.junit.Ignore;
|
||||
|
@ -35,6 +34,7 @@ import org.nd4j.linalg.api.iter.NdIndexIterator;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
|
||||
|
@ -265,7 +265,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
msg = "7 - upsampling2d, NCHW, 2x2 - " + Arrays.toString(inSizeNCHW);
|
||||
inSize = inSizeNCHW;
|
||||
in = sd.var("in", inSize);
|
||||
out = sd.cnn().upsampling2d(in, 2, 2, true);
|
||||
out = sd.cnn().upsampling2d(in, 2, 2, true);
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException();
|
||||
|
@ -1469,6 +1469,43 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDepthwiseConv2D(){
|
||||
|
||||
int bS = 10;
|
||||
|
||||
int kernelHeight = 2;
|
||||
int kernelWidth = 2;
|
||||
int strideHeight = 2;
|
||||
int strideWidth = 2;
|
||||
int inChannels = 2;
|
||||
int outChannels = 3;
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.var("in", Nd4j.rand(bS, inChannels, 5,5));
|
||||
SDVariable weights = sd.var("weights", Nd4j.rand(DataType.DOUBLE, kernelHeight, kernelWidth, inChannels, outChannels));
|
||||
SDVariable bias = sd.var("bias", Nd4j.rand(DataType.DOUBLE, inChannels*outChannels));
|
||||
Conv2DConfig config = Conv2DConfig.builder()
|
||||
.kH(kernelHeight)
|
||||
.kW(kernelWidth)
|
||||
.sH(strideHeight)
|
||||
.sW(strideWidth)
|
||||
.dataFormat("NCHW")
|
||||
.build();
|
||||
|
||||
SDVariable out = sd.cnn.depthWiseConv2d(in, weights, bias, config);
|
||||
SDVariable loss = sd.standardDeviation("loss", out, true);
|
||||
loss.markAsLoss();
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.gradientCheck(true)
|
||||
);
|
||||
assertNull(err);
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void LSTMLayerTestCase1() {
|
||||
|
@ -1476,9 +1513,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
int bS = 5;
|
||||
int nIn = 3;
|
||||
int numUnits = 7;
|
||||
int sL = 10; //small just for test
|
||||
int sL = 3; //small just for test
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
||||
// notations:
|
||||
// bS - batch size, numExamples
|
||||
|
@ -1492,50 +1528,66 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
// T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
|
||||
|
||||
|
||||
SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, bS, nIn, sL));
|
||||
for (boolean useCLast : new boolean[]{false, true}) {
|
||||
for (boolean useYLast : new boolean[]{false, true}) {
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.var("in", Nd4j.randn(DataType.DOUBLE, bS, nIn, sL));
|
||||
|
||||
|
||||
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits));
|
||||
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits));
|
||||
SDVariable cLast = useCLast ? sd.var("cLast", Nd4j.zeros(DataType.DOUBLE, bS, numUnits)) : null;
|
||||
SDVariable yLast = useYLast ? sd.var("yLast", Nd4j.zeros(DataType.DOUBLE, bS, numUnits)) : null;
|
||||
|
||||
LSTMLayerConfig c = LSTMLayerConfig.builder()
|
||||
.lstmdataformat(LSTMDataFormat.NST)
|
||||
.directionMode(LSTMDirectionMode.FWD)
|
||||
.gateAct(LSTMActivations.SIGMOID)
|
||||
.cellAct(LSTMActivations.TANH)
|
||||
.outAct(LSTMActivations.TANH)
|
||||
.retFullSequence(true)
|
||||
.retLastC(true)
|
||||
.retLastH(true)
|
||||
.build();
|
||||
|
||||
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(
|
||||
in, cLast, yLast, null,
|
||||
LSTMLayerWeights.builder()
|
||||
.weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits)))
|
||||
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits)))
|
||||
.peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.rand(DataType.FLOAT, 3 * numUnits)))
|
||||
.bias(sd.var("bias", Nd4j.rand(DataType.FLOAT, 4 * numUnits))).build(),
|
||||
c), c);
|
||||
LSTMLayerConfig c = LSTMLayerConfig.builder()
|
||||
.lstmdataformat(LSTMDataFormat.NST)
|
||||
.directionMode(LSTMDirectionMode.FWD)
|
||||
.gateAct(LSTMActivations.SIGMOID)
|
||||
.cellAct(LSTMActivations.TANH)
|
||||
.outAct(LSTMActivations.TANH)
|
||||
.retFullSequence(true)
|
||||
.retLastC(true)
|
||||
.retLastH(true)
|
||||
.build();
|
||||
|
||||
long[] out = new long[]{bS, numUnits, sL};
|
||||
long[] hL = new long[]{bS, numUnits};
|
||||
long[] cL = new long[]{bS, numUnits};
|
||||
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(
|
||||
in, cLast, yLast, null,
|
||||
LSTMLayerWeights.builder()
|
||||
.weights(sd.var("weights", Nd4j.randn(DataType.DOUBLE, nIn, 4 * numUnits)))
|
||||
.rWeights(sd.var("rWeights", Nd4j.randn(DataType.DOUBLE, numUnits, 4 * numUnits)))
|
||||
.peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.randn(DataType.DOUBLE, 3 * numUnits)))
|
||||
.bias(sd.var("bias", Nd4j.rand(DataType.DOUBLE, 4 * numUnits))).build(),
|
||||
c), c);
|
||||
|
||||
assertArrayEquals(out, outputs.getOutput().eval().shape());
|
||||
assertArrayEquals(hL, outputs.getLastTimeStepOutput().eval().shape());
|
||||
assertArrayEquals(cL, outputs.getLastCellStateOutput().eval().shape());
|
||||
long[] out = new long[]{bS, numUnits, sL};
|
||||
long[] hL = new long[]{bS, numUnits};
|
||||
long[] cL = new long[]{bS, numUnits};
|
||||
|
||||
assertArrayEquals(out, outputs.getOutput().eval().shape());
|
||||
assertArrayEquals(hL, outputs.getLastOutput().eval().shape());
|
||||
assertArrayEquals(cL, outputs.getLastState().eval().shape());
|
||||
|
||||
sd.setLossVariables(outputs.getOutput(), outputs.getLastTimeStepOutput(), outputs.getTimeSeriesOutput());
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.gradientCheck(true)
|
||||
.testName("cLast=" + cLast + ", yLast=" + yLast)
|
||||
);
|
||||
|
||||
assertNull(err);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test @Ignore //AB 2020/04/08 - https://github.com/eclipse/deeplearning4j/issues/8824
|
||||
@Test
|
||||
public void LSTMLayerTestCase2() {
|
||||
int bS = 5;
|
||||
int nIn = 3;
|
||||
int numUnits = 7;
|
||||
int sL = 10; //small just for test
|
||||
int sL = 3; //small just for test
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
||||
|
@ -1549,11 +1601,11 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
// NTS: shape [numExamples, timeLength, inOutSize]<br>
|
||||
// for bidirectional:
|
||||
// T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
|
||||
SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, sL, bS, nIn));
|
||||
SDVariable in = sd.var("in", Nd4j.rand(DataType.DOUBLE, sL, bS, nIn));
|
||||
|
||||
|
||||
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits));
|
||||
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits));
|
||||
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.DOUBLE, bS, numUnits));
|
||||
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.DOUBLE, bS, numUnits));
|
||||
|
||||
LSTMLayerConfig c = LSTMLayerConfig.builder()
|
||||
.lstmdataformat(LSTMDataFormat.TNS)
|
||||
|
@ -1569,8 +1621,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(
|
||||
in, cLast, yLast, null,
|
||||
LSTMLayerWeights.builder()
|
||||
.weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits)))
|
||||
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits)))
|
||||
.weights(sd.var("weights", Nd4j.rand(DataType.DOUBLE, nIn, 4 * numUnits)))
|
||||
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.DOUBLE, numUnits, 4 * numUnits)))
|
||||
.build(),
|
||||
c), c);
|
||||
|
||||
|
@ -1578,14 +1630,22 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
long[] out = new long[]{sL, bS, numUnits};
|
||||
assertArrayEquals(out, outputs.getOutput().eval().shape());
|
||||
|
||||
sd.setLossVariables(outputs.getOutput());
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.gradientCheck(true)
|
||||
);
|
||||
|
||||
assertNull(err);
|
||||
|
||||
}
|
||||
|
||||
@Test @Ignore //AB 2020/04/08 - https://github.com/eclipse/deeplearning4j/issues/8824
|
||||
@Test
|
||||
public void LSTMLayerTestCase3() {
|
||||
int bS = 5;
|
||||
int nIn = 3;
|
||||
int numUnits = 7;
|
||||
int sL = 10; //small just for test
|
||||
int sL = 3; //small just for test
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
||||
|
@ -1599,14 +1659,14 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
// NTS: shape [numExamples, timeLength, inOutSize]<br>
|
||||
// for bidirectional:
|
||||
// T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)
|
||||
SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, bS, sL, nIn));
|
||||
SDVariable in = sd.var("in", Nd4j.rand(DataType.DOUBLE, bS, sL, nIn));
|
||||
|
||||
|
||||
// when directionMode >= 2 (BIDIR_CONCAT=3)
|
||||
// Wx, Wr [2, nIn, 4*nOut]
|
||||
// hI, cI [2, bS, nOut]
|
||||
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, 2, bS, numUnits));
|
||||
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, 2, bS, numUnits));
|
||||
SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.DOUBLE, 2, bS, numUnits));
|
||||
SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.DOUBLE, 2, bS, numUnits));
|
||||
|
||||
LSTMLayerConfig c = LSTMLayerConfig.builder()
|
||||
.lstmdataformat(LSTMDataFormat.NTS)
|
||||
|
@ -1622,8 +1682,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(new String[]{"out"},
|
||||
in, cLast, yLast, null,
|
||||
LSTMLayerWeights.builder()
|
||||
.weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, 2, nIn, 4 * numUnits)))
|
||||
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, 2, numUnits, 4 * numUnits)))
|
||||
.weights(sd.var("weights", Nd4j.rand(DataType.DOUBLE, 2, nIn, 4 * numUnits)))
|
||||
.rWeights(sd.var("rWeights", Nd4j.rand(DataType.DOUBLE, 2, numUnits, 4 * numUnits)))
|
||||
.build(),
|
||||
c), c);
|
||||
|
||||
|
@ -1631,5 +1691,17 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
long[] out = new long[]{bS, sL, 2 * numUnits};
|
||||
|
||||
assertArrayEquals(out, outputs.getOutput().eval().shape());
|
||||
|
||||
sd.setLossVariables(outputs.getOutput());
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.gradientCheck(true)
|
||||
);
|
||||
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue