diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper.java index 3edbf0b28..abb84e965 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper.java @@ -87,11 +87,7 @@ public class MKLDNNSubsamplingHelper implements SubsamplingHelper { break; } - Pooling2DDerivative d = Pooling2DDerivative.derivativeBuilder() - .config(conf) - .arrayInputs(new INDArray[]{input, epsilon}) - .arrayOutputs(new INDArray[]{gradAtInput}) - .build(); + Pooling2DDerivative d = new Pooling2DDerivative(input, epsilon, gradAtInput, conf); Nd4j.exec(d); return new Pair(new DefaultGradient(), gradAtInput); diff --git a/libnd4j/blas/BlasVersionHelper.h b/libnd4j/blas/BlasVersionHelper.h new file mode 100644 index 000000000..93e8d75e3 --- /dev/null +++ b/libnd4j/blas/BlasVersionHelper.h @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SAMEDIFF_BLASVERSIONHELPER_H +#define SAMEDIFF_BLASVERSIONHELPER_H + +#include +#include +#include + +namespace nd4j { + class ND4J_EXPORT BlasVersionHelper { + public: + int _blasMajorVersion = 0; + int _blasMinorVersion = 0; + int _blasPatchVersion = 0; + + BlasVersionHelper(); + ~BlasVersionHelper() = default; + }; +} + +#endif //DEV_TESTS_BLASVERSIONHELPER_H diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index 0480d83d1..257fa44bb 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -253,20 +253,20 @@ if(CUDA_BLAS) file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu) if (NOT BUILD_TESTS) - CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA} + CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA} ${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp - Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} + Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES}) else() set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true") - CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA} + CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA} ${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp - Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} + Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES}) endif() diff --git a/libnd4j/blas/Environment.cpp b/libnd4j/blas/Environment.cpp index 1c3dd2d9e..b4b2db4ae 100644 --- a/libnd4j/blas/Environment.cpp +++ b/libnd4j/blas/Environment.cpp @@ -35,7 +35,7 @@ #include #include - +#include "BlasVersionHelper.h" #endif namespace nd4j { @@ -66,6 +66,13 @@ namespace nd4j { #endif #ifdef __CUDABLAS__ + BlasVersionHelper ver; + _blasMajorVersion = ver._blasMajorVersion; + _blasMinorVersion = ver._blasMinorVersion; + _blasPatchVersion = ver._blasPatchVersion; + printf("ND4J CUDA build version: %i.%i.%i\n", _blasMajorVersion, _blasMinorVersion, _blasPatchVersion); + fflush(stdout); + int devCnt = 0; cudaGetDeviceCount(&devCnt); auto devProperties = new cudaDeviceProp[devCnt]; diff --git a/libnd4j/blas/Environment.h b/libnd4j/blas/Environment.h index 5092b6190..ac4dfa678 100644 --- a/libnd4j/blas/Environment.h +++ b/libnd4j/blas/Environment.h @@ -56,6 +56,13 @@ namespace nd4j{ Environment(); ~Environment(); public: + /** + * These 3 fields are mostly for CUDA/cuBLAS version tracking + */ + int _blasMajorVersion = 0; + int _blasMinorVersion = 0; + int _blasPatchVersion = 0; + static Environment* getInstance(); bool isVerbose(); diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index 9bca7bb10..ef46e7752 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -647,7 +647,7 @@ ND4J_EXPORT void setOmpNumThreads(int threads); ND4J_EXPORT void setOmpMinThreads(int threads); - +ND4J_EXPORT bool isBlasVersionMatches(int major, int minor, int build); /** * diff --git a/libnd4j/blas/cpu/GraphExecutioner.cpp b/libnd4j/blas/cpu/GraphExecutioner.cpp index 6f97bc024..ef45a3e0c 100644 --- a/libnd4j/blas/cpu/GraphExecutioner.cpp +++ b/libnd4j/blas/cpu/GraphExecutioner.cpp @@ -583,7 +583,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace) auto fName = builder.CreateString(*(var->getName())); auto id = CreateIntPair(builder, var->id(), var->index()); - auto fv = CreateFlatVariable(builder, id, fName, static_cast(array->dataType()), 0, fArray); + auto fv = CreateFlatVariable(builder, id, fName, static_cast(array->dataType()), 0, fArray); variables_vector.push_back(fv); arrays++; diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index 86bc04fc4..e016d58fe 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -728,6 +728,10 @@ void execReduce3Tad(Nd4jPointer *extraPointers, } } +bool isBlasVersionMatches(int major, int minor, int build) { + return true; +} + /** * * @param opNum diff --git a/libnd4j/blas/cuda/BlasVersionHelper.cu b/libnd4j/blas/cuda/BlasVersionHelper.cu new file mode 100644 index 000000000..1f80a0cc0 --- /dev/null +++ b/libnd4j/blas/cuda/BlasVersionHelper.cu @@ -0,0 +1,29 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "../BlasVersionHelper.h" + +namespace nd4j { + BlasVersionHelper::BlasVersionHelper() { + _blasMajorVersion = __CUDACC_VER_MAJOR__; + _blasMinorVersion = __CUDACC_VER_MINOR__; + _blasPatchVersion = __CUDACC_VER_BUILD__; + } +} \ No newline at end of file diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index a29613b61..ec88de2e5 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -3357,6 +3357,18 @@ void deleteTadPack(nd4j::TadPack* ptr) { delete ptr; } +bool isBlasVersionMatches(int major, int minor, int build) { + auto result = major == Environment::getInstance()->_blasMajorVersion && minor == Environment::getInstance()->_blasMinorVersion && build == Environment::getInstance()->_blasPatchVersion; + + if (!result) { + nd4j_printf("CUDA/cuBLAS version mismatch. Expected: %i.%i.%i but got %i.%i.%i instead\n", Environment::getInstance()->_blasMajorVersion, Environment::getInstance()->_blasMinorVersion, Environment::getInstance()->_blasPatchVersion, major, minor, build); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(152); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("CUDA/cuBLAS version mismatch"); + } + + return result; +} + nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) { return nd4j::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype); } diff --git a/libnd4j/include/array/DataTypeUtils.h b/libnd4j/include/array/DataTypeUtils.h index 8346442eb..2a52ba6f5 100644 --- a/libnd4j/include/array/DataTypeUtils.h +++ b/libnd4j/include/array/DataTypeUtils.h @@ -38,7 +38,7 @@ namespace nd4j { public: static int asInt(DataType type); static DataType fromInt(int dtype); - static DataType fromFlatDataType(nd4j::graph::DataType dtype); + static DataType fromFlatDataType(nd4j::graph::DType dtype); FORCEINLINE static std::string asString(DataType dataType); template diff --git a/libnd4j/include/array/impl/DataTypeUtils.cpp b/libnd4j/include/array/impl/DataTypeUtils.cpp index f0b261039..cdf688b25 100644 --- a/libnd4j/include/array/impl/DataTypeUtils.cpp +++ b/libnd4j/include/array/impl/DataTypeUtils.cpp @@ -27,7 +27,7 @@ namespace nd4j { return (DataType) val; } - DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DataType dtype) { + DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DType dtype) { return (DataType) dtype; } diff --git a/libnd4j/include/graph/generated/array_generated.h b/libnd4j/include/graph/generated/array_generated.h index 5848c0ac4..b581240ad 100644 --- a/libnd4j/include/graph/generated/array_generated.h +++ b/libnd4j/include/graph/generated/array_generated.h @@ -40,56 +40,56 @@ inline const char *EnumNameByteOrder(ByteOrder e) { return EnumNamesByteOrder()[index]; } -enum DataType { - DataType_INHERIT = 0, - DataType_BOOL = 1, - DataType_FLOAT8 = 2, - DataType_HALF = 3, - DataType_HALF2 = 4, - DataType_FLOAT = 5, - DataType_DOUBLE = 6, - DataType_INT8 = 7, - DataType_INT16 = 8, - DataType_INT32 = 9, - DataType_INT64 = 10, - DataType_UINT8 = 11, - DataType_UINT16 = 12, - DataType_UINT32 = 13, - DataType_UINT64 = 14, - DataType_QINT8 = 15, - DataType_QINT16 = 16, - DataType_BFLOAT16 = 17, - DataType_UTF8 = 50, - DataType_MIN = DataType_INHERIT, - DataType_MAX = DataType_UTF8 +enum DType { + DType_INHERIT = 0, + DType_BOOL = 1, + DType_FLOAT8 = 2, + DType_HALF = 3, + DType_HALF2 = 4, + DType_FLOAT = 5, + DType_DOUBLE = 6, + DType_INT8 = 7, + DType_INT16 = 8, + DType_INT32 = 9, + DType_INT64 = 10, + DType_UINT8 = 11, + DType_UINT16 = 12, + DType_UINT32 = 13, + DType_UINT64 = 14, + DType_QINT8 = 15, + DType_QINT16 = 16, + DType_BFLOAT16 = 17, + DType_UTF8 = 50, + DType_MIN = DType_INHERIT, + DType_MAX = DType_UTF8 }; -inline const DataType (&EnumValuesDataType())[19] { - static const DataType values[] = { - DataType_INHERIT, - DataType_BOOL, - DataType_FLOAT8, - DataType_HALF, - DataType_HALF2, - DataType_FLOAT, - DataType_DOUBLE, - DataType_INT8, - DataType_INT16, - DataType_INT32, - DataType_INT64, - DataType_UINT8, - DataType_UINT16, - DataType_UINT32, - DataType_UINT64, - DataType_QINT8, - DataType_QINT16, - DataType_BFLOAT16, - DataType_UTF8 +inline const DType (&EnumValuesDType())[19] { + static const DType values[] = { + DType_INHERIT, + DType_BOOL, + DType_FLOAT8, + DType_HALF, + DType_HALF2, + DType_FLOAT, + DType_DOUBLE, + DType_INT8, + DType_INT16, + DType_INT32, + DType_INT64, + DType_UINT8, + DType_UINT16, + DType_UINT32, + DType_UINT64, + DType_QINT8, + DType_QINT16, + DType_BFLOAT16, + DType_UTF8 }; return values; } -inline const char * const *EnumNamesDataType() { +inline const char * const *EnumNamesDType() { static const char * const names[] = { "INHERIT", "BOOL", @@ -147,9 +147,9 @@ inline const char * const *EnumNamesDataType() { return names; } -inline const char *EnumNameDataType(DataType e) { +inline const char *EnumNameDType(DType e) { const size_t index = static_cast(e); - return EnumNamesDataType()[index]; + return EnumNamesDType()[index]; } struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -165,8 +165,8 @@ struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector *buffer() const { return GetPointer *>(VT_BUFFER); } - DataType dtype() const { - return static_cast(GetField(VT_DTYPE, 0)); + DType dtype() const { + return static_cast(GetField(VT_DTYPE, 0)); } ByteOrder byteOrder() const { return static_cast(GetField(VT_BYTEORDER, 0)); @@ -192,7 +192,7 @@ struct FlatArrayBuilder { void add_buffer(flatbuffers::Offset> buffer) { fbb_.AddOffset(FlatArray::VT_BUFFER, buffer); } - void add_dtype(DataType dtype) { + void add_dtype(DType dtype) { fbb_.AddElement(FlatArray::VT_DTYPE, static_cast(dtype), 0); } void add_byteOrder(ByteOrder byteOrder) { @@ -214,7 +214,7 @@ inline flatbuffers::Offset CreateFlatArray( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset> shape = 0, flatbuffers::Offset> buffer = 0, - DataType dtype = DataType_INHERIT, + DType dtype = DType_INHERIT, ByteOrder byteOrder = ByteOrder_LE) { FlatArrayBuilder builder_(_fbb); builder_.add_buffer(buffer); @@ -228,7 +228,7 @@ inline flatbuffers::Offset CreateFlatArrayDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *shape = nullptr, const std::vector *buffer = nullptr, - DataType dtype = DataType_INHERIT, + DType dtype = DType_INHERIT, ByteOrder byteOrder = ByteOrder_LE) { return nd4j::graph::CreateFlatArray( _fbb, diff --git a/libnd4j/include/graph/generated/array_generated.js b/libnd4j/include/graph/generated/array_generated.js index 8a2b644e6..b98410a9e 100644 --- a/libnd4j/include/graph/generated/array_generated.js +++ b/libnd4j/include/graph/generated/array_generated.js @@ -23,7 +23,7 @@ nd4j.graph.ByteOrder = { /** * @enum */ -nd4j.graph.DataType = { +nd4j.graph.DType = { INHERIT: 0, BOOL: 1, FLOAT8: 2, @@ -123,11 +123,11 @@ nd4j.graph.FlatArray.prototype.bufferArray = function() { }; /** - * @returns {nd4j.graph.DataType} + * @returns {nd4j.graph.DType} */ nd4j.graph.FlatArray.prototype.dtype = function() { var offset = this.bb.__offset(this.bb_pos, 8); - return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT; + return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT; }; /** @@ -205,10 +205,10 @@ nd4j.graph.FlatArray.startBufferVector = function(builder, numElems) { /** * @param {flatbuffers.Builder} builder - * @param {nd4j.graph.DataType} dtype + * @param {nd4j.graph.DType} dtype */ nd4j.graph.FlatArray.addDtype = function(builder, dtype) { - builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT); + builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT); }; /** diff --git a/libnd4j/include/graph/generated/nd4j/graph/DataType.cs b/libnd4j/include/graph/generated/nd4j/graph/DType.cs similarity index 93% rename from libnd4j/include/graph/generated/nd4j/graph/DataType.cs rename to libnd4j/include/graph/generated/nd4j/graph/DType.cs index 9cd9518c9..00e399b50 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/DataType.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/DType.cs @@ -5,7 +5,7 @@ namespace nd4j.graph { -public enum DataType : sbyte +public enum DType : sbyte { INHERIT = 0, BOOL = 1, diff --git a/libnd4j/include/graph/generated/nd4j/graph/DataType.java b/libnd4j/include/graph/generated/nd4j/graph/DType.java similarity index 95% rename from libnd4j/include/graph/generated/nd4j/graph/DataType.java rename to libnd4j/include/graph/generated/nd4j/graph/DType.java index 369c1b6ae..20d3d475b 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/DataType.java +++ b/libnd4j/include/graph/generated/nd4j/graph/DType.java @@ -2,8 +2,8 @@ package nd4j.graph; -public final class DataType { - private DataType() { } +public final class DType { + private DType() { } public static final byte INHERIT = 0; public static final byte BOOL = 1; public static final byte FLOAT8 = 2; diff --git a/libnd4j/include/graph/generated/nd4j/graph/DataType.py b/libnd4j/include/graph/generated/nd4j/graph/DType.py similarity index 93% rename from libnd4j/include/graph/generated/nd4j/graph/DataType.py rename to libnd4j/include/graph/generated/nd4j/graph/DType.py index e07aace5d..24cadf44e 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/DataType.py +++ b/libnd4j/include/graph/generated/nd4j/graph/DType.py @@ -2,7 +2,7 @@ # namespace: graph -class DataType(object): +class DType(object): INHERIT = 0 BOOL = 1 FLOAT8 = 2 diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatArray.cs b/libnd4j/include/graph/generated/nd4j/graph/FlatArray.cs index a19325fb7..60d836aeb 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatArray.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatArray.cs @@ -33,13 +33,13 @@ public struct FlatArray : IFlatbufferObject public ArraySegment? GetBufferBytes() { return __p.__vector_as_arraysegment(6); } #endif public sbyte[] GetBufferArray() { return __p.__vector_as_array(6); } - public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } } + public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } } public ByteOrder ByteOrder { get { int o = __p.__offset(10); return o != 0 ? (ByteOrder)__p.bb.GetSbyte(o + __p.bb_pos) : ByteOrder.LE; } } public static Offset CreateFlatArray(FlatBufferBuilder builder, VectorOffset shapeOffset = default(VectorOffset), VectorOffset bufferOffset = default(VectorOffset), - DataType dtype = DataType.INHERIT, + DType dtype = DType.INHERIT, ByteOrder byteOrder = ByteOrder.LE) { builder.StartObject(4); FlatArray.AddBuffer(builder, bufferOffset); @@ -58,7 +58,7 @@ public struct FlatArray : IFlatbufferObject public static VectorOffset CreateBufferVector(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte(data[i]); return builder.EndVector(); } public static VectorOffset CreateBufferVectorBlock(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } public static void StartBufferVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); } - public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } + public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } public static void AddByteOrder(FlatBufferBuilder builder, ByteOrder byteOrder) { builder.AddSbyte(3, (sbyte)byteOrder, 0); } public static Offset EndFlatArray(FlatBufferBuilder builder) { int o = builder.EndObject(); diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs index c1068811d..0810d2e6e 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs @@ -97,14 +97,14 @@ public struct FlatNode : IFlatbufferObject public ArraySegment? GetOpNameBytes() { return __p.__vector_as_arraysegment(36); } #endif public byte[] GetOpNameArray() { return __p.__vector_as_array(36); } - public DataType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DataType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DataType)0; } + public DType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; } public int OutputTypesLength { get { int o = __p.__offset(38); return o != 0 ? __p.__vector_len(o) : 0; } } #if ENABLE_SPAN_T public Span GetOutputTypesBytes() { return __p.__vector_as_span(38); } #else public ArraySegment? GetOutputTypesBytes() { return __p.__vector_as_arraysegment(38); } #endif - public DataType[] GetOutputTypesArray() { return __p.__vector_as_array(38); } + public DType[] GetOutputTypesArray() { return __p.__vector_as_array(38); } public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } } public static Offset CreateFlatNode(FlatBufferBuilder builder, @@ -196,8 +196,8 @@ public struct FlatNode : IFlatbufferObject public static void StartOutputNamesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } public static void AddOpName(FlatBufferBuilder builder, StringOffset opNameOffset) { builder.AddOffset(16, opNameOffset.Value, 0); } public static void AddOutputTypes(FlatBufferBuilder builder, VectorOffset outputTypesOffset) { builder.AddOffset(17, outputTypesOffset.Value, 0); } - public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); } - public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } + public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); } + public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); } public static void AddScalar(FlatBufferBuilder builder, Offset scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); } public static Offset EndFlatNode(FlatBufferBuilder builder) { diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs index d5f8014f2..9764668a0 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs @@ -25,7 +25,7 @@ public struct FlatVariable : IFlatbufferObject public ArraySegment? GetNameBytes() { return __p.__vector_as_arraysegment(6); } #endif public byte[] GetNameArray() { return __p.__vector_as_array(6); } - public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } } + public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } } public long Shape(int j) { int o = __p.__offset(10); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; } public int ShapeLength { get { int o = __p.__offset(10); return o != 0 ? __p.__vector_len(o) : 0; } } #if ENABLE_SPAN_T @@ -41,7 +41,7 @@ public struct FlatVariable : IFlatbufferObject public static Offset CreateFlatVariable(FlatBufferBuilder builder, Offset idOffset = default(Offset), StringOffset nameOffset = default(StringOffset), - DataType dtype = DataType.INHERIT, + DType dtype = DType.INHERIT, VectorOffset shapeOffset = default(VectorOffset), Offset ndarrayOffset = default(Offset), int device = 0, @@ -60,7 +60,7 @@ public struct FlatVariable : IFlatbufferObject public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); } public static void AddId(FlatBufferBuilder builder, Offset idOffset) { builder.AddOffset(0, idOffset.Value, 0); } public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); } - public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } + public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(3, shapeOffset.Value, 0); } public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); } public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); } diff --git a/libnd4j/include/graph/generated/node_generated.js b/libnd4j/include/graph/generated/node_generated.js index a7b2e264f..bd2274dad 100644 --- a/libnd4j/include/graph/generated/node_generated.js +++ b/libnd4j/include/graph/generated/node_generated.js @@ -312,11 +312,11 @@ nd4j.graph.FlatNode.prototype.opName = function(optionalEncoding) { /** * @param {number} index - * @returns {nd4j.graph.DataType} + * @returns {nd4j.graph.DType} */ nd4j.graph.FlatNode.prototype.outputTypes = function(index) { var offset = this.bb.__offset(this.bb_pos, 38); - return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DataType} */ (0); + return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0); }; /** @@ -686,7 +686,7 @@ nd4j.graph.FlatNode.addOutputTypes = function(builder, outputTypesOffset) { /** * @param {flatbuffers.Builder} builder - * @param {Array.} data + * @param {Array.} data * @returns {flatbuffers.Offset} */ nd4j.graph.FlatNode.createOutputTypesVector = function(builder, data) { diff --git a/libnd4j/include/graph/generated/variable_generated.h b/libnd4j/include/graph/generated/variable_generated.h index e441c17dc..ca1a705a0 100644 --- a/libnd4j/include/graph/generated/variable_generated.h +++ b/libnd4j/include/graph/generated/variable_generated.h @@ -65,8 +65,8 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::String *name() const { return GetPointer(VT_NAME); } - DataType dtype() const { - return static_cast(GetField(VT_DTYPE, 0)); + DType dtype() const { + return static_cast(GetField(VT_DTYPE, 0)); } const flatbuffers::Vector *shape() const { return GetPointer *>(VT_SHAPE); @@ -106,7 +106,7 @@ struct FlatVariableBuilder { void add_name(flatbuffers::Offset name) { fbb_.AddOffset(FlatVariable::VT_NAME, name); } - void add_dtype(DataType dtype) { + void add_dtype(DType dtype) { fbb_.AddElement(FlatVariable::VT_DTYPE, static_cast(dtype), 0); } void add_shape(flatbuffers::Offset> shape) { @@ -137,7 +137,7 @@ inline flatbuffers::Offset CreateFlatVariable( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset id = 0, flatbuffers::Offset name = 0, - DataType dtype = DataType_INHERIT, + DType dtype = DType_INHERIT, flatbuffers::Offset> shape = 0, flatbuffers::Offset ndarray = 0, int32_t device = 0, @@ -157,7 +157,7 @@ inline flatbuffers::Offset CreateFlatVariableDirect( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset id = 0, const char *name = nullptr, - DataType dtype = DataType_INHERIT, + DType dtype = DType_INHERIT, const std::vector *shape = nullptr, flatbuffers::Offset ndarray = 0, int32_t device = 0, diff --git a/libnd4j/include/graph/generated/variable_generated.js b/libnd4j/include/graph/generated/variable_generated.js index 3f128e4fc..9012af2de 100644 --- a/libnd4j/include/graph/generated/variable_generated.js +++ b/libnd4j/include/graph/generated/variable_generated.js @@ -76,11 +76,11 @@ nd4j.graph.FlatVariable.prototype.name = function(optionalEncoding) { }; /** - * @returns {nd4j.graph.DataType} + * @returns {nd4j.graph.DType} */ nd4j.graph.FlatVariable.prototype.dtype = function() { var offset = this.bb.__offset(this.bb_pos, 8); - return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT; + return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT; }; /** @@ -150,10 +150,10 @@ nd4j.graph.FlatVariable.addName = function(builder, nameOffset) { /** * @param {flatbuffers.Builder} builder - * @param {nd4j.graph.DataType} dtype + * @param {nd4j.graph.DType} dtype */ nd4j.graph.FlatVariable.addDtype = function(builder, dtype) { - builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT); + builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT); }; /** diff --git a/libnd4j/include/graph/impl/FlatUtils.cpp b/libnd4j/include/graph/impl/FlatUtils.cpp index bc8ff7e33..ec76cb4d2 100644 --- a/libnd4j/include/graph/impl/FlatUtils.cpp +++ b/libnd4j/include/graph/impl/FlatUtils.cpp @@ -111,7 +111,7 @@ namespace nd4j { auto bo = static_cast(BitwiseUtils::asByteOrder()); - return CreateFlatArray(builder, fShape, fBuffer, static_cast(array.dataType()), bo); + return CreateFlatArray(builder, fShape, fBuffer, static_cast(array.dataType()), bo); } } } \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index 6dd881f11..e54112783 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -219,7 +219,7 @@ namespace nd4j { throw std::runtime_error("CONSTANT variable must have NDArray bundled"); auto ar = flatVariable->ndarray(); - if (ar->dtype() == DataType_UTF8) { + if (ar->dtype() == DType_UTF8) { _ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar); } else { _ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar); @@ -320,7 +320,7 @@ namespace nd4j { auto fBuffer = builder.CreateVector(array->asByteVector()); // packing array - auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DataType) array->dataType()); + auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DType) array->dataType()); // packing id/index of this var auto fVid = CreateIntPair(builder, this->_id, this->_index); @@ -331,7 +331,7 @@ namespace nd4j { stringId = builder.CreateString(this->_name); // returning array - return CreateFlatVariable(builder, fVid, stringId, static_cast(array->dataType()), 0, fArray); + return CreateFlatVariable(builder, fVid, stringId, static_cast(array->dataType()), 0, fArray); } else { throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList"); } diff --git a/libnd4j/include/graph/scheme/array.fbs b/libnd4j/include/graph/scheme/array.fbs index f415ffb08..91e338500 100644 --- a/libnd4j/include/graph/scheme/array.fbs +++ b/libnd4j/include/graph/scheme/array.fbs @@ -23,7 +23,7 @@ enum ByteOrder:byte { } // DataType for arrays/buffers -enum DataType:byte { +enum DType:byte { INHERIT, BOOL, FLOAT8, @@ -49,7 +49,7 @@ enum DataType:byte { table FlatArray { shape:[long]; // shape in Nd4j format buffer:[byte]; // byte buffer with data - dtype:DataType; // data type of actual data within buffer + dtype:DType; // data type of actual data within buffer byteOrder:ByteOrder; // byte order of buffer } diff --git a/libnd4j/include/graph/scheme/node.fbs b/libnd4j/include/graph/scheme/node.fbs index 6117e7125..930702f6d 100644 --- a/libnd4j/include/graph/scheme/node.fbs +++ b/libnd4j/include/graph/scheme/node.fbs @@ -48,7 +48,7 @@ table FlatNode { opName:string; //Used to help resolving the class. In a few cases, multiple classes/opNames are mapped to same hash, and might have different config/properties/differentiability // output data types (optional) - outputTypes:[DataType]; + outputTypes:[DType]; //Scalar value - used for scalar ops. Should be single value only. scalar:FlatArray; diff --git a/libnd4j/include/graph/scheme/uigraphstatic.fbs b/libnd4j/include/graph/scheme/uigraphstatic.fbs index cce0da4ad..814c28fa5 100644 --- a/libnd4j/include/graph/scheme/uigraphstatic.fbs +++ b/libnd4j/include/graph/scheme/uigraphstatic.fbs @@ -51,7 +51,7 @@ table UIVariable { id:IntPair; //Existing IntPair class name:string; type:VarType; //Use existing VarType: VARIABLE, CONSTANT, ARRAY, PLACEHOLDER - datatype:DataType; + datatype:DType; shape:[long]; controlDeps:[string]; //Input control dependencies: variable x -> this outputOfOp:string; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of diff --git a/libnd4j/include/graph/scheme/variable.fbs b/libnd4j/include/graph/scheme/variable.fbs index 43f343c7c..31eafafa7 100644 --- a/libnd4j/include/graph/scheme/variable.fbs +++ b/libnd4j/include/graph/scheme/variable.fbs @@ -30,7 +30,7 @@ enum VarType:byte { table FlatVariable { id:IntPair; // ID of the Variable, in format of IntPair.first is node Id, IntPair.second is output index of the node name:string; // symbolic ID of the Variable (if defined) - dtype:DataType; + dtype:DType; shape:[long]; // shape is absolutely optional. either shape or ndarray might be set ndarray:FlatArray; diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp new file mode 100644 index 000000000..52d01429f --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp @@ -0,0 +1,50 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_bitwise_and) + +#include +#include +#include + +namespace nd4j { + namespace ops { + BROADCASTABLE_OP_IMPL(bitwise_and, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x,y,z); + + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), y, z, false); + + return Status::OK(); + } + + DECLARE_TYPES(bitwise_and) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setSameMode(true); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp new file mode 100644 index 000000000..b8469d83a --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp @@ -0,0 +1,50 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_bitwise_or) + +#include +#include +#include + +namespace nd4j { + namespace ops { + BROADCASTABLE_OP_IMPL(bitwise_or, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x,y,z); + + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), y, z, false); + + return Status::OK(); + } + + DECLARE_TYPES(bitwise_or) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setSameMode(true); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp new file mode 100644 index 000000000..f7f3f479a --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp @@ -0,0 +1,50 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_bitwise_xor) + +#include +#include +#include + +namespace nd4j { + namespace ops { + BROADCASTABLE_OP_IMPL(bitwise_xor, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x,y,z); + + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), y, z, false); + + return Status::OK(); + } + + DECLARE_TYPES(bitwise_xor) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setSameMode(true); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/dynamic_stitch.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/dynamic_stitch.cpp index 70310f643..e6913dc34 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/dynamic_stitch.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/dynamic_stitch.cpp @@ -29,21 +29,26 @@ namespace ops { CUSTOM_OP_IMPL(dynamic_stitch, 2, 1, false, 0, 0) { int numOfData = block.width(); // int k = 0; + // checking input data size REQUIRE_TRUE(numOfData % 2 == 0, 0, "dynamic_stitch: The input params should contains" " both indeces and data lists with same length."); + // split input data list on two equal parts numOfData /= 2; + // form input lists to use with helpers - both indices and float data inputs auto output = OUTPUT_VARIABLE(0); std::vector inputs(numOfData); std::vector indices(numOfData); + for (int e = 0; e < numOfData; e++) { auto data = INPUT_VARIABLE(numOfData + e); auto index = INPUT_VARIABLE(e); + inputs[e] = data; indices[e] = index; } - + // run helper return helpers::dynamicStitchFunctor(block.launchContext(), inputs, indices, output); } @@ -59,17 +64,17 @@ namespace ops { numOfData /= 2; // only index part it's needed to review auto restShape = inputShape->at(numOfData); auto firstShape = inputShape->at(0); + // check up inputs to avoid non-int indices and calculate max value from indices to output shape length for(int i = 0; i < numOfData; i++) { auto input = INPUT_VARIABLE(i); REQUIRE_TRUE(input->isZ(), 0, "dynamic_stitch: Indices should be integer, but %d type given.", (int)input->dataType() ); - // FIXME: we have reduce::Max, cinsider using it instead auto maxV = input->reduceNumber(reduce::Max); if (maxV.e(0) > maxValue) maxValue = maxV.e(0); } - - int outRank = shape::rank(restShape) - shape::rank(firstShape) + 1; + // calculate output rank - difference between indices shape and data shape + int outRank = shape::rank(restShape) - shape::rank(firstShape) + 1; // at least 1D tensor std::vector outShape(outRank); - + // fill up output shape template: the first to max index, and rests - to vals from the first data input outShape[0] = maxValue + 1; for(int i = 1; i < outRank; ++i) outShape[i] = shape::sizeAt(restShape, i); diff --git a/libnd4j/include/ops/declarable/headers/BarnesHutTsne.h b/libnd4j/include/ops/declarable/headers/BarnesHutTsne.h index 89e4c385a..d3a4c042d 100644 --- a/libnd4j/include/ops/declarable/headers/BarnesHutTsne.h +++ b/libnd4j/include/ops/declarable/headers/BarnesHutTsne.h @@ -33,12 +33,13 @@ namespace nd4j { * 0: 1D row-vector (or with shape (1, m)) * 1: 1D integer vector with slice nums * 2: 1D float-point values vector with same shape as above + * 3: 2D float-point matrix with data to search * * Int args: * 0: N - number of slices * * Output: - * 0: 1D vector with edge forces for input and values + * 0: 2D matrix with the same shape and type as the 3th argument */ #if NOT_EXCLUDED(OP_barnes_edge_forces) DECLARE_CUSTOM_OP(barnes_edge_forces, 4, 1, false, 0, 1); @@ -52,9 +53,11 @@ namespace nd4j { * 0: 1D int row-vector * 1: 1D int col-vector * 2: 1D float vector with values - * + * * Output: - * 0: symmetric 2D matrix with given values on given places + * 0: 1D int result row-vector + * 1: 1D int result col-vector + * 2: a float-point tensor with shape 1xN, with values from the last input vector */ #if NOT_EXCLUDED(OP_barnes_symmetrized) DECLARE_CUSTOM_OP(barnes_symmetrized, 3, 3, false, 0, -1); diff --git a/libnd4j/include/ops/declarable/headers/bitwise.h b/libnd4j/include/ops/declarable/headers/bitwise.h index a6362a73f..cb395b496 100644 --- a/libnd4j/include/ops/declarable/headers/bitwise.h +++ b/libnd4j/include/ops/declarable/headers/bitwise.h @@ -81,6 +81,39 @@ namespace nd4j { DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0); #endif + /** + * This operation applies bitwise AND + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ + #if NOT_EXCLUDED(OP_bitwise_and) + DECLARE_BROADCASTABLE_OP(bitwise_and, 0, 0); + #endif + + /** + * This operation applies bitwise OR + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ + #if NOT_EXCLUDED(OP_bitwise_or) + DECLARE_BROADCASTABLE_OP(bitwise_or, 0, 0); + #endif + + /** + * This operation applies bitwise XOR + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ + #if NOT_EXCLUDED(OP_bitwise_xor) + DECLARE_BROADCASTABLE_OP(bitwise_xor, 0, 0); + #endif + /** * This operation returns hamming distance based on bits * diff --git a/libnd4j/include/ops/declarable/headers/list.h b/libnd4j/include/ops/declarable/headers/list.h index 01c2d225c..756895a1f 100644 --- a/libnd4j/include/ops/declarable/headers/list.h +++ b/libnd4j/include/ops/declarable/headers/list.h @@ -120,7 +120,7 @@ namespace nd4j { #endif /** - * This operation unstacks given NDArray into NDArrayList + * This operation unstacks given NDArray into NDArrayList by the first dimension */ #if NOT_EXCLUDED(OP_unstack_list) DECLARE_LIST_OP(unstack_list, 1, 1, 0, 0); diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index c86f28499..bb7f306bd 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -594,21 +594,46 @@ namespace nd4j { /** + * This operation rearranges data from depth into blocks of spatial data. This is the reverse transformation + * of space_to_depth op. This op output is a copy of the input tensor where values from the depth dimension + * are moved in spatial blocks to the height and width dimensions. Int attr 0 indicates the input + * block size and how the data is moved. + * Input: + * 0 - 4D tensor on given type + * Output: + * 0 - 4D tensor of given type and proper shape * - * - * + * Int arguments: + * 0 - block size + * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels } + * 1 ("NCHW"): shape{ batch, channels, height, width } + * 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 } + * optional (default 0) */ #if NOT_EXCLUDED(OP_depth_to_space) - DECLARE_CUSTOM_OP(depth_to_space, 1, 1, false, 0, 2); + DECLARE_CUSTOM_OP(depth_to_space, 1, 1, false, 0, -1); #endif /** + * This operation rearranges blocks of spatial data, into depth.This op output is a copy of the input tensor + * where values from the height and width dimensions are moved to the depth dimension. Int attr 0 indicates + * the input block size. * + * Input: + * - 4D tensor of given type + * Output: + * - 4D tensor * + * Int arguments: + * 0 - block size + * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels } + * 1 ("NCHW"): shape{ batch, channels, height, width } + * 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 } + * optional (default 0) * */ #if NOT_EXCLUDED(OP_space_to_depth) - DECLARE_CUSTOM_OP(space_to_depth, 1, 1, false, 0, 2); + DECLARE_CUSTOM_OP(space_to_depth, 1, 1, false, 0, -1); #endif /** @@ -622,13 +647,42 @@ namespace nd4j { #endif /** + * Zero-pads and then rearranges (permutes) blocks of spatial data into batch. More specifically, this op + * outputs a copy of the input tensor where values from the height and width dimensions are moved to the + * batch dimension. After the zero-padding, both height and width of the input must be divisible by the block + * size. * + * Inputs: + * 0 - input tensor + * 1 - 2D paddings tensor (shape {M, 2}) + * + * Output: + * - result tensor + * + * Int args: + * 0 - block size (M) * */ #if NOT_EXCLUDED(OP_space_to_batch) DECLARE_CUSTOM_OP(space_to_batch, 2, 1, false, 0, 1); #endif + /* + * This operation divides "spatial" dimensions [1, ..., M] of the input into a grid of blocks of shape + * block_shape, and interleaves these blocks with the "batch" dimension (0) such that in the output, + * the spatial dimensions [1, ..., M] correspond to the position within the grid, and the batch dimension + * combines both the position within a spatial block and the original batch position. Prior to division into + * blocks, the spatial dimensions of the input are optionally zero padded according to paddings. + * + * Inputs: + * 0 - input (N-D tensor) + * 1 - block_shape - int 1D tensor with M length + * 2 - paddings - int 2D tensor with shape {M, 2} + * + * Output: + * - N-D tensor with the same type as input 0. + * + * */ #if NOT_EXCLUDED(OP_space_to_batch_nd) DECLARE_CUSTOM_OP(space_to_batch_nd, 3, 1, false, 0, 0); #endif @@ -973,7 +1027,7 @@ namespace nd4j { * return value: * tensor with min values according to indices sets. */ - #if NOT_EXCLUDED(OP_segment_min_bp) + #if NOT_EXCLUDED(OP_segment_min) DECLARE_CUSTOM_OP(segment_min, 2, 1, false, 0, 0); #endif #if NOT_EXCLUDED(OP_segment_min_bp) diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu index 7d520478e..75b541b72 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu @@ -118,19 +118,19 @@ namespace nd4j { PointersManager pm(context, "dynamicPartition"); - if (sourceDimsLen) { + if (sourceDimsLen) { // non-linear case std::vector sourceDims(sourceDimsLen); for (int i = sourceDimsLen; i > 0; i--) sourceDims[sourceDimsLen - i] = input->rankOf() - i; - + //compute tad array for given dimensions auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), sourceDims); std::vector outBuffers(outSize); std::vector tadShapes(outSize); std::vector tadOffsets(outSize); std::vector numTads(outSize); - + // fill up dimensions array for before kernel for (unsigned int i = 0; i < outSize; i++) { outputs[i].first = outputList[i]; std::vector outDims(outputs[i].first->rankOf() - 1); @@ -151,10 +151,10 @@ namespace nd4j { auto dOutBuffers = reinterpret_cast(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *))); auto dOutTadShapes = reinterpret_cast(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *))); auto dOutTadOffsets = reinterpret_cast(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *))); - + // run kernel on device dynamicPartitionTadKernel<<<256, 256, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize); - } else { + } else { // linear case auto numThreads = 256; auto shmemSize = numThreads * sizeof(Y) * 2 + 1024; @@ -169,7 +169,6 @@ namespace nd4j { auto dOutBuffers = reinterpret_cast(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *))); auto dOutShapes = reinterpret_cast(pm.replicatePointer(outShapes.data(), outShapes.size() * sizeof(Nd4jLong *))); - dynamicPartitionScalarKernel<<<256, numThreads, shmemSize, *context->getCudaStream()>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), indices->getSpecialBuffer(), indices-> getSpecialShapeInfo(), dOutBuffers, dOutShapes, outSize); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 87ac417be..2ef9e2309 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -544,8 +544,8 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_1) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustSaturation_2) { - NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32); - NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::FLOAT32); + NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::DOUBLE); + NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::DOUBLE); nd4j::ops::adjust_saturation op; auto results = op.execute({&input}, {10}, {2}); @@ -553,7 +553,8 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_2) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); - // result->printIndexedBuffer(); +// result->printIndexedBuffer("Result2"); +// exp.printIndexedBuffer("Expect2"); ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.equalsTo(result)); diff --git a/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp b/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp index cf9f2914e..49dd0657d 100644 --- a/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp @@ -94,10 +94,10 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) { auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector()); auto fBuffer = builder.CreateVector(array->asByteVector()); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT); + auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT); auto fVid = CreateIntPair(builder, -1); - auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray); + auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray); std::vector outputs1, outputs2, inputs1, inputs2; outputs1.push_back(2); @@ -265,7 +265,7 @@ TEST_F(FlatBuffersTest, ExplicitOutputTest1) { auto name1 = builder.CreateString("wow1"); - auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DataType::FLOAT); + auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DType::FLOAT); std::vector> variables_vector; variables_vector.push_back(fXVar); diff --git a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp index e31347b0e..fcdd1db3c 100644 --- a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp @@ -73,9 +73,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_1) { auto fBuffer = builder.CreateVector(vec); auto fVid = CreateIntPair(builder, 1, 12); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT); + auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT); - auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray); + auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray); builder.Finish(flatVar); @@ -107,9 +107,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_2) { auto fBuffer = builder.CreateVector(vec); auto fVid = CreateIntPair(builder, 1, 12); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE); + auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE); - auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray); + auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray); builder.Finish(flatVar); @@ -144,9 +144,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_3) { auto fBuffer = builder.CreateVector(vec); auto fVid = CreateIntPair(builder, 1, 12); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE); + auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE); - auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray); + auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray); builder.Finish(flatVar); @@ -180,7 +180,7 @@ TEST_F(VariableTests, Test_FlatVariableDataType_4) { auto fShape = builder.CreateVector(original.getShapeAsFlatVector()); auto fVid = CreateIntPair(builder, 37, 12); - auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER); + auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER); builder.Finish(flatVar); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index ac017beef..621dac941 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -469,7 +469,7 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable localResponseNormalization(SDVariable input, LocalResponseNormalizationConfig lrnConfig) { - LocalResponseNormalization lrn = LocalResponseNormalization.builder() + LocalResponseNormalization lrn = LocalResponseNormalization.sameDiffBuilder() .inputFunctions(new SDVariable[]{input}) .sameDiff(sameDiff()) .config(lrnConfig) @@ -487,7 +487,7 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { - Conv1D conv1D = Conv1D.builder() + Conv1D conv1D = Conv1D.sameDiffBuilder() .inputFunctions(new SDVariable[]{input, weights}) .sameDiff(sameDiff()) .config(conv1DConfig) @@ -496,6 +496,34 @@ public class DifferentialFunctionFactory { return conv1D.outputVariable(); } + /** + * Conv1d operation. + * + * @param input the inputs to conv1d + * @param weights conv1d weights + * @param bias conv1d bias + * @param conv1DConfig the configuration + * @return + */ + public SDVariable conv1d(SDVariable input, SDVariable weights, SDVariable bias, Conv1DConfig conv1DConfig) { + + SDVariable[] args; + + if(bias == null){ + args = new SDVariable[]{input, weights}; + } else { + args = new SDVariable[]{input, weights, bias}; + } + + Conv1D conv1D = Conv1D.sameDiffBuilder() + .inputFunctions(args) + .sameDiff(sameDiff()) + .config(conv1DConfig) + .build(); + + return conv1D.outputVariable(); + } + /** * Conv2d operation. * @@ -504,7 +532,7 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { - Conv2D conv2D = Conv2D.builder() + Conv2D conv2D = Conv2D.sameDiffBuilder() .inputFunctions(inputs) .sameDiff(sameDiff()) .config(conv2DConfig) @@ -530,7 +558,7 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { - AvgPooling2D avgPooling2D = AvgPooling2D.builder() + AvgPooling2D avgPooling2D = AvgPooling2D.sameDiffBuilder() .input(input) .sameDiff(sameDiff()) .config(pooling2DConfig) @@ -547,7 +575,7 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { - MaxPooling2D maxPooling2D = MaxPooling2D.builder() + MaxPooling2D maxPooling2D = MaxPooling2D.sameDiffBuilder() .input(input) .sameDiff(sameDiff()) .config(pooling2DConfig) @@ -590,7 +618,7 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { - SConv2D sconv2D = SConv2D.sBuilder() + SConv2D sconv2D = SConv2D.sameDiffSBuilder() .inputFunctions(inputs) .sameDiff(sameDiff()) .conv2DConfig(conv2DConfig) @@ -609,7 +637,7 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { - SConv2D depthWiseConv2D = SConv2D.sBuilder() + SConv2D depthWiseConv2D = SConv2D.sameDiffSBuilder() .inputFunctions(inputs) .sameDiff(sameDiff()) .conv2DConfig(depthConv2DConfig) @@ -627,7 +655,7 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { - DeConv2D deconv2D = DeConv2D.builder() + DeConv2D deconv2D = DeConv2D.sameDiffBuilder() .inputs(inputs) .sameDiff(sameDiff()) .config(deconv2DConfig) @@ -654,9 +682,9 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig) { - Conv3D conv3D = Conv3D.builder() + Conv3D conv3D = Conv3D.sameDiffBuilder() .inputFunctions(inputs) - .conv3DConfig(conv3DConfig) + .config(conv3DConfig) .sameDiff(sameDiff()) .build(); @@ -1260,6 +1288,22 @@ public class DifferentialFunctionFactory { return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable(); } + public SDVariable bitwiseHammingDist(SDVariable x, SDVariable y) { + return new BitsHammingDistance(sameDiff(), x, y).outputVariable(); + } + + public SDVariable bitwiseAnd(SDVariable x, SDVariable y){ + return new BitwiseAnd(sameDiff(), x, y).outputVariable(); + } + + public SDVariable bitwiseOr(SDVariable x, SDVariable y){ + return new BitwiseOr(sameDiff(), x, y).outputVariable(); + } + + public SDVariable bitwiseXor(SDVariable x, SDVariable y){ + return new BitwiseXor(sameDiff(), x, y).outputVariable(); + } + public SDVariable eq(SDVariable iX, SDVariable i_y) { return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index e09ceda75..0b5a4c03f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -188,6 +188,11 @@ public class SameDiff extends SDBaseOps { */ public final SDImage image = new SDImage(this); + /** + * Op creator object for bitwise operations + */ + public final SDBitwise bitwise = new SDBitwise(this); + /** * Op creator object for math operations */ @@ -237,6 +242,13 @@ public class SameDiff extends SDBaseOps { return image; } + /** + * Op creator object for bitwise operations + */ + public SDBitwise bitwise(){ + return bitwise; + } + /** * For import, many times we have variables diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java new file mode 100644 index 000000000..0857b2b42 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java @@ -0,0 +1,205 @@ +package org.nd4j.autodiff.samediff.ops; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; + +import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger; + +/** + * + */ +public class SDBitwise extends SDOps { + public SDBitwise(SameDiff sameDiff) { + super(sameDiff); + } + + /** + * See {@link #leftShift(String, SDVariable, SDVariable)} + */ + public SDVariable leftShift(@NonNull SDVariable x, @NonNull SDVariable y){ + return leftShift(null, x, y); + } + + /** + * Bitwise left shift operation. Supports broadcasting. + * + * @param name Name of the output variable. May be null. + * @param x Input to be bit shifted (must be an integer type) + * @param y Amount to shift elements of x array (must be an integer type) + * @return Bitwise shifted input x + */ + public SDVariable leftShift(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise left shift", x); + validateInteger("bitwise left shift", y); + + SDVariable ret = f().shift(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #rightShift(String, SDVariable, SDVariable)} + */ + public SDVariable rightShift(SDVariable x, SDVariable y){ + return rightShift(null, x, y); + } + + /** + * Bitwise right shift operation. Supports broadcasting. + * + * @param name Name of the output variable. May be null. + * @param x Input to be bit shifted (must be an integer type) + * @param y Amount to shift elements of x array (must be an integer type) + * @return Bitwise shifted input x + */ + public SDVariable rightShift(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise right shift", x); + validateInteger("bitwise right shift", y); + + SDVariable ret = f().rshift(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #leftShiftCyclic(String, SDVariable, SDVariable)} + */ + public SDVariable leftShiftCyclic(SDVariable x, SDVariable y){ + return leftShiftCyclic(null, x, y); + } + + /** + * Bitwise left cyclical shift operation. Supports broadcasting. + * Unlike {@link #leftShift(String, SDVariable, SDVariable)} the bits will "wrap around": + * {@code leftShiftCyclic(01110000, 2) -> 11000001} + * + * @param name Name of the output variable. May be null. + * @param x Input to be bit shifted (must be an integer type) + * @param y Amount to shift elements of x array (must be an integer type) + * @return Bitwise cyclic shifted input x + */ + public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise left shift (cyclic)", x); + validateInteger("bitwise left shift (cyclic)", y); + + SDVariable ret = f().rotl(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #rightShiftCyclic(String, SDVariable, SDVariable)} + */ + public SDVariable rightShiftCyclic(SDVariable x, SDVariable y){ + return rightShiftCyclic(null, x, y); + } + + /** + * Bitwise right cyclical shift operation. Supports broadcasting. + * Unlike {@link #rightShift(String, SDVariable, SDVariable)} the bits will "wrap around": + * {@code rightShiftCyclic(00001110, 2) -> 10000011} + * + * @param name Name of the output variable. May be null. + * @param x Input to be bit shifted (must be an integer type) + * @param y Amount to shift elements of x array (must be an integer type) + * @return Bitwise cyclic shifted input x + */ + public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise right shift (cyclic)", x); + validateInteger("bitwise right shift (cyclic)", y); + + SDVariable ret = f().rotr(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #bitsHammingDistance(String, SDVariable, SDVariable)} + */ + public SDVariable bitsHammingDistance(SDVariable x, SDVariable y){ + return bitsHammingDistance(null, x, y); + } + + /** + * Bitwise Hamming distance reduction over all elements of both input arrays.
+ * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1) + * + * @param name Name of the output variable. May be null. + * @param x First input array. Must be integer type. + * @param y First input array. Must be integer type, same type as x + * @return + */ + public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise hamming distance", x); + validateInteger("bitwise hamming distance", y); + + SDVariable ret = f().bitwiseHammingDist(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #and(String, SDVariable, SDVariable)} + */ + public SDVariable and(SDVariable x, SDVariable y){ + return and(null, x, y); + } + + /** + * Bitwise AND operation. Supports broadcasting. + * + * @param name Name of the output variable. May be null. + * @param x First input array. Must be integer type. + * @param y First input array. Must be integer type, same type as x + * @return Bitwise AND array + */ + public SDVariable and(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise AND", x); + validateInteger("bitwise AND", y); + + SDVariable ret = f().bitwiseAnd(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #or(String, SDVariable, SDVariable)} + */ + public SDVariable or(SDVariable x, SDVariable y){ + return or(null, x, y); + } + + /** + * Bitwise OR operation. Supports broadcasting. + * + * @param name Name of the output variable. May be null. + * @param x First input array. Must be integer type. + * @param y First input array. Must be integer type, same type as x + * @return Bitwise OR array + */ + public SDVariable or(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise OR", x); + validateInteger("bitwise OR", y); + + SDVariable ret = f().bitwiseOr(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #xor(String, SDVariable, SDVariable)} + */ + public SDVariable xor(SDVariable x, SDVariable y){ + return xor(null, x, y); + } + + /** + * Bitwise XOR operation (exclusive OR). Supports broadcasting. + * + * @param name Name of the output variable. May be null. + * @param x First input array. Must be integer type. + * @param y First input array. Must be integer type, same type as x + * @return Bitwise XOR array + */ + public SDVariable xor(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise XOR", x); + validateInteger("bitwise XOR", y); + + SDVariable ret = f().bitwiseXor(x, y); + return updateVariableNameAndReference(ret, name); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java index fab50a937..7b56ca266 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java @@ -16,6 +16,7 @@ package org.nd4j.autodiff.samediff.ops; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; @@ -38,14 +39,9 @@ public class SDCNN extends SDOps { } /** - * 2D Convolution layer operation - average pooling 2d - * - * @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param pooling2DConfig the configuration for - * @return Result after applying average pooling on the input + * See {@link #avgPooling2d(String, SDVariable, Pooling2DConfig)}. */ - public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { + public SDVariable avgPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) { return avgPooling2d(null, input, pooling2DConfig); } @@ -58,22 +54,16 @@ public class SDCNN extends SDOps { * @param pooling2DConfig the configuration * @return Result after applying average pooling on the input */ - public SDVariable avgPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) { + public SDVariable avgPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) { validateFloatingPoint("avgPooling2d", input); SDVariable ret = f().avgPooling2d(input, pooling2DConfig); return updateVariableNameAndReference(ret, name); } /** - * 3D convolution layer operation - average pooling 3d - * - * @param input the input to average pooling 3d operation - 5d activations in NCDHW format - * (shape [minibatch, channels, depth, height, width]) or NDHWC format - * (shape [minibatch, depth, height, width, channels]) - * @param pooling3DConfig the configuration - * @return Result after applying average pooling on the input + * See {@link #avgPooling3d(String, SDVariable, Pooling3DConfig)}. */ - public SDVariable avgPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) { + public SDVariable avgPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) { return avgPooling3d(null, input, pooling3DConfig); } @@ -87,7 +77,7 @@ public class SDCNN extends SDOps { * @param pooling3DConfig the configuration * @return Result after applying average pooling on the input */ - public SDVariable avgPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) { + public SDVariable avgPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) { validateFloatingPoint("avgPooling3d", input); SDVariable ret = f().avgPooling3d(input, pooling3DConfig); return updateVariableNameAndReference(ret, name); @@ -96,7 +86,7 @@ public class SDCNN extends SDOps { /** * @see #batchToSpace(String, SDVariable, int[], int[][]) */ - public SDVariable batchToSpace(SDVariable x, int[] blocks, int[][] crops) { + public SDVariable batchToSpace(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) { return batchToSpace(null, x, blocks, crops); } @@ -111,7 +101,7 @@ public class SDCNN extends SDOps { * @return Output variable * @see #spaceToBatch(String, SDVariable, int[], int[][]) */ - public SDVariable batchToSpace(String name, SDVariable x, int[] blocks, int[][] crops) { + public SDVariable batchToSpace(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) { validateNumerical("batchToSpace", x); SDVariable ret = f().batchToSpace(x, blocks, crops); return updateVariableNameAndReference(ret, name); @@ -119,14 +109,9 @@ public class SDCNN extends SDOps { /** - * col2im operation for use in 2D convolution operations. Outputs a 4d array with shape - * [minibatch, inputChannels, height, width] - * - * @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] - * @param config Convolution configuration for the col2im operation - * @return Col2Im output variable + * See {@link #col2Im(String, SDVariable, Conv2DConfig)}. */ - public SDVariable col2Im(SDVariable in, Conv2DConfig config) { + public SDVariable col2Im(@NonNull SDVariable in, @NonNull Conv2DConfig config) { return col2Im(null, in, config); } @@ -139,33 +124,22 @@ public class SDCNN extends SDOps { * @param config Convolution configuration for the col2im operation * @return Col2Im output variable */ - public SDVariable col2Im(String name, SDVariable in, Conv2DConfig config) { + public SDVariable col2Im(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) { SDVariable ret = f().col2Im(in, config); return updateVariableNameAndReference(ret, name); } /** - * 1D Convolution layer operation - Conv1d - * - * @param input the input array/activations for the conv1d op - * @param weights weights for conv1d op - rank 3 array with values [kernelSize, inputChannels, outputChannels] - * @param conv1DConfig the configuration - * @return + * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias. */ - public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { - return conv1d(null, input, weights, conv1DConfig); + public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) { + return conv1d((String) null, input, weights, conv1DConfig); } /** - * Conv1d operation. - * - * @param name name of the operation in SameDiff - * @param input the inputs to conv1d - * @param weights weights for conv1d op - rank 3 array with values [kernelSize, inputChannels, outputChannels] - * @param conv1DConfig the configuration - * @return + * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias. */ - public SDVariable conv1d(String name, SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { + public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) { validateFloatingPoint("conv1d", input); validateFloatingPoint("conv1d", weights); SDVariable ret = f().conv1d(input, weights, conv1DConfig); @@ -173,21 +147,55 @@ public class SDCNN extends SDOps { } /** - * 2D Convolution operation (without bias) - * - * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] - * @param config Conv2DConfig configuration - * @return result of conv2d op + * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}. */ - public SDVariable conv2d(SDVariable layerInput, SDVariable weights, Conv2DConfig config) { + public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) { + return conv1d(null, input, weights, bias, conv1DConfig); + } + + /** + * Conv1d operation. + * + * @param name name of the operation in SameDiff + * @param input the inputs to conv1d + * @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels] + * @param bias bias for conv1d op - rank 1 array with shape [outputChannels]. May be null. + * @param conv1DConfig the configuration + * @return + */ + public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) { + validateFloatingPoint("conv1d", input); + validateFloatingPoint("conv1d", weights); + validateFloatingPoint("conv1d", bias); + SDVariable ret = f().conv1d(input, weights, bias, conv1DConfig); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. + */ + public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) { return conv2d(layerInput, weights, null, config); } + /** + * See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. + */ + public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) { + return conv2d(name, layerInput, weights, null, config); + } + + /** + * See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}. + */ + public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) { + return conv2d(null, layerInput, weights, bias, config); + } + /** * 2D Convolution operation with optional bias * + * @param name name of the operation in SameDiff * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] @@ -195,7 +203,7 @@ public class SDCNN extends SDOps { * @param config Conv2DConfig configuration * @return result of conv2d op */ - public SDVariable conv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, Conv2DConfig config) { + public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) { validateFloatingPoint("conv2d", "input", layerInput); validateFloatingPoint("conv2d", "weights", weights); validateFloatingPoint("conv2d", "bias", bias); @@ -204,18 +212,13 @@ public class SDCNN extends SDOps { arr[1] = weights; if (bias != null) arr[2] = bias; - return conv2d(arr, config); + return conv2d(name, arr, config); } /** - * 2D Convolution operation with optional bias - * - * @param inputs an array with either 2 elements (layerInput, weights) or 3 elements (layerInput, weights, bias) as - * described in {@link #conv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)} - * @param config Conv2DConfig configuration - * @return result of convolution 2d operation + * See {@link #conv2d(String, SDVariable[], Conv2DConfig)}. */ - public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig config) { + public SDVariable conv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) { return conv2d(null, inputs, config); } @@ -228,7 +231,7 @@ public class SDCNN extends SDOps { * @param config Conv2DConfig configuration * @return result of convolution 2d operation */ - public SDVariable conv2d(String name, SDVariable[] inputs, Conv2DConfig config) { + public SDVariable conv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) { for(SDVariable v : inputs) validateNumerical("conv2d", v); SDVariable ret = f().conv2d(inputs, config); @@ -236,19 +239,26 @@ public class SDCNN extends SDOps { } /** - * Convolution 3D operation without bias - * - * @param input the input to average pooling 3d operation - 5d activations in NCDHW format - * (shape [minibatch, channels, depth, height, width]) or NDHWC format - * (shape [minibatch, depth, height, width, channels]) - * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. - * @param conv3DConfig the configuration - * @return Conv3d output variable + * See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias. */ - public SDVariable conv3d(SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) { + public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) { return conv3d(null, input, weights, null, conv3DConfig); } + /** + * See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias. + */ + public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) { + return conv3d(name, input, weights, null, conv3DConfig); + } + + /** + * See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}. + */ + public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) { + return conv3d(null, input, weights, bias, conv3DConfig); + } + /** * Convolution 3D operation with optional bias * @@ -261,7 +271,7 @@ public class SDCNN extends SDOps { * @param conv3DConfig the configuration * @return Conv3d output variable */ - public SDVariable conv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) { + public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) { validateFloatingPoint("conv3d", "input", input); validateFloatingPoint("conv3d", "weights", weights); validateFloatingPoint("conv3d", "bias", bias); @@ -276,51 +286,30 @@ public class SDCNN extends SDOps { } /** - * Convolution 3D operation with optional bias - * - * @param input the input to average pooling 3d operation - 5d activations in NCDHW format - * (shape [minibatch, channels, depth, height, width]) or NDHWC format - * (shape [minibatch, depth, height, width, channels]) - * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. - * @param bias Optional 1D bias array with shape [outputChannels]. May be null. - * @param conv3DConfig the configuration - * @return Conv3d output variable + * See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias. */ - public SDVariable conv3d(SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) { - return conv3d(null, input, weights, bias, conv3DConfig); - } - - /** - * Convolution 3D operation without bias - * - * @param name Name of the output variable - * @param input the input to average pooling 3d operation - 5d activations in NCDHW format - * (shape [minibatch, channels, depth, height, width]) or NDHWC format - * (shape [minibatch, depth, height, width, channels]) - * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. - * @param conv3DConfig the configuration - * @return Conv3d output variable - */ - public SDVariable conv3d(String name, SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) { - return conv3d(name, input, weights, null, conv3DConfig); - } - - /** - * 2D deconvolution operation without bias - * - * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth]. - * @param deconv2DConfig DeConv2DConfig configuration - * @return result of deconv2d op - */ - public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, DeConv2DConfig deconv2DConfig) { + public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) { return deconv2d(layerInput, weights, null, deconv2DConfig); } + /** + * See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias. + */ + public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) { + return deconv2d(name, layerInput, weights, null, deconv2DConfig); + } + + /** + * See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}. + */ + public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) { + return deconv2d(null, layerInput, weights, bias, deconv2DConfig); + } + /** * 2D deconvolution operation with optional bias * + * @param name name of the operation in SameDiff * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth]. @@ -328,7 +317,7 @@ public class SDCNN extends SDOps { * @param deconv2DConfig DeConv2DConfig configuration * @return result of deconv2d op */ - public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, DeConv2DConfig deconv2DConfig) { + public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) { validateFloatingPoint("deconv2d", "input", layerInput); validateFloatingPoint("deconv2d", "weights", weights); validateFloatingPoint("deconv2d", "bias", bias); @@ -337,18 +326,13 @@ public class SDCNN extends SDOps { arr[1] = weights; if (bias != null) arr[2] = bias; - return deconv2d(arr, deconv2DConfig); + return deconv2d(name, arr, deconv2DConfig); } /** - * 2D deconvolution operation with or without optional bias - * - * @param inputs Inputs to the deconvolution 2d operation - input array of length 2 (layerInput, weights) - * or length 3 (layerInput, weights, bias) as described in {@link #deconv2d(SDVariable[], DeConv2DConfig)} - * @param deconv2DConfig the configuration - * @return result of deconv2d op + * See {@link #deconv2d(String, SDVariable[], DeConv2DConfig)}. */ - public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { + public SDVariable deconv2d(@NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) { return deconv2d(null, inputs, deconv2DConfig); } @@ -361,13 +345,34 @@ public class SDCNN extends SDOps { * @param deconv2DConfig the configuration * @return result of deconv2d op */ - public SDVariable deconv2d(String name, SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { + public SDVariable deconv2d(String name, @NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) { for(SDVariable v : inputs) validateNumerical("deconv2d", v); SDVariable ret = f().deconv2d(inputs, deconv2DConfig); return updateVariableNameAndReference(ret, name); } + /** + * See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias. + */ + public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) { + return deconv3d(input, weights, null, config); + } + + /** + * See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias. + */ + public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) { + return deconv3d(name, input, weights, null, config); + } + + /** + * See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}. + */ + public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { + return deconv3d(null, input, weights, bias, config); + } + /** * 3D CNN deconvolution operation with or without optional bias * @@ -377,7 +382,7 @@ public class SDCNN extends SDOps { * @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels] * @param config Configuration */ - public SDVariable deconv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) { + public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { validateFloatingPoint("conv3d", input); validateFloatingPoint("conv3d", weights); validateFloatingPoint("conv3d", bias); @@ -386,41 +391,9 @@ public class SDCNN extends SDOps { } /** - * 3D CNN deconvolution operation with or without optional bias - * - * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - * @param weights Weights array - shape [kD, kH, kW, oC, iC] - * @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels] - * @param config Configuration + * See {@link #depthToSpace(String, SDVariable, int, String)}. */ - public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) { - return deconv3d(null, input, weights, bias, config); - } - - /** - * 3D CNN deconvolution operation with no bias - * - * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - * @param weights Weights array - shape [kD, kH, kW, oC, iC] - * @param config Configuration - */ - public SDVariable deconv3d(SDVariable input, SDVariable weights, DeConv3DConfig config) { - return deconv3d(input, weights, null, config); - } - - /** - * Convolution 2d layer batch to space operation on 4d input.
- * Reduces input channels dimension by rearranging data into a larger spatial dimensions
- * Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2] - * = [mb, 2, 4, 4] - * - * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param blockSize Block size, in the height/width dimension - * @param dataFormat Data format: "NCHW" or "NHWC" - * @return Output variable - */ - public SDVariable depthToSpace(SDVariable x, int blockSize, String dataFormat) { + public SDVariable depthToSpace(@NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) { return depthToSpace(null, x, blockSize, dataFormat); } @@ -438,27 +411,36 @@ public class SDCNN extends SDOps { * @return Output variable * @see #depthToSpace(String, SDVariable, int, String) */ - public SDVariable depthToSpace(String name, SDVariable x, int blockSize, String dataFormat) { + public SDVariable depthToSpace(String name, @NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) { SDVariable ret = f().depthToSpace(x, blockSize, dataFormat); return updateVariableNameAndReference(ret, name); } /** - * Depth-wise 2D convolution operation without bias - * - * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] - * @param config Conv2DConfig configuration - * @return result of conv2d op + * See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. */ - public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, Conv2DConfig config) { + public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) { return depthWiseConv2d(layerInput, depthWeights, null, config); } + /** + * See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. + */ + public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) { + return depthWiseConv2d(name, layerInput, depthWeights, null, config); + } + + /** + * See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}. + */ + public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) { + return depthWiseConv2d(null, layerInput, depthWeights, bias, config); + } + /** * Depth-wise 2D convolution operation with optional bias * + * @param name name of the operation in SameDiff * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] @@ -466,7 +448,7 @@ public class SDCNN extends SDOps { * @param config Conv2DConfig configuration * @return result of depthwise conv2d op */ - public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable bias, Conv2DConfig config) { + public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) { validateFloatingPoint("depthwiseConv2d", "input", layerInput); validateFloatingPoint("depthwiseConv2d", "depth weights", depthWeights); validateFloatingPoint("depthwiseConv2d", "bias", bias); @@ -475,19 +457,13 @@ public class SDCNN extends SDOps { arr[1] = depthWeights; if (bias != null) arr[2] = bias; - return depthWiseConv2d(arr, config); + return depthWiseConv2d(name, arr, config); } /** - * Depth-wise convolution 2D operation. - * - * @param inputs the inputs to depth-wise conv2d. An array with either 2 elements (layerInput, depthWeights) - * or 3 elements (layerInput, depthWeights, bias) as described in - * {@link #depthWiseConv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)} - * @param depthConv2DConfig the configuration - * @return result of depthwise conv2d op + * See {@link #depthWiseConv2d(String, SDVariable[], Conv2DConfig)}. */ - public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { + public SDVariable depthWiseConv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) { return depthWiseConv2d(null, inputs, depthConv2DConfig); } @@ -501,7 +477,7 @@ public class SDCNN extends SDOps { * @param depthConv2DConfig the configuration * @return result of depthwise conv2d op */ - public SDVariable depthWiseConv2d(String name, SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { + public SDVariable depthWiseConv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) { for(SDVariable v : inputs) validateFloatingPoint("depthWiseConv2d", v); SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig); @@ -509,17 +485,10 @@ public class SDCNN extends SDOps { } /** - * TODO doc string - * - * @param df - * @param weights - * @param strides - * @param rates - * @param isSameMode - * @return + * See {@link #dilation2D(String, SDVariable, SDVariable, int[], int[], boolean)}. */ - public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides, - int[] rates, boolean isSameMode) { + public SDVariable dilation2D(@NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides, + @NonNull int[] rates, @NonNull boolean isSameMode) { return dilation2D(null, df, weights, strides, rates, isSameMode); } @@ -534,8 +503,8 @@ public class SDCNN extends SDOps { * @param isSameMode * @return */ - public SDVariable dilation2D(String name, SDVariable df, SDVariable weights, int[] strides, - int[] rates, boolean isSameMode) { + public SDVariable dilation2D(String name, @NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides, + @NonNull int[] rates, @NonNull boolean isSameMode) { SDVariable ret = f().dilation2D(df, weights, strides, rates, isSameMode); return updateVariableNameAndReference(ret, name); } @@ -555,21 +524,16 @@ public class SDCNN extends SDOps { * @param sameMode If true: use same mode padding. If false * @return */ - public SDVariable extractImagePatches(String name, SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) { + public SDVariable extractImagePatches(String name, @NonNull SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) { SDVariable ret = f().extractImagePatches(input, kH, kW, sH, sW, rH, rW, sameMode); return updateVariableNameAndReference(ret, name); } /** - * im2col operation for use in 2D convolution operations. Outputs a 6d array with shape - * [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] - * - * @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width] - * @param config Convolution configuration for the im2col operation - * @return Im2Col output variable + * See {@link #im2Col(String, SDVariable, Conv2DConfig)}. */ - public SDVariable im2Col(SDVariable in, Conv2DConfig config) { + public SDVariable im2Col(@NonNull SDVariable in, @NonNull Conv2DConfig config) { return im2Col(null, in, config); } @@ -582,20 +546,16 @@ public class SDCNN extends SDOps { * @param config Convolution configuration for the im2col operation * @return Im2Col output variable */ - public SDVariable im2Col(String name, SDVariable in, Conv2DConfig config) { + public SDVariable im2Col(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) { SDVariable ret = f().im2Col(in, config); return updateVariableNameAndReference(ret, name); } /** - * 2D convolution layer operation - local response normalization - * - * @param inputs the inputs to lrn - * @param lrnConfig the configuration - * @return + * See {@link #localResponseNormalization(String, SDVariable, LocalResponseNormalizationConfig)}. */ - public SDVariable localResponseNormalization(SDVariable inputs, LocalResponseNormalizationConfig lrnConfig) { + public SDVariable localResponseNormalization(@NonNull SDVariable inputs, @NonNull LocalResponseNormalizationConfig lrnConfig) { return localResponseNormalization(null, inputs, lrnConfig); } @@ -607,8 +567,8 @@ public class SDCNN extends SDOps { * @param lrnConfig the configuration * @return */ - public SDVariable localResponseNormalization(String name, SDVariable input, - LocalResponseNormalizationConfig lrnConfig) { + public SDVariable localResponseNormalization(String name, @NonNull SDVariable input, + @NonNull LocalResponseNormalizationConfig lrnConfig) { validateFloatingPoint("local response normalization", input); SDVariable ret = f().localResponseNormalization(input, lrnConfig); return updateVariableNameAndReference(ret, name); @@ -616,14 +576,9 @@ public class SDCNN extends SDOps { /** - * 2D Convolution layer operation - max pooling 2d - * - * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param pooling2DConfig the configuration - * @return Result after applying max pooling on the input + * See {@link #maxPooling2d(String, SDVariable, Pooling2DConfig)}. */ - public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { + public SDVariable maxPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) { return maxPooling2d(null, input, pooling2DConfig); } @@ -636,22 +591,16 @@ public class SDCNN extends SDOps { * @param pooling2DConfig the configuration * @return Result after applying max pooling on the input */ - public SDVariable maxPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) { + public SDVariable maxPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) { validateNumerical("maxPooling2d", input); SDVariable ret = f().maxPooling2d(input, pooling2DConfig); return updateVariableNameAndReference(ret, name); } /** - * 3D convolution layer operation - max pooling 3d operation. - * - * @param input the input to average pooling 3d operation - 5d activations in NCDHW format - * (shape [minibatch, channels, depth, height, width]) or NDHWC format - * (shape [minibatch, depth, height, width, channels]) - * @param pooling3DConfig the configuration - * @return Result after applying max pooling on the input + * See {@link #maxPooling3d(String, SDVariable, Pooling3DConfig)}. */ - public SDVariable maxPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) { + public SDVariable maxPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) { return maxPooling3d(null, input, pooling3DConfig); } @@ -665,7 +614,7 @@ public class SDCNN extends SDOps { * @param pooling3DConfig the configuration * @return Result after applying max pooling on the input */ - public SDVariable maxPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) { + public SDVariable maxPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) { validateNumerical("maxPooling3d", input); SDVariable ret = f().maxPooling3d(input, pooling3DConfig); return updateVariableNameAndReference(ret, name); @@ -673,21 +622,30 @@ public class SDCNN extends SDOps { /** - * Separable 2D convolution operation without bias - * - * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] - * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels] - * May be null - * @param config Conv2DConfig configuration - * @return result of separable convolution 2d operation + * See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. */ - public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights, - Conv2DConfig config) { + public SDVariable separableConv2d(SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights, + @NonNull Conv2DConfig config) { return separableConv2d(layerInput, depthWeights, pointWeights, null, config); } + + /** + * See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. + */ + public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights, + @NonNull Conv2DConfig config) { + return separableConv2d(layerInput, depthWeights, pointWeights, null, config); + } + + /** + * See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}. + */ + public SDVariable separableConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights, + SDVariable bias, @NonNull Conv2DConfig config) { + return separableConv2d(null, layerInput, depthWeights, pointWeights, bias, config); + } + /** * Separable 2D convolution operation with optional bias * @@ -700,8 +658,8 @@ public class SDCNN extends SDOps { * @param config Conv2DConfig configuration * @return result of separable convolution 2d operation */ - public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights, - SDVariable bias, Conv2DConfig config) { + public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights, + SDVariable bias, @NonNull Conv2DConfig config) { validateFloatingPoint("separableConv2d", "input", layerInput); validateFloatingPoint("separableConv2d", "depthWeights", depthWeights); validateFloatingPoint("separableConv2d", "pointWeights", pointWeights); @@ -712,18 +670,13 @@ public class SDCNN extends SDOps { arr[2] = pointWeights; if (bias != null) arr[3] = bias; - return sconv2d(arr, config); + return sconv2d(name, arr, config); } /** - * Separable 2D convolution operation with/without optional bias - * - * @param inputs the inputs to separable conv2 operation. Should be length 3 (layerInput, depthWeights, pointWeights) - * or length 4 (layerInput, depthWeights, pointWeights, bias) as described in {@link #separableConv2d(SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)} - * @param conv2DConfig the configuration - * @return result of separable convolution 2d operation + * See {@link #sconv2d(String, SDVariable[], Conv2DConfig)}. */ - public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { + public SDVariable sconv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) { return sconv2d(null, inputs, conv2DConfig); } @@ -736,7 +689,7 @@ public class SDCNN extends SDOps { * @param conv2DConfig the configuration * @return result of separable convolution 2d operation */ - public SDVariable sconv2d(String name, SDVariable[] inputs, Conv2DConfig conv2DConfig) { + public SDVariable sconv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) { for(SDVariable v : inputs) validateFloatingPoint("sconv2d", v); SDVariable ret = f().sconv2d(inputs, conv2DConfig); @@ -747,7 +700,7 @@ public class SDCNN extends SDOps { /** * @see #spaceToBatch(String, SDVariable, int[], int[][]) */ - public SDVariable spaceToBatch(SDVariable x, int[] blocks, int[][] padding) { + public SDVariable spaceToBatch(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) { return spaceToBatch(null, x, blocks, padding); } @@ -762,7 +715,7 @@ public class SDCNN extends SDOps { * @return Output variable * @see #batchToSpace(String, SDVariable, int[], int[][]) */ - public SDVariable spaceToBatch(String name, SDVariable x, int[] blocks, int[][] padding) { + public SDVariable spaceToBatch(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) { SDVariable ret = f().spaceToBatch(x, blocks, padding); return updateVariableNameAndReference(ret, name); } @@ -770,7 +723,7 @@ public class SDCNN extends SDOps { /** * @see #spaceToDepth(String, SDVariable, int, String) */ - public SDVariable spaceToDepth(SDVariable x, int blockSize, String dataFormat) { + public SDVariable spaceToDepth(@NonNull SDVariable x, int blockSize, @NonNull String dataFormat) { return spaceToDepth(null, x, blockSize, dataFormat); } @@ -788,23 +741,39 @@ public class SDCNN extends SDOps { * @return Output variable * @see #depthToSpace(String, SDVariable, int, String) */ - public SDVariable spaceToDepth(String name, SDVariable x, int blockSize, String dataFormat) { + public SDVariable spaceToDepth(String name, @NonNull SDVariable x, int blockSize, @NonNull String dataFormat) { SDVariable ret = f().spaceToDepth(x, blockSize, dataFormat); return updateVariableNameAndReference(ret, name); } /** - * 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format. + * See {@link #upsampling2d(String, SDVariable, boolean, int, int)}, + * scale is used for both height and width dimensions. * - * @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) - * @param scale Scale to upsample in both H and W dimensions - * @return Upsampled input + * @param scale The scale for both height and width dimensions. */ - public SDVariable upsampling2d(SDVariable input, int scale) { + public SDVariable upsampling2d(@NonNull SDVariable input, int scale) { return upsampling2d(null, input, true, scale, scale); } + /** + * See {@link #upsampling2d(String, SDVariable, boolean, int, int)}, + * scale is used for both height and width dimensions. + * + * @param scale The scale for both height and width dimensions. + */ + public SDVariable upsampling2d(String name, @NonNull SDVariable input, int scale) { + return upsampling2d(name, input, true, scale, scale); + } + + /** + * See {@link #upsampling2d(String, SDVariable, boolean, int, int)}. + */ + public SDVariable upsampling2d(@NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) { + return upsampling2d(null, input, nchw, scaleH, scaleW); + } + /** * 2D Convolution layer operation - Upsampling 2d * @@ -814,33 +783,8 @@ public class SDCNN extends SDOps { * @param scaleW Scale to upsample in width dimension * @return Upsampled input */ - public SDVariable upsampling2d(String name, SDVariable input, boolean nchw, int scaleH, int scaleW) { + public SDVariable upsampling2d(String name, @NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) { SDVariable ret = f().upsampling2d(input, nchw, scaleH, scaleW); return updateVariableNameAndReference(ret, name); } - - /** - * 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format. - * - * @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) - * @param scale Scale to upsample in both H and W dimensions - * @return Upsampled input - */ - public SDVariable upsampling2d(String name, SDVariable input, int scale) { - return upsampling2d(name, input, true, scale, scale); - } - - /** - * 2D Convolution layer operation - Upsampling 2d - * - * @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) - * or NHWC format (shape [minibatch, height, width, channels]) - * @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format - * @param scaleH Scale to upsample in height dimension - * @param scaleW Scale to upsample in width dimension - * @return Upsampled input - */ - public SDVariable upsampling2d(SDVariable input, boolean nchw, int scaleH, int scaleW) { - return upsampling2d(null, input, nchw, scaleH, scaleW); - } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java index 6faf29bfc..cce38cf24 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java @@ -31,7 +31,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.base.Preconditions; -import org.nd4j.graph.DataType; +import org.nd4j.graph.DType; import org.nd4j.graph.FlatArray; import org.nd4j.graph.FlatNode; import org.nd4j.graph.FlatProperties; @@ -66,33 +66,33 @@ public class FlatBuffersMapper { public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) { switch (type) { case FLOAT: - return DataType.FLOAT; + return DType.FLOAT; case DOUBLE: - return DataType.DOUBLE; + return DType.DOUBLE; case HALF: - return DataType.HALF; + return DType.HALF; case INT: - return DataType.INT32; + return DType.INT32; case LONG: - return DataType.INT64; + return DType.INT64; case BOOL: - return DataType.BOOL; + return DType.BOOL; case SHORT: - return DataType.INT16; + return DType.INT16; case BYTE: - return DataType.INT8; + return DType.INT8; case UBYTE: - return DataType.UINT8; + return DType.UINT8; case UTF8: - return DataType.UTF8; + return DType.UTF8; case UINT16: - return DataType.UINT16; + return DType.UINT16; case UINT32: - return DataType.UINT32; + return DType.UINT32; case UINT64: - return DataType.UINT64; + return DType.UINT64; case BFLOAT16: - return DataType.BFLOAT16; + return DType.BFLOAT16; default: throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]"); } @@ -102,33 +102,33 @@ public class FlatBuffersMapper { * This method converts enums for DataType */ public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) { - if (val == DataType.FLOAT) { + if (val == DType.FLOAT) { return org.nd4j.linalg.api.buffer.DataType.FLOAT; - } else if (val == DataType.DOUBLE) { + } else if (val == DType.DOUBLE) { return org.nd4j.linalg.api.buffer.DataType.DOUBLE; - } else if (val == DataType.HALF) { + } else if (val == DType.HALF) { return org.nd4j.linalg.api.buffer.DataType.HALF; - } else if (val == DataType.INT32) { + } else if (val == DType.INT32) { return org.nd4j.linalg.api.buffer.DataType.INT; - } else if (val == DataType.INT64) { + } else if (val == DType.INT64) { return org.nd4j.linalg.api.buffer.DataType.LONG; - } else if (val == DataType.INT8) { + } else if (val == DType.INT8) { return org.nd4j.linalg.api.buffer.DataType.BYTE; - } else if (val == DataType.BOOL) { + } else if (val == DType.BOOL) { return org.nd4j.linalg.api.buffer.DataType.BOOL; - } else if (val == DataType.UINT8) { + } else if (val == DType.UINT8) { return org.nd4j.linalg.api.buffer.DataType.UBYTE; - } else if (val == DataType.INT16) { + } else if (val == DType.INT16) { return org.nd4j.linalg.api.buffer.DataType.SHORT; - } else if (val == DataType.UTF8) { + } else if (val == DType.UTF8) { return org.nd4j.linalg.api.buffer.DataType.UTF8; - } else if (val == DataType.UINT16) { + } else if (val == DType.UINT16) { return org.nd4j.linalg.api.buffer.DataType.UINT16; - } else if (val == DataType.UINT32) { + } else if (val == DType.UINT32) { return org.nd4j.linalg.api.buffer.DataType.UINT32; - } else if (val == DataType.UINT64) { + } else if (val == DType.UINT64) { return org.nd4j.linalg.api.buffer.DataType.UINT64; - } else if (val == DataType.BFLOAT16){ + } else if (val == DType.BFLOAT16){ return org.nd4j.linalg.api.buffer.DataType.BFLOAT16; } else { throw new RuntimeException("Unknown datatype: " + val); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DataType.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DType.java similarity index 95% rename from nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DataType.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DType.java index 17a0752f0..2617ce8f6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DataType.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DType.java @@ -2,8 +2,8 @@ package org.nd4j.graph; -public final class DataType { - private DataType() { } +public final class DType { + private DType() { } public static final byte INHERIT = 0; public static final byte BOOL = 1; public static final byte FLOAT8 = 2; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 5bfba7a48..19b534a97 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -353,6 +353,10 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index ac642872c..771b74615 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -1149,16 +1149,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride(), elementWiseStride(), order, this.dataType(), isEmpty())); } - @Override - public void setShape(long[] shape) { - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride(), elementWiseStride(), ordering(), this.dataType(), isEmpty())); - } - - @Override - public void setStride(long[] stride) { - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride, elementWiseStride(), ordering(), this.dataType(), isEmpty())); - } - @Override public void setShapeAndStride(int[] shape, int[] stride) { setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 0, ordering(), this.dataType(), false)); @@ -1283,29 +1273,16 @@ public abstract class BaseNDArray implements INDArray, Iterable { return scalar.getDouble(0); } - /** - * Returns entropy value for this INDArray - * @return - */ @Override public Number entropyNumber() { return entropy(Integer.MAX_VALUE).getDouble(0); } - /** - * Returns non-normalized Shannon entropy value for this INDArray - * @return - */ @Override public Number shannonEntropyNumber() { return shannonEntropy(Integer.MAX_VALUE).getDouble(0); } - - /** - * Returns log entropy value for this INDArray - * @return - */ @Override public Number logEntropyNumber() { return logEntropy(Integer.MAX_VALUE).getDouble(0); @@ -2297,37 +2274,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return size(0); } - @Override - public INDArray subArray(long[] offsets, int[] shape, int[] stride) { - Nd4j.getCompressor().autoDecompress(this); - int n = shape.length; - - // FIXME: shapeInfo should be used here - if (shape.length < 1) - return create(Nd4j.createBufferDetached(shape)); - if (offsets.length != n) - throw new IllegalArgumentException("Invalid offset " + Arrays.toString(offsets)); - if (stride.length != n) - throw new IllegalArgumentException("Invalid stride " + Arrays.toString(stride)); - - if (Shape.contentEquals(shape, shapeOf())) { - if (ArrayUtil.isZero(offsets)) { - return this; - } else { - throw new IllegalArgumentException("Invalid subArray offsets"); - } - } - - long[] dotProductOffsets = offsets; - int[] dotProductStride = stride; - - long offset = Shape.offset(jvmShapeInfo.javaShapeInformation) + NDArrayIndex.offset(dotProductStride, dotProductOffsets); - if (offset >= data().length()) - offset = ArrayUtil.sumLong(offsets); - - return create(data, Arrays.copyOf(shape, shape.length), stride, offset, ordering()); - } - protected INDArray create(DataBuffer buffer) { return Nd4j.create(buffer); } @@ -4016,58 +3962,30 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.getExecutioner().exec(new AMin(this, dimension)); } - /** - * Returns the sum along the specified dimension(s) of this ndarray - * - * @param dimension the dimension to getScalar the sum along - * @return the sum along the specified dimension of this ndarray - */ @Override public INDArray sum(int... dimension) { validateNumericalArray("sum", true); return Nd4j.getExecutioner().exec(new Sum(this, dimension)); } - /** - * Returns the sum along the last dimension of this ndarray - * - * @param dimension the dimension to getScalar the sum along - * @return the sum along the specified dimension of this ndarray - */ @Override public INDArray sum(boolean keepDim, int... dimension) { validateNumericalArray("sum", true); return Nd4j.getExecutioner().exec(new Sum(this, null, keepDim, dimension)); } - - /** - * Returns entropy along dimension - * @param dimension - * @return - */ @Override public INDArray entropy(int... dimension) { validateNumericalArray("entropy", false); return Nd4j.getExecutioner().exec(new Entropy(this, dimension)); } - /** - * Returns non-normalized Shannon entropy along dimension - * @param dimension - * @return - */ @Override public INDArray shannonEntropy(int... dimension) { validateNumericalArray("shannonEntropy", false); return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension)); } - /** - * Returns log entropy along dimension - * @param dimension - * @return - */ @Override public INDArray logEntropy(int... dimension) { validateNumericalArray("logEntropy", false); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java index 1e0772494..11a005f91 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java @@ -468,16 +468,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { throw new UnsupportedOperationException(); } - @Override - public void setStride(long... stride) { - throw new UnsupportedOperationException(); - } - - @Override - public void setShape(long... shape) { - throw new UnsupportedOperationException(); - } - @Override public INDArray putScalar(long row, long col, double value) { return null; @@ -1284,17 +1274,10 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { @Override public void setShapeAndStride(int[] shape, int[] stride) { - } @Override public void setOrder(char order) { - - } - - @Override - public INDArray subArray(long[] offsets, int[] shape, int[] stride) { - return null; } @Override @@ -1842,49 +1825,26 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { return null; } - /** - * Returns entropy value for this INDArray - * @return - */ @Override public Number entropyNumber() { return entropy(Integer.MAX_VALUE).getDouble(0); } - /** - * Returns non-normalized Shannon entropy value for this INDArray - * @return - */ @Override public Number shannonEntropyNumber() { return shannonEntropy(Integer.MAX_VALUE).getDouble(0); } - - /** - * Returns log entropy value for this INDArray - * @return - */ @Override public Number logEntropyNumber() { return logEntropy(Integer.MAX_VALUE).getDouble(0); } - /** - * Returns entropy along dimension - * @param dimension - * @return - */ @Override public INDArray entropy(int... dimension) { return Nd4j.getExecutioner().exec(new Entropy(this, dimension)); } - /** - * Returns non-normalized Shannon entropy along dimension - * @param dimension - * @return - */ @Override public INDArray shannonEntropy(int... dimension) { return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java index 116a4b4f7..85a7ec5ce 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java @@ -1016,13 +1016,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray { return extendedFlags; } - @Override - public INDArray subArray(long[] offsets, int[] shape, int[] stride) { - throw new UnsupportedOperationException(); - } - - - /** * Returns the underlying indices of the element of the given index * such as there really are in the original ndarray @@ -1138,16 +1131,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray { return null; } - @Override - public void setStride(long... stride) { - - } - - @Override - public void setShape(long... shape) { - - } - /** * This method returns true if this INDArray is special case: no-value INDArray * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java index 92e59486c..cf2f9fe3f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java @@ -213,11 +213,6 @@ public abstract class BaseSparseNDArrayCSR extends BaseSparseNDArray { return shapeInformation; } - @Override - public INDArray subArray(long[] offsets, int[] shape, int[] stride) { - throw new UnsupportedOperationException(); - } - @Override public boolean equals(Object o) { //TODO use op diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index 47e259b94..9288b6d51 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -1854,63 +1854,47 @@ public interface INDArray extends Serializable, AutoCloseable { /** * Returns entropy value for this INDArray - * @return + * @return entropy value */ Number entropyNumber(); /** * Returns non-normalized Shannon entropy value for this INDArray - * @return + * @return non-normalized Shannon entropy */ Number shannonEntropyNumber(); /** * Returns log entropy value for this INDArray - * @return + * @return log entropy value */ Number logEntropyNumber(); /** * Returns entropy value for this INDArray along specified dimension(s) - * @return + * @param dimension specified dimension(s) + * @return entropy value */ INDArray entropy(int... dimension); /** - * Returns entropy value for this INDArray along specified dimension(s) - * @return + * Returns Shannon entropy value for this INDArray along specified dimension(s) + * @param dimension specified dimension(s) + * @return Shannon entropy */ INDArray shannonEntropy(int... dimension); /** - * Returns entropy value for this INDArray along specified dimension(s) - * @return + * Returns log entropy value for this INDArray along specified dimension(s) + * @param dimension specified dimension(s) + * @return log entropy value */ INDArray logEntropy(int... dimension); - - /** - * stride setter - * @param stride - * @deprecated, use {@link #reshape(int...) } - */ - @Deprecated - void setStride(long... stride); - - /** - * Shape setter - * @param shape - * @deprecated, use {@link #reshape(int...) } - */ - - - @Deprecated - void setShape(long... shape); - /** * Shape and stride setter - * @param shape - * @param stride + * @param shape new value for shape + * @param stride new value for stride */ void setShapeAndStride(int[] shape, int[] stride); @@ -1919,15 +1903,7 @@ public interface INDArray extends Serializable, AutoCloseable { * @param order the ordering to set */ void setOrder(char order); - - /** - * @param offsets - * @param shape - * @param stride - * @return - */ - INDArray subArray(long[] offsets, int[] shape, int[] stride); - + /** * Returns the elements at the specified indices * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java index ac13c6224..2f295cc6a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import onnx.Onnx; @@ -53,19 +54,19 @@ public class AvgPooling2D extends DynamicCustomOp { } - @Builder(builderMethodName = "builder") - public AvgPooling2D(SameDiff sameDiff, SDVariable input, INDArray arrayInput, INDArray arrayOutput, Pooling2DConfig config) { - super(null, sameDiff, new SDVariable[]{input}, false); - if (arrayInput != null) { - addInputArgument(arrayInput); - } - if (arrayOutput != null) { - addOutputArgument(arrayOutput); - } + @Builder(builderMethodName = "sameDiffBuilder") + public AvgPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) { + super(sameDiff, new SDVariable[]{input}); config.setType(Pooling2D.Pooling2DType.AVG); + this.config = config; + addArgs(); + } + + public AvgPooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){ + super(new INDArray[]{input}, wrapOrNull(output)); + config.setType(Pooling2D.Pooling2DType.AVG); - this.sameDiff = sameDiff; this.config = config; addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java index 5ae2ac144..2fc814fb3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import onnx.Onnx; @@ -39,6 +40,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -59,18 +61,28 @@ public class Conv1D extends DynamicCustomOp { protected Conv1DConfig config; private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s "; - @Builder(builderMethodName = "builder") + @Builder(builderMethodName = "sameDiffBuilder") public Conv1D(SameDiff sameDiff, SDVariable[] inputFunctions, - INDArray[] inputArrays, INDArray[] outputs, Conv1DConfig config) { - super(null, inputArrays, outputs); - this.sameDiff = sameDiff; + super(sameDiff, inputFunctions); + initConfig(config); + } + + public Conv1D(INDArray[] inputs, INDArray[] outputs, Conv1DConfig config){ + super(inputs, outputs); + + initConfig(config); + } + + public Conv1D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv1DConfig config){ + this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); + } + + private void initConfig(Conv1DConfig config){ this.config = config; Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP()); addArgs(); - sameDiff.putOpForId(this.getOwnName(), this); - sameDiff.addArgsFor(inputFunctions, this); } protected void addArgs() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java index 04db5874c..5e077e3fc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import onnx.Onnx; @@ -56,23 +57,32 @@ public class Conv2D extends DynamicCustomOp { protected Conv2DConfig config; private static final String INVALID_CONFIGURATION = "Invalid Conv2D configuration : sW = %s pH = %s dW = %s "; - @Builder(builderMethodName = "builder") + @Builder(builderMethodName = "sameDiffBuilder") public Conv2D(SameDiff sameDiff, SDVariable[] inputFunctions, - INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig config) { - super(null, inputArrays, outputs); - this.sameDiff = sameDiff; + super(sameDiff, inputFunctions); + + initConfig(config); + } + + public Conv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){ + super(inputs, outputs); + + initConfig(config); + } + + public Conv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){ + this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); + } + + protected void initConfig(Conv2DConfig config){ this.config = config; Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1, - INVALID_CONFIGURATION, - config.getSH(), config.getPH(), config.getDW()); + INVALID_CONFIGURATION, + config.getSH(), config.getPH(), config.getDW()); addArgs(); - if(sameDiff != null) { - sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point - sameDiff.addArgsFor(inputFunctions, this); - } } protected void addArgs() { @@ -252,7 +262,6 @@ public class Conv2D extends DynamicCustomOp { Conv2DDerivative conv2DDerivative = Conv2DDerivative.derivativeBuilder() .sameDiff(sameDiff) .config(config) - .outputs(outputArguments()) .inputFunctions(inputs.toArray(new SDVariable[inputs.size()])) .build(); List ret = Arrays.asList(conv2DDerivative.outputVariables()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2DDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2DDerivative.java index 8ccbd84eb..cd5ab6556 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2DDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2DDerivative.java @@ -37,8 +37,8 @@ import java.util.List; public class Conv2DDerivative extends Conv2D { @Builder(builderMethodName = "derivativeBuilder") - public Conv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig config) { - super(sameDiff, inputFunctions, inputArrays, outputs, config); + public Conv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig config) { + super(sameDiff, inputFunctions, config); } public Conv2DDerivative() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java index 810974103..8c4e40e8a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; import lombok.Getter; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.nd4j.autodiff.functions.DifferentialFunction; @@ -33,6 +34,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -55,25 +57,27 @@ public class Conv3D extends DynamicCustomOp { public Conv3D() { } - @Builder(builderMethodName = "builder") - public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, - Conv3DConfig conv3DConfig) { - super(null, sameDiff, inputFunctions, false); - setSameDiff(sameDiff); + @Builder(builderMethodName = "sameDiffBuilder") + public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig config) { + super(sameDiff, inputFunctions); + initConfig(config); + } - if (inputs != null) - addInputArgument(inputs); - if (outputs != null) - addOutputArgument(outputs); - this.config = conv3DConfig; + public Conv3D(INDArray[] inputs, INDArray[] outputs, Conv3DConfig config){ + super(inputs, outputs); + initConfig(config); + } + + public Conv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv3DConfig config){ + this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); + } + + private void initConfig(Conv3DConfig config){ + this.config = config; Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1, - INVALID_CONFIGURATION, - config.getSW(), config.getPH(), config.getDW()); + INVALID_CONFIGURATION, + config.getSW(), config.getPH(), config.getDW()); addArgs(); - - - //for (val arg: iArgs()) - // System.out.println(getIArgument(arg)); } @@ -259,8 +263,6 @@ public class Conv3D extends DynamicCustomOp { inputs.add(f1.get(0)); Conv3DDerivative conv3DDerivative = Conv3DDerivative.derivativeBuilder() .conv3DConfig(config) - .inputFunctions(args()) - .outputs(outputArguments()) .inputFunctions(inputs.toArray(new SDVariable[inputs.size()])) .sameDiff(sameDiff) .build(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3DDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3DDerivative.java index ee34fca90..ea6312094 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3DDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3DDerivative.java @@ -39,8 +39,8 @@ public class Conv3DDerivative extends Conv3D { public Conv3DDerivative() {} @Builder(builderMethodName = "derivativeBuilder") - public Conv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, Conv3DConfig conv3DConfig) { - super(sameDiff, inputFunctions, inputs, outputs, conv3DConfig); + public Conv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig conv3DConfig) { + super(sameDiff, inputFunctions, conv3DConfig); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java index 65c0fccc3..c69292dd9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import onnx.Onnx; @@ -31,6 +32,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; @@ -51,25 +53,25 @@ public class DeConv2D extends DynamicCustomOp { protected DeConv2DConfig config; - @Builder(builderMethodName = "builder") + @Builder(builderMethodName = "sameDiffBuilder") public DeConv2D(SameDiff sameDiff, SDVariable[] inputs, - INDArray[] inputArrays, INDArray[] outputs, DeConv2DConfig config) { - super(null, inputArrays, outputs); - this.sameDiff = sameDiff; + super(sameDiff, inputs); this.config = config; - if (inputArrays != null) { - addInputArgument(inputArrays); - } - if (outputs != null) { - addOutputArgument(outputs); - } - addArgs(); - sameDiff.putOpForId(this.getOwnName(), this); - sameDiff.addArgsFor(inputs, this); + } + + public DeConv2D(INDArray[] inputs, INDArray[] outputs, DeConv2DConfig config){ + super(inputs, outputs); + + this.config = config; + addArgs(); + } + + public DeConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull DeConv2DConfig config){ + this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DDerivative.java index 174d95ed7..04dc1dd2d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DDerivative.java @@ -40,8 +40,8 @@ public class DeConv2DDerivative extends DeConv2D { public DeConv2DDerivative() {} @Builder(builderMethodName = "derivativeBuilder") - public DeConv2DDerivative(SameDiff sameDiff, SDVariable[] inputs, INDArray[] inputArrays, INDArray[] outputs, DeConv2DConfig config) { - super(sameDiff, inputs, inputArrays, outputs, config); + public DeConv2DDerivative(SameDiff sameDiff, SDVariable[] inputs, DeConv2DConfig config) { + super(sameDiff, inputs, config); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java index 085f48365..bc4f996b1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java @@ -53,25 +53,21 @@ public class DeConv2DTF extends DynamicCustomOp { protected DeConv2DConfig config; - @Builder(builderMethodName = "builder") + @Builder(builderMethodName = "sameDiffBuilder") public DeConv2DTF(SameDiff sameDiff, SDVariable[] inputs, - INDArray[] inputArrays, INDArray[] outputs, DeConv2DConfig config) { - super(null, inputArrays, outputs); - this.sameDiff = sameDiff; + super(sameDiff, inputs); + + this.config = config; + addArgs(); + } + + public DeConv2DTF(INDArray[] inputs, INDArray[] outputs, DeConv2DConfig config){ + super(inputs, outputs); + this.config = config; - - if (inputArrays != null) { - addInputArgument(inputArrays); - } - if (outputs != null) { - addOutputArgument(outputs); - } - addArgs(); - sameDiff.putOpForId(this.getOwnName(), this); - sameDiff.addArgsFor(inputs, this); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java index 20b28da5a..077f6a64b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java @@ -28,6 +28,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; @@ -53,12 +54,23 @@ public class DeConv3D extends DynamicCustomOp { protected DeConv3DConfig config; - public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, DeConv3DConfig config) { + public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { super(sameDiff, toArr(input, weights, bias)); this.config = config; addArgs(); } + public DeConv3D(INDArray[] inputs, INDArray[] outputs, DeConv3DConfig config){ + super(inputs, outputs); + + this.config = config; + addArgs(); + } + + public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull DeConv3DConfig config){ + this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); + } + private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){ if(bias != null){ return new SDVariable[]{input, weights, bias}; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java index 92a39f188..ec2bb1d3f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; import lombok.Getter; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import onnx.Onnx; @@ -35,6 +36,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -53,17 +55,25 @@ public class DepthwiseConv2D extends DynamicCustomOp { protected Conv2DConfig config; - @Builder(builderMethodName = "builder") + @Builder(builderMethodName = "sameDiffBuilder") public DepthwiseConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, - INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig config) { - super(null, inputArrays, outputs); - this.sameDiff = sameDiff; + super(sameDiff, inputFunctions); + this.config = config; addArgs(); - sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point - sameDiff.addArgsFor(inputFunctions, this); + } + + public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){ + super(inputs, outputs); + + this.config = config; + addArgs(); + } + + public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){ + this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } public DepthwiseConv2D() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java index 421598d13..8dfb7131a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import onnx.Onnx; @@ -48,18 +49,19 @@ public class LocalResponseNormalization extends DynamicCustomOp { protected LocalResponseNormalizationConfig config; - @Builder(builderMethodName = "builder") - public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] inputFunctions, - INDArray[] inputs, INDArray[] outputs,boolean inPlace, + @Builder(builderMethodName = "sameDiffBuilder") + public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] inputFunctions, boolean inPlace, LocalResponseNormalizationConfig config) { super(null,sameDiff, inputFunctions, inPlace); + + this.config = config; + addArgs(); + } + + public LocalResponseNormalization(@NonNull INDArray input, INDArray output, @NonNull LocalResponseNormalizationConfig config){ + super(new INDArray[]{input}, wrapOrNull(output)); + this.config = config; - if(inputs != null) { - addInputArgument(inputs); - } - if(outputs!= null) { - addOutputArgument(outputs); - } addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalizationDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalizationDerivative.java index 2159f87fa..c2e6aad15 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalizationDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalizationDerivative.java @@ -33,8 +33,8 @@ import java.util.List; @Slf4j public class LocalResponseNormalizationDerivative extends LocalResponseNormalization { @Builder(builderMethodName = "derivativeBuilder") - public LocalResponseNormalizationDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, boolean inPlace, LocalResponseNormalizationConfig config) { - super(sameDiff, inputFunctions, inputs, outputs, inPlace, config); + public LocalResponseNormalizationDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, boolean inPlace, LocalResponseNormalizationConfig config) { + super(sameDiff, inputFunctions, inPlace, config); } public LocalResponseNormalizationDerivative() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java index b321334a5..09e928d2f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java @@ -51,27 +51,18 @@ public class MaxPooling2D extends DynamicCustomOp { public MaxPooling2D() { } - @Builder(builderMethodName = "builder") + @Builder(builderMethodName = "sameDiffBuilder") @SuppressWarnings("Used in lombok") - public MaxPooling2D(SameDiff sameDiff, SDVariable input, INDArray arrayInput, INDArray arrayOutput, Pooling2DConfig config) { + public MaxPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) { super(null, sameDiff, new SDVariable[]{input}, false); - if (arrayInput != null) { - addInputArgument(arrayInput); - } - if (arrayOutput != null) { - addOutputArgument(arrayOutput); - } config.setType(Pooling2D.Pooling2DType.MAX); - this.config = config; - this.sameDiff = sameDiff; - addArgs(); } public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){ - super(null, new INDArray[]{input}, output == null ? null : new INDArray[]{output}); + super(null, new INDArray[]{input}, wrapOrNull(output)); config.setType(Pooling2D.Pooling2DType.MAX); this.config = config; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2D.java index c45d106e7..ab2984969 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2D.java @@ -16,8 +16,14 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; import lombok.Builder; import lombok.Getter; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import onnx.Onnx; @@ -33,9 +39,6 @@ import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.lang.reflect.Field; -import java.util.*; - /** * Pooling2D operation @@ -70,21 +73,27 @@ public class Pooling2D extends DynamicCustomOp { public Pooling2D() {} - @Builder(builderMethodName = "builder") + @Builder(builderMethodName = "sameDiffBuilder") @SuppressWarnings("Used in lombok") - public Pooling2D(SameDiff sameDiff, SDVariable[] inputs,INDArray[] arrayInputs, INDArray[] arrayOutputs,Pooling2DConfig config) { - super(null,sameDiff, inputs, false); - if(arrayInputs != null) { - addInputArgument(arrayInputs); - } + public Pooling2D(SameDiff sameDiff, SDVariable[] inputs, + Pooling2DConfig config) { + super(null, sameDiff, inputs, false); - if(arrayOutputs != null) { - addOutputArgument(arrayOutputs); - } + this.config = config; + addArgs(); + } - this.config = config; + public Pooling2D(@NonNull INDArray[] inputs, INDArray[] outputs, @NonNull Pooling2DConfig config){ + super(inputs, outputs); + this.config = config; + addArgs(); + } + public Pooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){ + super(new INDArray[]{input}, wrapOrNull(output)); + + this.config = config; addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2DDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2DDerivative.java index 6fdb40215..aa58603e1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2DDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2DDerivative.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -36,8 +37,12 @@ import java.util.List; @Slf4j public class Pooling2DDerivative extends Pooling2D { @Builder(builderMethodName = "derivativeBuilder") - public Pooling2DDerivative(SameDiff sameDiff, SDVariable[] inputs, INDArray[] arrayInputs, INDArray[] arrayOutputs, Pooling2DConfig config) { - super(sameDiff, inputs, arrayInputs, arrayOutputs, config); + public Pooling2DDerivative(SameDiff sameDiff, SDVariable[] inputs, Pooling2DConfig config) { + super(sameDiff, inputs, config); + } + + public Pooling2DDerivative(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Pooling2DConfig config){ + super(new INDArray[]{input, grad}, wrapOrNull(output), config); } public Pooling2DDerivative() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java index 745caccba..d4ef84e88 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -39,9 +40,17 @@ import java.util.List; @Slf4j public class SConv2D extends Conv2D { - @Builder(builderMethodName = "sBuilder") - public SConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig conv2DConfig) { - super(sameDiff, inputFunctions, inputArrays, outputs, conv2DConfig); + @Builder(builderMethodName = "sameDiffSBuilder") + public SConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig conv2DConfig) { + super(sameDiff, inputFunctions, conv2DConfig); + } + + public SConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){ + super(inputs, outputs, config); + } + + public SConv2D(@NonNull INDArray input, @NonNull INDArray depthWeights, INDArray pointWeights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){ + this(wrapFilterNull(input, depthWeights, pointWeights, bias), wrapOrNull(output), config); } public SConv2D() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2DDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2DDerivative.java index e25dae144..a30a58d95 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2DDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2DDerivative.java @@ -38,8 +38,8 @@ import java.util.List; public class SConv2DDerivative extends SConv2D { @Builder(builderMethodName = "sDerviativeBuilder") - public SConv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig conv2DConfig) { - super(sameDiff, inputFunctions, inputArrays, outputs, conv2DConfig); + public SConv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig conv2DConfig) { + super(sameDiff, inputFunctions, conv2DConfig); } public SConv2DDerivative() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitsHammingDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitsHammingDistance.java new file mode 100644 index 000000000..1fa749830 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitsHammingDistance.java @@ -0,0 +1,37 @@ +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class BitsHammingDistance extends DynamicCustomOp { + + public BitsHammingDistance(){ } + + public BitsHammingDistance(@NonNull SameDiff sd, @NonNull SDVariable x, @NonNull SDVariable y){ + super(sd, new SDVariable[]{x, y}); + } + + public BitsHammingDistance(@NonNull INDArray x, @NonNull INDArray y){ + super(new INDArray[]{x, y}, null); + } + + @Override + public String opName() { + return "bits_hamming_distance"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected 2 input datatypes, got %s", dataTypes); + Preconditions.checkState(dataTypes.get(0).isIntType() && dataTypes.get(1).isIntType(), "Input datatypes must be integer type, got %s", dataTypes); + return Collections.singletonList(DataType.LONG); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseAnd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseAnd.java new file mode 100644 index 000000000..d81a72c1f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseAnd.java @@ -0,0 +1,78 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.Collections; +import java.util.List; + +/** + * Bit-wise AND operation, broadcastable + * + * @author raver119@gmail.com + */ +public class BitwiseAnd extends BaseDynamicTransformOp { + + public BitwiseAnd(SameDiff sameDiff, SDVariable x, SDVariable y) { + super(sameDiff, new SDVariable[] {x, y} ,false); + } + + public BitwiseAnd(INDArray x, INDArray y, INDArray output) { + super(new INDArray[]{x, y}, new INDArray[]{output}); + } + + public BitwiseAnd(INDArray x, INDArray y) { + this(x, y,x.ulike()); + } + + public BitwiseAnd() {} + + @Override + public String opName() { + return "bitwise_and"; + } + + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + return "bitwise_and"; + } + + + @Override + public List doDiff(List i_v) { + throw new UnsupportedOperationException("Not yet implemented: " + opName()); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0)); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseOr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseOr.java new file mode 100644 index 000000000..85dd5c31b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseOr.java @@ -0,0 +1,78 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.Collections; +import java.util.List; + +/** + * Bit-wise OR operation, broadcastable + * + * @author raver119@gmail.com + */ +public class BitwiseOr extends BaseDynamicTransformOp { + + public BitwiseOr(SameDiff sameDiff, SDVariable x, SDVariable y) { + super(sameDiff, new SDVariable[] {x, y} ,false); + } + + public BitwiseOr(INDArray x, INDArray y, INDArray output) { + super(new INDArray[]{x, y}, new INDArray[]{output}); + } + + public BitwiseOr(INDArray x, INDArray y) { + this(x, y,x.ulike()); + } + + public BitwiseOr() {} + + @Override + public String opName() { + return "bitwise_or"; + } + + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + return "bitwise_or"; + } + + + @Override + public List doDiff(List i_v) { + throw new UnsupportedOperationException("Not yet implemented: " + opName()); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0)); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseXor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseXor.java new file mode 100644 index 000000000..136ca9b62 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseXor.java @@ -0,0 +1,78 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.Collections; +import java.util.List; + +/** + * Bit-wise XOR operation, broadcastable + * + * @author raver119@gmail.com + */ +public class BitwiseXor extends BaseDynamicTransformOp { + + public BitwiseXor(SameDiff sameDiff, SDVariable x, SDVariable y) { + super(sameDiff, new SDVariable[] {x, y} ,false); + } + + public BitwiseXor(INDArray x, INDArray y, INDArray output) { + super(new INDArray[]{x, y}, new INDArray[]{output}); + } + + public BitwiseXor(INDArray x, INDArray y) { + this(x, y,x.ulike()); + } + + public BitwiseXor() {} + + @Override + public String opName() { + return "bitwise_xor"; + } + + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + return "bitwise_xor"; + } + + + @Override + public List doDiff(List i_v) { + throw new UnsupportedOperationException("Not yet implemented: " + opName()); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0)); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java index 3a9173654..a8b4ebbb0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java @@ -61,7 +61,7 @@ public class CyclicRShiftBits extends BaseDynamicTransformOp { @Override public String tensorflowName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java index 20b6f6955..ea7ae1715 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java @@ -61,7 +61,7 @@ public class CyclicShiftBits extends BaseDynamicTransformOp { @Override public String tensorflowName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java index 4435615f5..3cc03d12b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java @@ -61,7 +61,7 @@ public class RShiftBits extends BaseDynamicTransformOp { @Override public String tensorflowName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java index 5501324f2..a9eebb14e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java @@ -61,7 +61,7 @@ public class ShiftBits extends BaseDynamicTransformOp { @Override public String tensorflowName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java index cab411916..b31e6e036 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java @@ -235,24 +235,20 @@ public class Convolution { public static INDArray pooling2D(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int dh, int dw, boolean isSameMode, Pooling2D.Pooling2DType type, Pooling2D.Divisor divisor, double extra, int virtualHeight, int virtualWidth, INDArray out) { - Pooling2D pooling = Pooling2D.builder() - .arrayInputs(new INDArray[]{img}) - .arrayOutputs(new INDArray[]{out}) - .config(Pooling2DConfig.builder() - .dH(dh) - .dW(dw) - .extra(extra) - .kH(kh) - .kW(kw) - .pH(ph) - .pW(pw) - .isSameMode(isSameMode) - .sH(sy) - .sW(sx) - .type(type) - .divisor(divisor) - .build()) - .build(); + Pooling2D pooling = new Pooling2D(img, out, Pooling2DConfig.builder() + .dH(dh) + .dW(dw) + .extra(extra) + .kH(kh) + .kW(kw) + .pH(ph) + .pW(pw) + .isSameMode(isSameMode) + .sH(sy) + .sW(sx) + .type(type) + .divisor(divisor) + .build()); Nd4j.getExecutioner().execAndReturn(pooling); return out; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java index c21993548..40aa692eb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java @@ -96,57 +96,6 @@ public abstract class NDArrayIndex implements INDArrayIndex { return offset(arr.stride(), Indices.offsets(arr.shape(), indices)); } - /** - * Set the shape and stride for - * new axes based dimensions - * @param arr the array to update - * the shape/strides for - * @param indexes the indexes to update based on - */ - public static void updateForNewAxes(INDArray arr, INDArrayIndex... indexes) { - int numNewAxes = NDArrayIndex.numNewAxis(indexes); - if (numNewAxes >= 1 && (indexes[0].length() > 1 || indexes[0] instanceof NDArrayIndexAll)) { - List newShape = new ArrayList<>(); - List newStrides = new ArrayList<>(); - int currDimension = 0; - for (int i = 0; i < indexes.length; i++) { - if (indexes[i] instanceof NewAxis) { - newShape.add(1L); - newStrides.add(0L); - } else { - newShape.add(arr.size(currDimension)); - newStrides.add(arr.size(currDimension)); - currDimension++; - } - } - - while (currDimension < arr.rank()) { - newShape.add((long) currDimension); - newStrides.add((long) currDimension); - currDimension++; - } - - long[] newShapeArr = Longs.toArray(newShape); - long[] newStrideArr = Longs.toArray(newStrides); - - // FIXME: this is wrong, it breaks shapeInfo immutability - arr.setShape(newShapeArr); - arr.setStride(newStrideArr); - - - } else { - if (numNewAxes > 0) { - long[] newShape = Longs.concat(ArrayUtil.toLongArray(ArrayUtil.nTimes(numNewAxes, 1)), arr.shape()); - long[] newStrides = Longs.concat(new long[numNewAxes], arr.stride()); - arr.setShape(newShape); - arr.setStride(newStrides); - } - } - - } - - - /** * Compute the offset given an array of offsets. * The offset is computed(for both fortran an d c ordering) as: diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java index 26d850366..30c68d578 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java @@ -54,7 +54,7 @@ public class DeallocatorService { deallocatorThreads = new Thread[numThreads]; queues = new ReferenceQueue[numThreads]; for (int e = 0; e < numThreads; e++) { - log.debug("Starting deallocator thread {}", e + 1); + log.trace("Starting deallocator thread {}", e + 1); queues[e] = new ReferenceQueue<>(); int deviceId = e % numDevices; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index 576cea78a..e694587b0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -1151,4 +1151,6 @@ public interface NativeOps { int lastErrorCode(); String lastErrorMessage(); + + boolean isBlasVersionMatches(int major, int minor, int build); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java index 98bdb90fa..ae31ea7b8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java @@ -101,7 +101,7 @@ public class NativeOpsHolder { } //deviceNativeOps.setOmpNumThreads(4); - log.info("Number of threads used for NativeOps: {}", deviceNativeOps.ompGetMaxThreads()); + log.info("Number of threads used for OpenMP: {}", deviceNativeOps.ompGetMaxThreads()); } catch (Exception | Error e) { throw new RuntimeException( "ND4J is probably missing dependencies. For more information, please refer to: http://nd4j.org/getstarted.html", diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java index 23abf1d40..5de827d1a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java @@ -51,7 +51,8 @@ public abstract class Nd4jBlas implements Blas { numThreads = NativeOpsHolder.getCores(Runtime.getRuntime().availableProcessors()); setMaxThreads(numThreads); } - log.info("Number of threads used for BLAS: {}", getMaxThreads()); + + log.info("Number of threads used for OpenMP BLAS: {}", getMaxThreads()); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java index 28910ae6a..a6a5a45e4 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java @@ -52,6 +52,7 @@ public class JCublasBackend extends Nd4jBackend { throw new RuntimeException("No CUDA devices were found in system"); } Loader.load(org.bytedeco.cuda.global.cublas.class); + return true; } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java index 9e9dc34b2..73daa679d 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java @@ -108,6 +108,22 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); +/* + val major = new int[1]; + val minor = new int[1]; + val build = new int[1]; + org.bytedeco.cuda.global.cublas.cublasGetProperty(0, major); + org.bytedeco.cuda.global.cublas.cublasGetProperty(1, minor); + org.bytedeco.cuda.global.cublas.cublasGetProperty(2, build); + + val pew = new int[100]; + org.bytedeco.cuda.global.cudart.cudaDriverGetVersion(pew); + + nativeOps.isBlasVersionMatches(major[0], minor[0], build[0]); + + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + */ } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java index b06211545..bd6817f62 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java @@ -28,10 +28,12 @@ import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t; import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.linalg.api.blas.impl.BaseLevel3; +import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil; +import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.factory.DataTypeValidation; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.jcublas.CublasPointer; @@ -113,8 +115,13 @@ public class JcublasLevel3 extends BaseLevel3 { @Override protected void sgemm(char Order, char TransA, char TransB, int M, int N, int K, float alpha, INDArray A, int lda, INDArray B, int ldb, float beta, INDArray C, int ldc) { - //A = Shape.toOffsetZero(A); - //B = Shape.toOffsetZero(B); + /* + val ctx = AtomicAllocator.getInstance().getDeviceContext(); + val handle = ctx.getCublasHandle(); + synchronized (handle) { + Nd4j.exec(new Mmul(A, B, C, MMulTranspose.builder().transposeA(false).transposeB(false).build())); + } + */ Nd4j.getExecutioner().push(); @@ -141,6 +148,7 @@ public class JcublasLevel3 extends BaseLevel3 { } allocator.registerAction(ctx, C, A, B); + OpExecutionerUtil.checkForAny(C); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index f3080f05a..0ddcb6266 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -557,6 +557,13 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public Environment(Pointer p) { super(p); } + /** + * These 3 fields are mostly for CUDA/cuBLAS version tracking + */ + public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter); + public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter); + public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter); + public static native Environment getInstance(); public native @Cast("bool") boolean isVerbose(); @@ -1874,7 +1881,7 @@ public native void setOmpNumThreads(int threads); public native void setOmpMinThreads(int threads); - +public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build); /** * diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 9554a94e9..dabac7001 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -557,6 +557,13 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public Environment(Pointer p) { super(p); } + /** + * These 3 fields are mostly for CUDA/cuBLAS version tracking + */ + public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter); + public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter); + public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter); + public static native Environment getInstance(); public native @Cast("bool") boolean isVerbose(); @@ -1874,7 +1881,7 @@ public native void setOmpNumThreads(int threads); public native void setOmpMinThreads(int threads); - +public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build); /** * @@ -21929,6 +21936,78 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + /** + * This operation applies bitwise AND + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_bitwise_and) + @Namespace("nd4j::ops") public static class bitwise_and extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bitwise_and(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bitwise_and(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bitwise_and position(long position) { + return (bitwise_and)super.position(position); + } + + public bitwise_and() { super((Pointer)null); allocate(); } + private native void allocate(); + } +// #endif + + /** + * This operation applies bitwise OR + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_bitwise_or) + @Namespace("nd4j::ops") public static class bitwise_or extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bitwise_or(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bitwise_or(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bitwise_or position(long position) { + return (bitwise_or)super.position(position); + } + + public bitwise_or() { super((Pointer)null); allocate(); } + private native void allocate(); + } +// #endif + + /** + * This operation applies bitwise XOR + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_bitwise_xor) + @Namespace("nd4j::ops") public static class bitwise_xor extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bitwise_xor(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bitwise_xor(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bitwise_xor position(long position) { + return (bitwise_xor)super.position(position); + } + + public bitwise_xor() { super((Pointer)null); allocate(); } + private native void allocate(); + } +// #endif + /** * This operation returns hamming distance based on bits * diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 057f610bd..539901a41 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -389,10 +389,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); INDArray input = Nd4j.create(inSize); - AvgPooling2D avgPooling2D = AvgPooling2D.builder() - .arrayInput(input) - .config(conf) - .build(); + AvgPooling2D avgPooling2D = new AvgPooling2D(input, null, conf); val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D); @@ -410,10 +407,7 @@ public class LayerOpValidation extends BaseOpValidation { //Test backprop: - Pooling2DDerivative avg2dDeriv = Pooling2DDerivative.derivativeBuilder() - .arrayInputs(new INDArray[]{input, grad}) - .config(conf) - .build(); + Pooling2DDerivative avg2dDeriv = new Pooling2DDerivative(input, grad, null, conf); val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv); assertEquals(1, outSizesBP.size()); @@ -435,10 +429,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); INDArray input = Nd4j.create(inSize); - AvgPooling2D avgPooling2D = AvgPooling2D.builder() - .arrayInput(input) - .config(conf) - .build(); + AvgPooling2D avgPooling2D = new AvgPooling2D(input, null, conf); val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D); assertEquals(1, outSizes.size()); @@ -454,11 +445,7 @@ public class LayerOpValidation extends BaseOpValidation { INDArray grad = Nd4j.create(exp); //Test backprop: - Pooling2DDerivative avg2dDeriv = Pooling2DDerivative.derivativeBuilder() - .arrayInputs(new INDArray[]{input, grad}) //Original input, and output gradient (eps - same shape as output) - .arrayOutputs(new INDArray[]{Nd4j.create(inSize)}) //Output for BP: same shape as original input - .config(conf) - .build(); + Pooling2DDerivative avg2dDeriv = new Pooling2DDerivative(input, grad, Nd4j.create(inSize), conf); val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv); assertEquals(1, outSizesBP.size()); @@ -749,7 +736,7 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(false) .build(); - SDVariable out = sd.cnn().conv2d(vars, c); + SDVariable out = sd.cnn().conv2d("conv", vars, c); out = sd.nn().tanh("out", out); INDArray outArr = sd.execAndEndResult();