Map C++ ops to Java (#392)
* MergeMaxIndex, ReverseBp, Tri, Triu and TriuBp added Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * Upsamling3d draft Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * minor fix (upsampling3dBp inputDatatype.size=2) Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * polished testcases Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * matching of Upsampling3d input format according to cpp iArg Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * ops generated from codegen Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * requested changes Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * added super() for Triu Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * everything passes except TriuOp Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * Tri op dtype arg (output datatype config support) + default float32 Signed-off-by: Alex Black <blacka101@gmail.com> * Small fixes Signed-off-by: Alex Black <blacka101@gmail.com> * temporary commit with manually edited sd/nd ops Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * Cannot use 'val' here because initializer expression does not have a representable type: Type cannot be resolved Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * all tests passed Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * few requested changes Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * Small fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Ignore reverse_bp test due to logged issue Signed-off-by: Alex Black <blacka101@gmail.com> * Fix reverse op Signed-off-by: Alex Black <blacka101@gmail.com> * Fix MergeMaxIndex dtype -> iarg Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>master
parent
bd376ca993
commit
58b11bfecc
|
@ -51,7 +51,9 @@ DECLARE_SHAPE_FN(tri) {
|
||||||
const int rows = INT_ARG(0);
|
const int rows = INT_ARG(0);
|
||||||
const int cols = block.numI() > 1 ? INT_ARG(1) : rows;
|
const int cols = block.numI() > 1 ? INT_ARG(1) : rows;
|
||||||
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', {rows, cols}));
|
auto dtype = block.numD() ? D_ARG(0) : DataType::FLOAT32;
|
||||||
|
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', {rows, cols}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -85,7 +85,7 @@ namespace ops {
|
||||||
// check the consistency of input dimensions to reverse along
|
// check the consistency of input dimensions to reverse along
|
||||||
shape::checkDimensions(input->rankOf(), axis);
|
shape::checkDimensions(input->rankOf(), axis);
|
||||||
// we just reverse back original array
|
// we just reverse back original array
|
||||||
helpers::reverse(block.launchContext(), eps, output, &axis, true);
|
helpers::reverse(block.launchContext(), eps, output, &axis, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -447,7 +447,7 @@ public abstract class DifferentialFunction {
|
||||||
this.sameDiff = sameDiff;
|
this.sameDiff = sameDiff;
|
||||||
this.inPlace = inPlace;
|
this.inPlace = inPlace;
|
||||||
setInstanceId();
|
setInstanceId();
|
||||||
if(sameDiff != null) {
|
if(sameDiff != null && args != null) {
|
||||||
sameDiff.addArgsFor(args, this);
|
sameDiff.addArgsFor(args, this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1052,4 +1052,38 @@ public class SDCNN extends SDOps {
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(sd,input, scaleH, scaleW, nchw).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(sd,input, scaleH, scaleW, nchw).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 3D Convolution layer operation - Upsampling 3d <br>
|
||||||
|
*
|
||||||
|
* @param input Input in NCHW format (NUMERIC type)
|
||||||
|
* @param ncdhw If true: input is in NCDHW (minibatch, channels, depth, height, width) format. False: NDHWC format
|
||||||
|
* @param scaleD Scale to upsample in depth dimension
|
||||||
|
* @param scaleH Scale to upsample in height dimension
|
||||||
|
* @param scaleW Scale to upsample in width dimension
|
||||||
|
* @return output Upsampled input (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public SDVariable upsampling3d(SDVariable input, boolean ncdhw, int scaleD, int scaleH,
|
||||||
|
int scaleW) {
|
||||||
|
SDValidation.validateNumerical("upsampling3d", "input", input);
|
||||||
|
return new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling3d(sd,input, ncdhw, scaleD, scaleH, scaleW).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 3D Convolution layer operation - Upsampling 3d <br>
|
||||||
|
*
|
||||||
|
* @param name name May be null. Name for the output variable
|
||||||
|
* @param input Input in NCHW format (NUMERIC type)
|
||||||
|
* @param ncdhw If true: input is in NCDHW (minibatch, channels, depth, height, width) format. False: NDHWC format
|
||||||
|
* @param scaleD Scale to upsample in depth dimension
|
||||||
|
* @param scaleH Scale to upsample in height dimension
|
||||||
|
* @param scaleW Scale to upsample in width dimension
|
||||||
|
* @return output Upsampled input (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public SDVariable upsampling3d(String name, SDVariable input, boolean ncdhw, int scaleD,
|
||||||
|
int scaleH, int scaleW) {
|
||||||
|
SDValidation.validateNumerical("upsampling3d", "input", input);
|
||||||
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling3d(sd,input, ncdhw, scaleD, scaleH, scaleW).outputVariable();
|
||||||
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -258,7 +258,7 @@ public class SDImage extends SDOps {
|
||||||
/**
|
/**
|
||||||
* Resize images to size using the specified method.<br>
|
* Resize images to size using the specified method.<br>
|
||||||
*
|
*
|
||||||
* @param input 4D image [NCHW] (NUMERIC type)
|
* @param input 4D image [NHWC] (NUMERIC type)
|
||||||
* @param size new height and width (INT type)
|
* @param size new height and width (INT type)
|
||||||
* @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False.
|
* @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False.
|
||||||
* @param antialis Whether to use an anti-aliasing filter when downsampling an image
|
* @param antialis Whether to use an anti-aliasing filter when downsampling an image
|
||||||
|
@ -282,7 +282,7 @@ public class SDImage extends SDOps {
|
||||||
* Resize images to size using the specified method.<br>
|
* Resize images to size using the specified method.<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param input 4D image [NCHW] (NUMERIC type)
|
* @param input 4D image [NHWC] (NUMERIC type)
|
||||||
* @param size new height and width (INT type)
|
* @param size new height and width (INT type)
|
||||||
* @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False.
|
* @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False.
|
||||||
* @param antialis Whether to use an anti-aliasing filter when downsampling an image
|
* @param antialis Whether to use an anti-aliasing filter when downsampling an image
|
||||||
|
@ -306,7 +306,7 @@ public class SDImage extends SDOps {
|
||||||
/**
|
/**
|
||||||
* Resize images to size using the specified method.<br>
|
* Resize images to size using the specified method.<br>
|
||||||
*
|
*
|
||||||
* @param input 4D image [NCHW] (NUMERIC type)
|
* @param input 4D image [NHWC] (NUMERIC type)
|
||||||
* @param size new height and width (INT type)
|
* @param size new height and width (INT type)
|
||||||
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
|
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
|
||||||
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
|
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
|
||||||
|
@ -328,7 +328,7 @@ public class SDImage extends SDOps {
|
||||||
* Resize images to size using the specified method.<br>
|
* Resize images to size using the specified method.<br>
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param input 4D image [NCHW] (NUMERIC type)
|
* @param input 4D image [NHWC] (NUMERIC type)
|
||||||
* @param size new height and width (INT type)
|
* @param size new height and width (INT type)
|
||||||
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
|
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
|
||||||
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
|
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
|
||||||
|
|
|
@ -23,6 +23,7 @@ import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||||
import java.lang.String;
|
import java.lang.String;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
|
||||||
public class SDLinalg extends SDOps {
|
public class SDLinalg extends SDOps {
|
||||||
public SDLinalg(SameDiff sameDiff) {
|
public SDLinalg(SameDiff sameDiff) {
|
||||||
|
@ -558,4 +559,106 @@ public class SDLinalg extends SDOps {
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, 16).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, 16).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An array with ones at and below the given diagonal and zeros elsewhere.<br>
|
||||||
|
*
|
||||||
|
* @param dataType Data type
|
||||||
|
* @param row
|
||||||
|
* @param column
|
||||||
|
* @param diagonal
|
||||||
|
* @return output (FLOATING_POINT type)
|
||||||
|
*/
|
||||||
|
public SDVariable tri(DataType dataType, int row, int column, int diagonal) {
|
||||||
|
return new org.nd4j.linalg.api.ops.custom.Tri(sd,dataType, row, column, diagonal).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An array with ones at and below the given diagonal and zeros elsewhere.<br>
|
||||||
|
*
|
||||||
|
* @param name name May be null. Name for the output variable
|
||||||
|
* @param dataType Data type
|
||||||
|
* @param row
|
||||||
|
* @param column
|
||||||
|
* @param diagonal
|
||||||
|
* @return output (FLOATING_POINT type)
|
||||||
|
*/
|
||||||
|
public SDVariable tri(String name, DataType dataType, int row, int column, int diagonal) {
|
||||||
|
SDVariable out = new org.nd4j.linalg.api.ops.custom.Tri(sd,dataType, row, column, diagonal).outputVariable();
|
||||||
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An array with ones at and below the given diagonal and zeros elsewhere.<br>
|
||||||
|
*
|
||||||
|
* @param row
|
||||||
|
* @param column
|
||||||
|
* @return output (FLOATING_POINT type)
|
||||||
|
*/
|
||||||
|
public SDVariable tri(int row, int column) {
|
||||||
|
return new org.nd4j.linalg.api.ops.custom.Tri(sd,DataType.FLOAT, row, column, 0).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An array with ones at and below the given diagonal and zeros elsewhere.<br>
|
||||||
|
*
|
||||||
|
* @param name name May be null. Name for the output variable
|
||||||
|
* @param row
|
||||||
|
* @param column
|
||||||
|
* @return output (FLOATING_POINT type)
|
||||||
|
*/
|
||||||
|
public SDVariable tri(String name, int row, int column) {
|
||||||
|
SDVariable out = new org.nd4j.linalg.api.ops.custom.Tri(sd,DataType.FLOAT, row, column, 0).outputVariable();
|
||||||
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Upper triangle of an array. Return a copy of a input tensor with the elements below the k-th diagonal zeroed.<br>
|
||||||
|
*
|
||||||
|
* @param input (NUMERIC type)
|
||||||
|
* @param diag
|
||||||
|
* @return output (FLOATING_POINT type)
|
||||||
|
*/
|
||||||
|
public SDVariable triu(SDVariable input, int diag) {
|
||||||
|
SDValidation.validateNumerical("triu", "input", input);
|
||||||
|
return new org.nd4j.linalg.api.ops.custom.Triu(sd,input, diag).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Upper triangle of an array. Return a copy of a input tensor with the elements below the k-th diagonal zeroed.<br>
|
||||||
|
*
|
||||||
|
* @param name name May be null. Name for the output variable
|
||||||
|
* @param input (NUMERIC type)
|
||||||
|
* @param diag
|
||||||
|
* @return output (FLOATING_POINT type)
|
||||||
|
*/
|
||||||
|
public SDVariable triu(String name, SDVariable input, int diag) {
|
||||||
|
SDValidation.validateNumerical("triu", "input", input);
|
||||||
|
SDVariable out = new org.nd4j.linalg.api.ops.custom.Triu(sd,input, diag).outputVariable();
|
||||||
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Upper triangle of an array. Return a copy of a input tensor with the elements below the k-th diagonal zeroed.<br>
|
||||||
|
*
|
||||||
|
* @param input (NUMERIC type)
|
||||||
|
* @return output (FLOATING_POINT type)
|
||||||
|
*/
|
||||||
|
public SDVariable triu(SDVariable input) {
|
||||||
|
SDValidation.validateNumerical("triu", "input", input);
|
||||||
|
return new org.nd4j.linalg.api.ops.custom.Triu(sd,input, 0).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Upper triangle of an array. Return a copy of a input tensor with the elements below the k-th diagonal zeroed.<br>
|
||||||
|
*
|
||||||
|
* @param name name May be null. Name for the output variable
|
||||||
|
* @param input (NUMERIC type)
|
||||||
|
* @return output (FLOATING_POINT type)
|
||||||
|
*/
|
||||||
|
public SDVariable triu(String name, SDVariable input) {
|
||||||
|
SDValidation.validateNumerical("triu", "input", input);
|
||||||
|
SDVariable out = new org.nd4j.linalg.api.ops.custom.Triu(sd,input, 0).outputVariable();
|
||||||
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,13 +67,13 @@ public class SDMath extends SDOps {
|
||||||
* Looks up ids in a list of embedding tensors.<br>
|
* Looks up ids in a list of embedding tensors.<br>
|
||||||
*
|
*
|
||||||
* @param x Input tensor (NUMERIC type)
|
* @param x Input tensor (NUMERIC type)
|
||||||
* @param indices A Tensor containing the ids to be looked up. (NUMERIC type)
|
* @param indices A Tensor containing the ids to be looked up. (INT type)
|
||||||
* @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div'
|
* @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div'
|
||||||
* @return output Shifted output (NUMERIC type)
|
* @return output Shifted output (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable embeddingLookup(SDVariable x, SDVariable indices, PartitionMode PartitionMode) {
|
public SDVariable embeddingLookup(SDVariable x, SDVariable indices, PartitionMode PartitionMode) {
|
||||||
SDValidation.validateNumerical("EmbeddingLookup", "x", x);
|
SDValidation.validateNumerical("EmbeddingLookup", "x", x);
|
||||||
SDValidation.validateNumerical("EmbeddingLookup", "indices", indices);
|
SDValidation.validateInteger("EmbeddingLookup", "indices", indices);
|
||||||
return new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,18 +82,72 @@ public class SDMath extends SDOps {
|
||||||
*
|
*
|
||||||
* @param name name May be null. Name for the output variable
|
* @param name name May be null. Name for the output variable
|
||||||
* @param x Input tensor (NUMERIC type)
|
* @param x Input tensor (NUMERIC type)
|
||||||
* @param indices A Tensor containing the ids to be looked up. (NUMERIC type)
|
* @param indices A Tensor containing the ids to be looked up. (INT type)
|
||||||
* @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div'
|
* @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div'
|
||||||
* @return output Shifted output (NUMERIC type)
|
* @return output Shifted output (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public SDVariable embeddingLookup(String name, SDVariable x, SDVariable indices,
|
public SDVariable embeddingLookup(String name, SDVariable x, SDVariable indices,
|
||||||
PartitionMode PartitionMode) {
|
PartitionMode PartitionMode) {
|
||||||
SDValidation.validateNumerical("EmbeddingLookup", "x", x);
|
SDValidation.validateNumerical("EmbeddingLookup", "x", x);
|
||||||
SDValidation.validateNumerical("EmbeddingLookup", "indices", indices);
|
SDValidation.validateInteger("EmbeddingLookup", "indices", indices);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return array of max elements indices with along tensor dimensions <br>
|
||||||
|
*
|
||||||
|
* @param x Input tensor (NUMERIC type)
|
||||||
|
* @param dataType Data type
|
||||||
|
* @return output Array max elements indices with along dimensions. (INT type)
|
||||||
|
*/
|
||||||
|
public SDVariable mergeMaxIndex(SDVariable[] x, DataType dataType) {
|
||||||
|
SDValidation.validateNumerical("MergeMaxIndex", "x", x);
|
||||||
|
Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length);
|
||||||
|
return new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd,x, dataType).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return array of max elements indices with along tensor dimensions <br>
|
||||||
|
*
|
||||||
|
* @param name name May be null. Name for the output variable
|
||||||
|
* @param x Input tensor (NUMERIC type)
|
||||||
|
* @param dataType Data type
|
||||||
|
* @return output Array max elements indices with along dimensions. (INT type)
|
||||||
|
*/
|
||||||
|
public SDVariable mergeMaxIndex(String name, SDVariable[] x, DataType dataType) {
|
||||||
|
SDValidation.validateNumerical("MergeMaxIndex", "x", x);
|
||||||
|
Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length);
|
||||||
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd,x, dataType).outputVariable();
|
||||||
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return array of max elements indices with along tensor dimensions <br>
|
||||||
|
*
|
||||||
|
* @param x Input tensor (NUMERIC type)
|
||||||
|
* @return output Array max elements indices with along dimensions. (INT type)
|
||||||
|
*/
|
||||||
|
public SDVariable mergeMaxIndex(SDVariable... x) {
|
||||||
|
SDValidation.validateNumerical("MergeMaxIndex", "x", x);
|
||||||
|
Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length);
|
||||||
|
return new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd,x, DataType.INT).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return array of max elements indices with along tensor dimensions <br>
|
||||||
|
*
|
||||||
|
* @param name name May be null. Name for the output variable
|
||||||
|
* @param x Input tensor (NUMERIC type)
|
||||||
|
* @return output Array max elements indices with along dimensions. (INT type)
|
||||||
|
*/
|
||||||
|
public SDVariable mergeMaxIndex(String name, SDVariable... x) {
|
||||||
|
SDValidation.validateNumerical("MergeMaxIndex", "x", x);
|
||||||
|
Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length);
|
||||||
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd,x, DataType.INT).outputVariable();
|
||||||
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Elementwise absolute value operation: out = abs(x)<br>
|
* Elementwise absolute value operation: out = abs(x)<br>
|
||||||
*
|
*
|
||||||
|
|
|
@ -143,6 +143,8 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2DDerivative.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2DDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling3d.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling3dBp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2dDerivative.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2dDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU.class,
|
org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUBp.class,
|
org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUBp.class,
|
||||||
|
@ -301,6 +303,7 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.shape.Linspace.class,
|
org.nd4j.linalg.api.ops.impl.shape.Linspace.class,
|
||||||
org.nd4j.linalg.api.ops.impl.shape.MergeAvg.class,
|
org.nd4j.linalg.api.ops.impl.shape.MergeAvg.class,
|
||||||
org.nd4j.linalg.api.ops.impl.shape.MergeMax.class,
|
org.nd4j.linalg.api.ops.impl.shape.MergeMax.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex.class,
|
||||||
org.nd4j.linalg.api.ops.impl.shape.MergeSum.class,
|
org.nd4j.linalg.api.ops.impl.shape.MergeSum.class,
|
||||||
org.nd4j.linalg.api.ops.impl.shape.MeshGrid.class,
|
org.nd4j.linalg.api.ops.impl.shape.MeshGrid.class,
|
||||||
org.nd4j.linalg.api.ops.impl.shape.OneHot.class,
|
org.nd4j.linalg.api.ops.impl.shape.OneHot.class,
|
||||||
|
@ -426,6 +429,7 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.ParallelConcat.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.ParallelConcat.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.Pow.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.Pow.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseBp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseV2.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseV2.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits.class,
|
||||||
|
@ -642,6 +646,9 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.custom.RandomCrop.class,
|
org.nd4j.linalg.api.ops.custom.RandomCrop.class,
|
||||||
org.nd4j.linalg.api.ops.custom.Roll.class,
|
org.nd4j.linalg.api.ops.custom.Roll.class,
|
||||||
org.nd4j.linalg.api.ops.custom.ToggleBits.class,
|
org.nd4j.linalg.api.ops.custom.ToggleBits.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.Tri.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.Triu.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.TriuBp.class,
|
||||||
org.nd4j.linalg.api.ops.custom.Igamma.class,
|
org.nd4j.linalg.api.ops.custom.Igamma.class,
|
||||||
org.nd4j.linalg.api.ops.custom.Igammac.class,
|
org.nd4j.linalg.api.ops.custom.Igammac.class,
|
||||||
org.nd4j.linalg.api.ops.custom.Digamma.class,
|
org.nd4j.linalg.api.ops.custom.Digamma.class,
|
||||||
|
|
|
@ -85,7 +85,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public DynamicCustomOp(SameDiff sameDiff, SDVariable arg) {
|
public DynamicCustomOp(SameDiff sameDiff, SDVariable arg) {
|
||||||
this(sameDiff, new SDVariable[]{arg});
|
this(sameDiff, wrapOrNull(arg));
|
||||||
}
|
}
|
||||||
|
|
||||||
public DynamicCustomOp(SameDiff sameDiff, SDVariable[] args) {
|
public DynamicCustomOp(SameDiff sameDiff, SDVariable[] args) {
|
||||||
|
@ -655,6 +655,10 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
outputArguments.clear();
|
outputArguments.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected static SDVariable[] wrapOrNull(SDVariable in){
|
||||||
|
return in == null ? null : new SDVariable[]{in};
|
||||||
|
}
|
||||||
|
|
||||||
protected static INDArray[] wrapOrNull(INDArray in){
|
protected static INDArray[] wrapOrNull(INDArray in){
|
||||||
return in == null ? null : new INDArray[]{in};
|
return in == null ? null : new INDArray[]{in};
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class Tri extends DynamicCustomOp {
|
||||||
|
|
||||||
|
private DataType dataType = DataType.FLOAT;
|
||||||
|
|
||||||
|
public Tri(SameDiff sameDiff, int row, int column, int diag) {
|
||||||
|
super(sameDiff, new SDVariable[]{});
|
||||||
|
addIArgument(row,column,diag);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Tri(SameDiff sameDiff, DataType dataType, int row, int column, int diag) {
|
||||||
|
super(sameDiff, new SDVariable[]{});
|
||||||
|
addIArgument(row,column,diag);
|
||||||
|
addDArgument(dataType);
|
||||||
|
this.dataType = dataType;
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public Tri(int row, int column, int diag) {
|
||||||
|
super(new INDArray[]{}, null);
|
||||||
|
addIArgument(row,column,diag);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public Tri(DataType dataType, int row, int column, int diag) {
|
||||||
|
super(new INDArray[]{}, null);
|
||||||
|
addIArgument(row,column,diag);
|
||||||
|
addDArgument(dataType);
|
||||||
|
this.dataType = dataType;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "tri";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||||
|
|
||||||
|
return Collections.singletonList(this.dataType);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,71 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class Triu extends DynamicCustomOp {
|
||||||
|
|
||||||
|
private int diag = 0;
|
||||||
|
|
||||||
|
public Triu(SameDiff sameDiff, SDVariable in, int diag) {
|
||||||
|
super(sameDiff, new SDVariable[]{in});
|
||||||
|
addIArgument(diag);
|
||||||
|
this.diag=diag;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Triu(SameDiff sameDiff, SDVariable in) {
|
||||||
|
super(sameDiff, new SDVariable[]{in});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
public Triu(INDArray input, int diag) {
|
||||||
|
super(new INDArray[]{input}, null);
|
||||||
|
addIArgument(diag);
|
||||||
|
this.diag=diag;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "triu";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||||
|
return Collections.singletonList(arg(0).dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
|
||||||
|
return new TriuBp(sameDiff, arg(0), f1.get(0), diag).outputs();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,55 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseBp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class TriuBp extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public TriuBp(SameDiff sameDiff, SDVariable in, SDVariable grad, int diag) {
|
||||||
|
super(sameDiff, new SDVariable[]{in, grad});
|
||||||
|
addIArgument(diag);
|
||||||
|
}
|
||||||
|
|
||||||
|
public TriuBp(SameDiff sameDiff, SDVariable in, SDVariable grad) {
|
||||||
|
super(sameDiff, new SDVariable[]{in, grad});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "triu_bp";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||||
|
|
||||||
|
return Collections.singletonList(arg(0).dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,99 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Upsampling3d operation
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Getter
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class Upsampling3d extends DynamicCustomOp {
|
||||||
|
|
||||||
|
|
||||||
|
protected boolean ncdhw;
|
||||||
|
protected int scaleH;
|
||||||
|
protected int scaleW;
|
||||||
|
protected int scaleD;
|
||||||
|
|
||||||
|
public Upsampling3d(SameDiff sameDiff, SDVariable input, boolean ncdhw, int scaleD, int scaleH, int scaleW) {
|
||||||
|
super("upsampling3d",sameDiff, new SDVariable[]{input});
|
||||||
|
this.ncdhw = ncdhw;
|
||||||
|
|
||||||
|
this.scaleD = scaleD;
|
||||||
|
this.scaleH = scaleH;
|
||||||
|
this.scaleW = scaleW;
|
||||||
|
|
||||||
|
addIArgument(scaleD);
|
||||||
|
addIArgument(scaleH);
|
||||||
|
addIArgument(scaleW);
|
||||||
|
addIArgument(scaleD);
|
||||||
|
addIArgument(ncdhw ? 1 : 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
public Upsampling3d(INDArray input, boolean ncdhw, int scaleH, int scaleW, int scaleD) {
|
||||||
|
super(new INDArray[]{input}, null);
|
||||||
|
this.ncdhw = ncdhw;
|
||||||
|
|
||||||
|
this.scaleD = scaleD;
|
||||||
|
this.scaleH = scaleH;
|
||||||
|
this.scaleW = scaleW;
|
||||||
|
|
||||||
|
addIArgument(scaleD);
|
||||||
|
addIArgument(scaleH);
|
||||||
|
addIArgument(scaleW);
|
||||||
|
addIArgument(scaleD);
|
||||||
|
addIArgument(ncdhw ? 0 : 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "upsampling3d";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
return Arrays.asList(new Upsampling3dBp(sameDiff, arg(0), f1.get(0), this.ncdhw).outputVariables());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes);
|
||||||
|
return Collections.singletonList(inputDataTypes.get(0));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,54 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
|
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class Upsampling3dBp extends DynamicCustomOp {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
public Upsampling3dBp(SameDiff sameDiff, SDVariable input, SDVariable grad0, boolean ncdhw) {
|
||||||
|
super("upsampling3d_bp",sameDiff, new SDVariable[]{input, grad0});
|
||||||
|
addIArgument(ncdhw ? 1 : 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "upsampling3d_bp";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected 2 input data type for %s, got %s", getClass(), inputDataTypes);
|
||||||
|
return Collections.singletonList(inputDataTypes.get(0));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,85 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
|
||||||
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class MergeMaxIndex extends DynamicCustomOp {
|
||||||
|
|
||||||
|
private DataType dataType = DataType.INT32;
|
||||||
|
|
||||||
|
public MergeMaxIndex(@NonNull SameDiff sameDiff, @NonNull SDVariable... inputs) {
|
||||||
|
super("mergemaxindex", sameDiff, inputs);
|
||||||
|
addIArgument(dataType.toInt());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public MergeMaxIndex(@NonNull INDArray... inputs) {
|
||||||
|
super("mergemaxindex", inputs, null);
|
||||||
|
Preconditions.checkArgument(areEqualShapes(inputs), "All inputs have to be equal shapes");
|
||||||
|
addIArgument(dataType.toInt());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public MergeMaxIndex(@NonNull SameDiff sd, @NonNull SDVariable[] x, @NonNull DataType dataType) {
|
||||||
|
super("mergemaxindex", sd, x);
|
||||||
|
this.dataType = dataType;
|
||||||
|
addIArgument(dataType.toInt());
|
||||||
|
}
|
||||||
|
|
||||||
|
public MergeMaxIndex(@NonNull INDArray[] x, @NonNull DataType dataType) {
|
||||||
|
super(x, null);
|
||||||
|
Preconditions.checkArgument(areEqualShapes(x), "All inputs have to be equal shapes");
|
||||||
|
this.dataType = dataType;
|
||||||
|
addIArgument(dataType.toInt());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
protected static boolean areEqualShapes(INDArray... inputs) {
|
||||||
|
for (INDArray input : inputs) {
|
||||||
|
if (!inputs[0].equalShapes(input)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "mergemaxindex";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
return Collections.singletonList(this.dataType);
|
||||||
|
}
|
||||||
|
}
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -23,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.shape.bp.MergeAvgBp;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
@ -30,8 +32,8 @@ import java.util.List;
|
||||||
|
|
||||||
public class Reverse extends DynamicCustomOp {
|
public class Reverse extends DynamicCustomOp {
|
||||||
|
|
||||||
public Reverse(SameDiff sameDiff, SDVariable i_v, int... dimensions) {
|
public Reverse(@NonNull SameDiff sameDiff, @NonNull SDVariable i_v, @NonNull int... dimensions) {
|
||||||
super(null, sameDiff, new SDVariable[]{i_v}, false);
|
super(sameDiff, new SDVariable[]{i_v});
|
||||||
this.dimensions = dimensions;
|
this.dimensions = dimensions;
|
||||||
addIArgument(dimensions);
|
addIArgument(dimensions);
|
||||||
}
|
}
|
||||||
|
@ -56,6 +58,7 @@ public class Reverse extends DynamicCustomOp {
|
||||||
public Reverse(INDArray x, int... axis){
|
public Reverse(INDArray x, int... axis){
|
||||||
super(new INDArray[]{x}, new INDArray[0]);
|
super(new INDArray[]{x}, new INDArray[0]);
|
||||||
this.inPlace = false;
|
this.inPlace = false;
|
||||||
|
this.dimensions = axis;
|
||||||
addIArgument(axis);
|
addIArgument(axis);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,6 +70,7 @@ public class Reverse extends DynamicCustomOp {
|
||||||
public Reverse(INDArray x, INDArray z, int... axis){
|
public Reverse(INDArray x, INDArray z, int... axis){
|
||||||
super(new INDArray[]{x}, new INDArray[] {z});
|
super(new INDArray[]{x}, new INDArray[] {z});
|
||||||
this.inPlace = false;
|
this.inPlace = false;
|
||||||
|
this.dimensions = axis;
|
||||||
addIArgument(axis);
|
addIArgument(axis);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,8 +104,7 @@ public class Reverse extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
SDVariable ret = sameDiff.reverse(f1.get(0), dimensions);
|
return new ReverseBp(sameDiff, arg(0), f1.get(0), dimensions).outputs();
|
||||||
return Collections.singletonList(ret);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class ReverseBp extends DynamicCustomOp {
|
||||||
|
public ReverseBp(@NonNull SameDiff sameDiff, @NonNull SDVariable i_v, @NonNull SDVariable grad, @NonNull int... dimensions) {
|
||||||
|
super(sameDiff, new SDVariable[]{i_v, grad});
|
||||||
|
addIArgument(dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "reverse_bp";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||||
|
return Collections.singletonList(inputDataTypes.get(0));
|
||||||
|
}
|
||||||
|
}
|
|
@ -508,4 +508,19 @@ public class NDCNN {
|
||||||
NDValidation.validateNumerical("upsampling2d", "input", input);
|
NDValidation.validateNumerical("upsampling2d", "input", input);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(input, scaleH, scaleW, nchw))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(input, scaleH, scaleW, nchw))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 3D Convolution layer operation - Upsampling 3d <br>
|
||||||
|
*
|
||||||
|
* @param input Input in NCHW format (NUMERIC type)
|
||||||
|
* @param ncdhw If true: input is in NCDHW (minibatch, channels, depth, height, width) format. False: NDHWC format
|
||||||
|
* @param scaleD Scale to upsample in depth dimension
|
||||||
|
* @param scaleH Scale to upsample in height dimension
|
||||||
|
* @param scaleW Scale to upsample in width dimension
|
||||||
|
* @return output Upsampled input (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray upsampling3d(INDArray input, boolean ncdhw, int scaleD, int scaleH, int scaleW) {
|
||||||
|
NDValidation.validateNumerical("upsampling3d", "input", input);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling3d(input, ncdhw, scaleD, scaleH, scaleW))[0];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -138,7 +138,7 @@ public class NDImage {
|
||||||
/**
|
/**
|
||||||
* Resize images to size using the specified method.<br>
|
* Resize images to size using the specified method.<br>
|
||||||
*
|
*
|
||||||
* @param input 4D image [NCHW] (NUMERIC type)
|
* @param input 4D image [NHWC] (NUMERIC type)
|
||||||
* @param size new height and width (INT type)
|
* @param size new height and width (INT type)
|
||||||
* @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False.
|
* @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False.
|
||||||
* @param antialis Whether to use an anti-aliasing filter when downsampling an image
|
* @param antialis Whether to use an anti-aliasing filter when downsampling an image
|
||||||
|
@ -161,7 +161,7 @@ public class NDImage {
|
||||||
/**
|
/**
|
||||||
* Resize images to size using the specified method.<br>
|
* Resize images to size using the specified method.<br>
|
||||||
*
|
*
|
||||||
* @param input 4D image [NCHW] (NUMERIC type)
|
* @param input 4D image [NHWC] (NUMERIC type)
|
||||||
* @param size new height and width (INT type)
|
* @param size new height and width (INT type)
|
||||||
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
|
* @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
|
||||||
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
|
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
|
||||||
|
|
|
@ -20,6 +20,7 @@ package org.nd4j.linalg.factory.ops;
|
||||||
|
|
||||||
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.NDValidation;
|
import org.nd4j.linalg.factory.NDValidation;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -271,4 +272,51 @@ public class NDLinalg {
|
||||||
NDValidation.validateNumerical("svd", "input", input);
|
NDValidation.validateNumerical("svd", "input", input);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(input, fullUV, computeUV, 16))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(input, fullUV, computeUV, 16))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An array with ones at and below the given diagonal and zeros elsewhere.<br>
|
||||||
|
*
|
||||||
|
* @param dataType Data type
|
||||||
|
* @param row
|
||||||
|
* @param column
|
||||||
|
* @param diagonal
|
||||||
|
* @return output (FLOATING_POINT type)
|
||||||
|
*/
|
||||||
|
public INDArray tri(DataType dataType, int row, int column, int diagonal) {
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Tri(dataType, row, column, diagonal))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An array with ones at and below the given diagonal and zeros elsewhere.<br>
|
||||||
|
*
|
||||||
|
* @param row
|
||||||
|
* @param column
|
||||||
|
* @return output (FLOATING_POINT type)
|
||||||
|
*/
|
||||||
|
public INDArray tri(int row, int column) {
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Tri(DataType.FLOAT, row, column, 0))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Upper triangle of an array. Return a copy of a input tensor with the elements below the k-th diagonal zeroed.<br>
|
||||||
|
*
|
||||||
|
* @param input (NUMERIC type)
|
||||||
|
* @param diag
|
||||||
|
* @return output (FLOATING_POINT type)
|
||||||
|
*/
|
||||||
|
public INDArray triu(INDArray input, int diag) {
|
||||||
|
NDValidation.validateNumerical("triu", "input", input);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Triu(input, diag))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Upper triangle of an array. Return a copy of a input tensor with the elements below the k-th diagonal zeroed.<br>
|
||||||
|
*
|
||||||
|
* @param input (NUMERIC type)
|
||||||
|
* @return output (FLOATING_POINT type)
|
||||||
|
*/
|
||||||
|
public INDArray triu(INDArray input) {
|
||||||
|
NDValidation.validateNumerical("triu", "input", input);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Triu(input, 0))[0];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,16 +50,41 @@ public class NDMath {
|
||||||
* Looks up ids in a list of embedding tensors.<br>
|
* Looks up ids in a list of embedding tensors.<br>
|
||||||
*
|
*
|
||||||
* @param x Input tensor (NUMERIC type)
|
* @param x Input tensor (NUMERIC type)
|
||||||
* @param indices A Tensor containing the ids to be looked up. (NUMERIC type)
|
* @param indices A Tensor containing the ids to be looked up. (INT type)
|
||||||
* @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div'
|
* @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div'
|
||||||
* @return output Shifted output (NUMERIC type)
|
* @return output Shifted output (NUMERIC type)
|
||||||
*/
|
*/
|
||||||
public INDArray embeddingLookup(INDArray x, INDArray indices, PartitionMode PartitionMode) {
|
public INDArray embeddingLookup(INDArray x, INDArray indices, PartitionMode PartitionMode) {
|
||||||
NDValidation.validateNumerical("EmbeddingLookup", "x", x);
|
NDValidation.validateNumerical("EmbeddingLookup", "x", x);
|
||||||
NDValidation.validateNumerical("EmbeddingLookup", "indices", indices);
|
NDValidation.validateInteger("EmbeddingLookup", "indices", indices);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(x, indices, PartitionMode))[0];
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(x, indices, PartitionMode))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return array of max elements indices with along tensor dimensions <br>
|
||||||
|
*
|
||||||
|
* @param x Input tensor (NUMERIC type)
|
||||||
|
* @param dataType Data type
|
||||||
|
* @return output Array max elements indices with along dimensions. (INT type)
|
||||||
|
*/
|
||||||
|
public INDArray mergeMaxIndex(INDArray[] x, DataType dataType) {
|
||||||
|
NDValidation.validateNumerical("MergeMaxIndex", "x", x);
|
||||||
|
Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(x, dataType))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return array of max elements indices with along tensor dimensions <br>
|
||||||
|
*
|
||||||
|
* @param x Input tensor (NUMERIC type)
|
||||||
|
* @return output Array max elements indices with along dimensions. (INT type)
|
||||||
|
*/
|
||||||
|
public INDArray mergeMaxIndex(INDArray... x) {
|
||||||
|
NDValidation.validateNumerical("MergeMaxIndex", "x", x);
|
||||||
|
Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(x, DataType.INT))[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Elementwise absolute value operation: out = abs(x)<br>
|
* Elementwise absolute value operation: out = abs(x)<br>
|
||||||
*
|
*
|
||||||
|
|
|
@ -33,6 +33,8 @@ import org.nd4j.autodiff.validation.TestCase;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Tri;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Triu;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.*;
|
import org.nd4j.linalg.api.ops.impl.shape.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.Fill;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Fill;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
|
@ -2525,4 +2527,49 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertArrayEquals(exp, out.shape());
|
assertArrayEquals(exp, out.shape());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMergeMaxIndex() {
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable inputX = sd.var(Nd4j.createFromArray(new float[] {1, 0, 0}));
|
||||||
|
SDVariable inputY = sd.var(Nd4j.createFromArray(new float[] {0, 1, 0}));
|
||||||
|
SDVariable inputZ = sd.var(Nd4j.createFromArray(new float[] {0, 0, 1}));
|
||||||
|
SDVariable out = new MergeMaxIndex(sd, new SDVariable[]{inputX, inputY, inputZ},DataType.INT32).outputVariable();
|
||||||
|
INDArray expected = Nd4j.createFromArray(0,1,2);
|
||||||
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
|
.expectedOutput("mergemaxindex", expected)
|
||||||
|
.gradientCheck(false));
|
||||||
|
assertNull(err);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTriOp() {
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable out = new Tri(sd, DataType.INT32, 3, 5, 2).outputVariable();
|
||||||
|
INDArray expected = Nd4j.createFromArray(new int[][]{{1, 1, 1, 0, 0}, {1, 1, 1, 1, 0}, {1, 1, 1, 1, 1}});
|
||||||
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
|
.expectedOutput("tri", expected)
|
||||||
|
.gradientCheck(false));
|
||||||
|
assertNull(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTriuOp() {
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable input = sd.var(Nd4j.createFromArray(new double[][]{{1,2,3}, {4,5,6}, {7,8,9},{10,11,12}}));
|
||||||
|
SDVariable out = new Triu(sd, input,-1).outputVariable();
|
||||||
|
out.markAsLoss();
|
||||||
|
INDArray expected = Nd4j.createFromArray(new double[][]{{1,2,3}, {4,5,6}, {0,8,9},{0,0,12}});
|
||||||
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
|
.expectedOutput("triu", expected)
|
||||||
|
.gradientCheck(true));
|
||||||
|
assertNull(err);
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,11 +40,13 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.image.ImageResize;
|
import org.nd4j.linalg.api.ops.impl.image.ImageResize;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling3d;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod;
|
import org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication;
|
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Cross;
|
import org.nd4j.linalg.api.ops.impl.shape.Cross;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.MergeAvg;
|
import org.nd4j.linalg.api.ops.impl.shape.MergeAvg;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.MergeMax;
|
import org.nd4j.linalg.api.ops.impl.shape.MergeMax;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup;
|
import org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.Pad;
|
import org.nd4j.linalg.api.ops.impl.transforms.Pad;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm;
|
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm;
|
||||||
|
@ -2126,7 +2128,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
SDVariable out = new ImageResize(sd, inputImage, requestedSize, preserveAspectRatio, antialias, method).outputVariable();
|
SDVariable out = new ImageResize(sd, inputImage, requestedSize, preserveAspectRatio, antialias, method).outputVariable().std(true);
|
||||||
|
|
||||||
String err = OpValidation.validate(new TestCase(sd)
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
.gradientCheck(false)
|
.gradientCheck(false)
|
||||||
|
@ -2150,7 +2152,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
SDVariable inputY = sd.var(Nd4j.rand(2, 3));
|
SDVariable inputY = sd.var(Nd4j.rand(2, 3));
|
||||||
|
|
||||||
|
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd, inputX, inputY).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd, inputX, inputY).outputVariable().std(true);
|
||||||
String err = OpValidation.validate(new TestCase(sd)
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
.gradientCheck(true));
|
.gradientCheck(true));
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
|
@ -2166,7 +2168,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
SDVariable inputX = sd.var(Nd4j.rand(2, 3));
|
SDVariable inputX = sd.var(Nd4j.rand(2, 3));
|
||||||
SDVariable inputY = sd.var(Nd4j.rand(2, 3));
|
SDVariable inputY = sd.var(Nd4j.rand(2, 3));
|
||||||
SDVariable inputZ = sd.var(Nd4j.rand(2, 3));
|
SDVariable inputZ = sd.var(Nd4j.rand(2, 3));
|
||||||
SDVariable out = new MergeAddOp(sd, new SDVariable[]{inputX, inputY, inputZ}).outputVariable();
|
SDVariable out = new MergeAddOp(sd, new SDVariable[]{inputX, inputY, inputZ}).outputVariable().std(true);
|
||||||
out.markAsLoss();
|
out.markAsLoss();
|
||||||
String err = OpValidation.validate(new TestCase(sd)
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
.gradientCheck(true));
|
.gradientCheck(true));
|
||||||
|
@ -2183,7 +2185,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
SDVariable inputX = sd.var(Nd4j.rand(2, 3));
|
SDVariable inputX = sd.var(Nd4j.rand(2, 3));
|
||||||
SDVariable inputY = sd.var(Nd4j.rand(2, 3));
|
SDVariable inputY = sd.var(Nd4j.rand(2, 3));
|
||||||
SDVariable inputZ = sd.var(Nd4j.rand(2, 3));
|
SDVariable inputZ = sd.var(Nd4j.rand(2, 3));
|
||||||
SDVariable out = new MergeMax(sd, new SDVariable[]{inputX, inputY, inputZ}).outputVariable();
|
SDVariable out = new MergeMax(sd, inputX, inputY, inputZ).outputVariable().std(true);
|
||||||
out.markAsLoss();
|
out.markAsLoss();
|
||||||
String err = OpValidation.validate(new TestCase(sd)
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
.gradientCheck(true));
|
.gradientCheck(true));
|
||||||
|
@ -2201,7 +2203,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
SDVariable inputX = sd.var(Nd4j.rand(2, 3));
|
SDVariable inputX = sd.var(Nd4j.rand(2, 3));
|
||||||
SDVariable inputY = sd.var(Nd4j.rand(2, 3));
|
SDVariable inputY = sd.var(Nd4j.rand(2, 3));
|
||||||
SDVariable inputZ = sd.var(Nd4j.rand(2, 3));
|
SDVariable inputZ = sd.var(Nd4j.rand(2, 3));
|
||||||
SDVariable out = new MergeAvg(sd, new SDVariable[]{inputX, inputY, inputZ}).outputVariable();
|
SDVariable out = new MergeAvg(sd, inputX, inputY, inputZ).outputVariable().std(true);
|
||||||
out.markAsLoss();
|
out.markAsLoss();
|
||||||
String err = OpValidation.validate(new TestCase(sd)
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
.gradientCheck(true));
|
.gradientCheck(true));
|
||||||
|
@ -2210,6 +2212,44 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testReverseBp() {
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable input = sd.var(Nd4j.createFromArray(new double[][]{{2,7}, {3,5}, {4,5}}));
|
||||||
|
SDVariable out = new Reverse(sd, input,0).outputVariable();
|
||||||
|
SDVariable loss = out.std(true);
|
||||||
|
loss.markAsLoss();
|
||||||
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
|
.gradientCheck(true));
|
||||||
|
assertNull(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUpsampling3dBp() {
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
for (boolean dataformat : new boolean[]{true, false}) {
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
// NCDHW input
|
||||||
|
SDVariable input = dataformat ? sd.var(Nd4j.rand(DataType.DOUBLE, 2, 1, 5, 5, 5)) : sd.var(Nd4j.rand(DataType.DOUBLE, 2, 5, 5, 5, 1));
|
||||||
|
int scaleD = 2;
|
||||||
|
int scaleH = 2;
|
||||||
|
int scaleW = 2;
|
||||||
|
SDVariable out = new Upsampling3d(sd, input, true, scaleD, scaleH, scaleW).outputVariable().std(true);
|
||||||
|
out.markAsLoss();
|
||||||
|
String err = OpValidation.validate(new TestCase(sd)
|
||||||
|
.gradientCheck(true));
|
||||||
|
assertNull(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue