Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
AlexDBlack 2019-09-05 00:53:49 +10:00
commit b7226bdd7a
91 changed files with 1622 additions and 864 deletions

View File

@ -87,11 +87,7 @@ public class MKLDNNSubsamplingHelper implements SubsamplingHelper {
break; break;
} }
Pooling2DDerivative d = Pooling2DDerivative.derivativeBuilder() Pooling2DDerivative d = new Pooling2DDerivative(input, epsilon, gradAtInput, conf);
.config(conf)
.arrayInputs(new INDArray[]{input, epsilon})
.arrayOutputs(new INDArray[]{gradAtInput})
.build();
Nd4j.exec(d); Nd4j.exec(d);
return new Pair<Gradient,INDArray>(new DefaultGradient(), gradAtInput); return new Pair<Gradient,INDArray>(new DefaultGradient(), gradAtInput);

View File

@ -0,0 +1,40 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SAMEDIFF_BLASVERSIONHELPER_H
#define SAMEDIFF_BLASVERSIONHELPER_H
#include <dll.h>
#include <cuda.h>
#include <cuda_runtime.h>
namespace nd4j {
class ND4J_EXPORT BlasVersionHelper {
public:
int _blasMajorVersion = 0;
int _blasMinorVersion = 0;
int _blasPatchVersion = 0;
BlasVersionHelper();
~BlasVersionHelper() = default;
};
}
#endif //DEV_TESTS_BLASVERSIONHELPER_H

View File

@ -253,20 +253,20 @@ if(CUDA_BLAS)
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu) file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
if (NOT BUILD_TESTS) if (NOT BUILD_TESTS)
CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA} CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA}
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h ../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES}) ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
else() else()
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true") set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true")
CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA} CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA}
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h ../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES}) ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
endif() endif()

View File

@ -35,7 +35,7 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "BlasVersionHelper.h"
#endif #endif
namespace nd4j { namespace nd4j {
@ -66,6 +66,13 @@ namespace nd4j {
#endif #endif
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
BlasVersionHelper ver;
_blasMajorVersion = ver._blasMajorVersion;
_blasMinorVersion = ver._blasMinorVersion;
_blasPatchVersion = ver._blasPatchVersion;
printf("ND4J CUDA build version: %i.%i.%i\n", _blasMajorVersion, _blasMinorVersion, _blasPatchVersion);
fflush(stdout);
int devCnt = 0; int devCnt = 0;
cudaGetDeviceCount(&devCnt); cudaGetDeviceCount(&devCnt);
auto devProperties = new cudaDeviceProp[devCnt]; auto devProperties = new cudaDeviceProp[devCnt];

View File

@ -56,6 +56,13 @@ namespace nd4j{
Environment(); Environment();
~Environment(); ~Environment();
public: public:
/**
* These 3 fields are mostly for CUDA/cuBLAS version tracking
*/
int _blasMajorVersion = 0;
int _blasMinorVersion = 0;
int _blasPatchVersion = 0;
static Environment* getInstance(); static Environment* getInstance();
bool isVerbose(); bool isVerbose();

View File

@ -647,7 +647,7 @@ ND4J_EXPORT void setOmpNumThreads(int threads);
ND4J_EXPORT void setOmpMinThreads(int threads); ND4J_EXPORT void setOmpMinThreads(int threads);
ND4J_EXPORT bool isBlasVersionMatches(int major, int minor, int build);
/** /**
* *

View File

@ -583,7 +583,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
auto fName = builder.CreateString(*(var->getName())); auto fName = builder.CreateString(*(var->getName()));
auto id = CreateIntPair(builder, var->id(), var->index()); auto id = CreateIntPair(builder, var->id(), var->index());
auto fv = CreateFlatVariable(builder, id, fName, static_cast<nd4j::graph::DataType>(array->dataType()), 0, fArray); auto fv = CreateFlatVariable(builder, id, fName, static_cast<nd4j::graph::DType>(array->dataType()), 0, fArray);
variables_vector.push_back(fv); variables_vector.push_back(fv);
arrays++; arrays++;

View File

@ -728,6 +728,10 @@ void execReduce3Tad(Nd4jPointer *extraPointers,
} }
} }
bool isBlasVersionMatches(int major, int minor, int build) {
return true;
}
/** /**
* *
* @param opNum * @param opNum

View File

@ -0,0 +1,29 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../BlasVersionHelper.h"
namespace nd4j {
BlasVersionHelper::BlasVersionHelper() {
_blasMajorVersion = __CUDACC_VER_MAJOR__;
_blasMinorVersion = __CUDACC_VER_MINOR__;
_blasPatchVersion = __CUDACC_VER_BUILD__;
}
}

View File

@ -3357,6 +3357,18 @@ void deleteTadPack(nd4j::TadPack* ptr) {
delete ptr; delete ptr;
} }
bool isBlasVersionMatches(int major, int minor, int build) {
auto result = major == Environment::getInstance()->_blasMajorVersion && minor == Environment::getInstance()->_blasMinorVersion && build == Environment::getInstance()->_blasPatchVersion;
if (!result) {
nd4j_printf("CUDA/cuBLAS version mismatch. Expected: %i.%i.%i but got %i.%i.%i instead\n", Environment::getInstance()->_blasMajorVersion, Environment::getInstance()->_blasMinorVersion, Environment::getInstance()->_blasPatchVersion, major, minor, build);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(152);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("CUDA/cuBLAS version mismatch");
}
return result;
}
nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) { nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) {
return nd4j::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype); return nd4j::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype);
} }

View File

@ -38,7 +38,7 @@ namespace nd4j {
public: public:
static int asInt(DataType type); static int asInt(DataType type);
static DataType fromInt(int dtype); static DataType fromInt(int dtype);
static DataType fromFlatDataType(nd4j::graph::DataType dtype); static DataType fromFlatDataType(nd4j::graph::DType dtype);
FORCEINLINE static std::string asString(DataType dataType); FORCEINLINE static std::string asString(DataType dataType);
template <typename T> template <typename T>

View File

@ -27,7 +27,7 @@ namespace nd4j {
return (DataType) val; return (DataType) val;
} }
DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DataType dtype) { DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DType dtype) {
return (DataType) dtype; return (DataType) dtype;
} }

View File

@ -40,56 +40,56 @@ inline const char *EnumNameByteOrder(ByteOrder e) {
return EnumNamesByteOrder()[index]; return EnumNamesByteOrder()[index];
} }
enum DataType { enum DType {
DataType_INHERIT = 0, DType_INHERIT = 0,
DataType_BOOL = 1, DType_BOOL = 1,
DataType_FLOAT8 = 2, DType_FLOAT8 = 2,
DataType_HALF = 3, DType_HALF = 3,
DataType_HALF2 = 4, DType_HALF2 = 4,
DataType_FLOAT = 5, DType_FLOAT = 5,
DataType_DOUBLE = 6, DType_DOUBLE = 6,
DataType_INT8 = 7, DType_INT8 = 7,
DataType_INT16 = 8, DType_INT16 = 8,
DataType_INT32 = 9, DType_INT32 = 9,
DataType_INT64 = 10, DType_INT64 = 10,
DataType_UINT8 = 11, DType_UINT8 = 11,
DataType_UINT16 = 12, DType_UINT16 = 12,
DataType_UINT32 = 13, DType_UINT32 = 13,
DataType_UINT64 = 14, DType_UINT64 = 14,
DataType_QINT8 = 15, DType_QINT8 = 15,
DataType_QINT16 = 16, DType_QINT16 = 16,
DataType_BFLOAT16 = 17, DType_BFLOAT16 = 17,
DataType_UTF8 = 50, DType_UTF8 = 50,
DataType_MIN = DataType_INHERIT, DType_MIN = DType_INHERIT,
DataType_MAX = DataType_UTF8 DType_MAX = DType_UTF8
}; };
inline const DataType (&EnumValuesDataType())[19] { inline const DType (&EnumValuesDType())[19] {
static const DataType values[] = { static const DType values[] = {
DataType_INHERIT, DType_INHERIT,
DataType_BOOL, DType_BOOL,
DataType_FLOAT8, DType_FLOAT8,
DataType_HALF, DType_HALF,
DataType_HALF2, DType_HALF2,
DataType_FLOAT, DType_FLOAT,
DataType_DOUBLE, DType_DOUBLE,
DataType_INT8, DType_INT8,
DataType_INT16, DType_INT16,
DataType_INT32, DType_INT32,
DataType_INT64, DType_INT64,
DataType_UINT8, DType_UINT8,
DataType_UINT16, DType_UINT16,
DataType_UINT32, DType_UINT32,
DataType_UINT64, DType_UINT64,
DataType_QINT8, DType_QINT8,
DataType_QINT16, DType_QINT16,
DataType_BFLOAT16, DType_BFLOAT16,
DataType_UTF8 DType_UTF8
}; };
return values; return values;
} }
inline const char * const *EnumNamesDataType() { inline const char * const *EnumNamesDType() {
static const char * const names[] = { static const char * const names[] = {
"INHERIT", "INHERIT",
"BOOL", "BOOL",
@ -147,9 +147,9 @@ inline const char * const *EnumNamesDataType() {
return names; return names;
} }
inline const char *EnumNameDataType(DataType e) { inline const char *EnumNameDType(DType e) {
const size_t index = static_cast<int>(e); const size_t index = static_cast<int>(e);
return EnumNamesDataType()[index]; return EnumNamesDType()[index];
} }
struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
@ -165,8 +165,8 @@ struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const flatbuffers::Vector<int8_t> *buffer() const { const flatbuffers::Vector<int8_t> *buffer() const {
return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_BUFFER); return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_BUFFER);
} }
DataType dtype() const { DType dtype() const {
return static_cast<DataType>(GetField<int8_t>(VT_DTYPE, 0)); return static_cast<DType>(GetField<int8_t>(VT_DTYPE, 0));
} }
ByteOrder byteOrder() const { ByteOrder byteOrder() const {
return static_cast<ByteOrder>(GetField<int8_t>(VT_BYTEORDER, 0)); return static_cast<ByteOrder>(GetField<int8_t>(VT_BYTEORDER, 0));
@ -192,7 +192,7 @@ struct FlatArrayBuilder {
void add_buffer(flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer) { void add_buffer(flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer) {
fbb_.AddOffset(FlatArray::VT_BUFFER, buffer); fbb_.AddOffset(FlatArray::VT_BUFFER, buffer);
} }
void add_dtype(DataType dtype) { void add_dtype(DType dtype) {
fbb_.AddElement<int8_t>(FlatArray::VT_DTYPE, static_cast<int8_t>(dtype), 0); fbb_.AddElement<int8_t>(FlatArray::VT_DTYPE, static_cast<int8_t>(dtype), 0);
} }
void add_byteOrder(ByteOrder byteOrder) { void add_byteOrder(ByteOrder byteOrder) {
@ -214,7 +214,7 @@ inline flatbuffers::Offset<FlatArray> CreateFlatArray(
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0, flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer = 0, flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer = 0,
DataType dtype = DataType_INHERIT, DType dtype = DType_INHERIT,
ByteOrder byteOrder = ByteOrder_LE) { ByteOrder byteOrder = ByteOrder_LE) {
FlatArrayBuilder builder_(_fbb); FlatArrayBuilder builder_(_fbb);
builder_.add_buffer(buffer); builder_.add_buffer(buffer);
@ -228,7 +228,7 @@ inline flatbuffers::Offset<FlatArray> CreateFlatArrayDirect(
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
const std::vector<int64_t> *shape = nullptr, const std::vector<int64_t> *shape = nullptr,
const std::vector<int8_t> *buffer = nullptr, const std::vector<int8_t> *buffer = nullptr,
DataType dtype = DataType_INHERIT, DType dtype = DType_INHERIT,
ByteOrder byteOrder = ByteOrder_LE) { ByteOrder byteOrder = ByteOrder_LE) {
return nd4j::graph::CreateFlatArray( return nd4j::graph::CreateFlatArray(
_fbb, _fbb,

View File

@ -23,7 +23,7 @@ nd4j.graph.ByteOrder = {
/** /**
* @enum * @enum
*/ */
nd4j.graph.DataType = { nd4j.graph.DType = {
INHERIT: 0, INHERIT: 0,
BOOL: 1, BOOL: 1,
FLOAT8: 2, FLOAT8: 2,
@ -123,11 +123,11 @@ nd4j.graph.FlatArray.prototype.bufferArray = function() {
}; };
/** /**
* @returns {nd4j.graph.DataType} * @returns {nd4j.graph.DType}
*/ */
nd4j.graph.FlatArray.prototype.dtype = function() { nd4j.graph.FlatArray.prototype.dtype = function() {
var offset = this.bb.__offset(this.bb_pos, 8); var offset = this.bb.__offset(this.bb_pos, 8);
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT; return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
}; };
/** /**
@ -205,10 +205,10 @@ nd4j.graph.FlatArray.startBufferVector = function(builder, numElems) {
/** /**
* @param {flatbuffers.Builder} builder * @param {flatbuffers.Builder} builder
* @param {nd4j.graph.DataType} dtype * @param {nd4j.graph.DType} dtype
*/ */
nd4j.graph.FlatArray.addDtype = function(builder, dtype) { nd4j.graph.FlatArray.addDtype = function(builder, dtype) {
builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT); builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT);
}; };
/** /**

View File

@ -5,7 +5,7 @@
namespace nd4j.graph namespace nd4j.graph
{ {
public enum DataType : sbyte public enum DType : sbyte
{ {
INHERIT = 0, INHERIT = 0,
BOOL = 1, BOOL = 1,

View File

@ -2,8 +2,8 @@
package nd4j.graph; package nd4j.graph;
public final class DataType { public final class DType {
private DataType() { } private DType() { }
public static final byte INHERIT = 0; public static final byte INHERIT = 0;
public static final byte BOOL = 1; public static final byte BOOL = 1;
public static final byte FLOAT8 = 2; public static final byte FLOAT8 = 2;

View File

@ -2,7 +2,7 @@
# namespace: graph # namespace: graph
class DataType(object): class DType(object):
INHERIT = 0 INHERIT = 0
BOOL = 1 BOOL = 1
FLOAT8 = 2 FLOAT8 = 2

View File

@ -33,13 +33,13 @@ public struct FlatArray : IFlatbufferObject
public ArraySegment<byte>? GetBufferBytes() { return __p.__vector_as_arraysegment(6); } public ArraySegment<byte>? GetBufferBytes() { return __p.__vector_as_arraysegment(6); }
#endif #endif
public sbyte[] GetBufferArray() { return __p.__vector_as_array<sbyte>(6); } public sbyte[] GetBufferArray() { return __p.__vector_as_array<sbyte>(6); }
public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } } public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
public ByteOrder ByteOrder { get { int o = __p.__offset(10); return o != 0 ? (ByteOrder)__p.bb.GetSbyte(o + __p.bb_pos) : ByteOrder.LE; } } public ByteOrder ByteOrder { get { int o = __p.__offset(10); return o != 0 ? (ByteOrder)__p.bb.GetSbyte(o + __p.bb_pos) : ByteOrder.LE; } }
public static Offset<FlatArray> CreateFlatArray(FlatBufferBuilder builder, public static Offset<FlatArray> CreateFlatArray(FlatBufferBuilder builder,
VectorOffset shapeOffset = default(VectorOffset), VectorOffset shapeOffset = default(VectorOffset),
VectorOffset bufferOffset = default(VectorOffset), VectorOffset bufferOffset = default(VectorOffset),
DataType dtype = DataType.INHERIT, DType dtype = DType.INHERIT,
ByteOrder byteOrder = ByteOrder.LE) { ByteOrder byteOrder = ByteOrder.LE) {
builder.StartObject(4); builder.StartObject(4);
FlatArray.AddBuffer(builder, bufferOffset); FlatArray.AddBuffer(builder, bufferOffset);
@ -58,7 +58,7 @@ public struct FlatArray : IFlatbufferObject
public static VectorOffset CreateBufferVector(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte(data[i]); return builder.EndVector(); } public static VectorOffset CreateBufferVector(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte(data[i]); return builder.EndVector(); }
public static VectorOffset CreateBufferVectorBlock(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } public static VectorOffset CreateBufferVectorBlock(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
public static void StartBufferVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); } public static void StartBufferVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
public static void AddByteOrder(FlatBufferBuilder builder, ByteOrder byteOrder) { builder.AddSbyte(3, (sbyte)byteOrder, 0); } public static void AddByteOrder(FlatBufferBuilder builder, ByteOrder byteOrder) { builder.AddSbyte(3, (sbyte)byteOrder, 0); }
public static Offset<FlatArray> EndFlatArray(FlatBufferBuilder builder) { public static Offset<FlatArray> EndFlatArray(FlatBufferBuilder builder) {
int o = builder.EndObject(); int o = builder.EndObject();

View File

@ -97,14 +97,14 @@ public struct FlatNode : IFlatbufferObject
public ArraySegment<byte>? GetOpNameBytes() { return __p.__vector_as_arraysegment(36); } public ArraySegment<byte>? GetOpNameBytes() { return __p.__vector_as_arraysegment(36); }
#endif #endif
public byte[] GetOpNameArray() { return __p.__vector_as_array<byte>(36); } public byte[] GetOpNameArray() { return __p.__vector_as_array<byte>(36); }
public DataType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DataType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DataType)0; } public DType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; }
public int OutputTypesLength { get { int o = __p.__offset(38); return o != 0 ? __p.__vector_len(o) : 0; } } public int OutputTypesLength { get { int o = __p.__offset(38); return o != 0 ? __p.__vector_len(o) : 0; } }
#if ENABLE_SPAN_T #if ENABLE_SPAN_T
public Span<byte> GetOutputTypesBytes() { return __p.__vector_as_span(38); } public Span<byte> GetOutputTypesBytes() { return __p.__vector_as_span(38); }
#else #else
public ArraySegment<byte>? GetOutputTypesBytes() { return __p.__vector_as_arraysegment(38); } public ArraySegment<byte>? GetOutputTypesBytes() { return __p.__vector_as_arraysegment(38); }
#endif #endif
public DataType[] GetOutputTypesArray() { return __p.__vector_as_array<DataType>(38); } public DType[] GetOutputTypesArray() { return __p.__vector_as_array<DType>(38); }
public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } } public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder, public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
@ -196,8 +196,8 @@ public struct FlatNode : IFlatbufferObject
public static void StartOutputNamesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } public static void StartOutputNamesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
public static void AddOpName(FlatBufferBuilder builder, StringOffset opNameOffset) { builder.AddOffset(16, opNameOffset.Value, 0); } public static void AddOpName(FlatBufferBuilder builder, StringOffset opNameOffset) { builder.AddOffset(16, opNameOffset.Value, 0); }
public static void AddOutputTypes(FlatBufferBuilder builder, VectorOffset outputTypesOffset) { builder.AddOffset(17, outputTypesOffset.Value, 0); } public static void AddOutputTypes(FlatBufferBuilder builder, VectorOffset outputTypesOffset) { builder.AddOffset(17, outputTypesOffset.Value, 0); }
public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); } public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); } public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
public static void AddScalar(FlatBufferBuilder builder, Offset<FlatArray> scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); } public static void AddScalar(FlatBufferBuilder builder, Offset<FlatArray> scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); }
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) { public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {

View File

@ -25,7 +25,7 @@ public struct FlatVariable : IFlatbufferObject
public ArraySegment<byte>? GetNameBytes() { return __p.__vector_as_arraysegment(6); } public ArraySegment<byte>? GetNameBytes() { return __p.__vector_as_arraysegment(6); }
#endif #endif
public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); } public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); }
public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } } public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
public long Shape(int j) { int o = __p.__offset(10); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; } public long Shape(int j) { int o = __p.__offset(10); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; }
public int ShapeLength { get { int o = __p.__offset(10); return o != 0 ? __p.__vector_len(o) : 0; } } public int ShapeLength { get { int o = __p.__offset(10); return o != 0 ? __p.__vector_len(o) : 0; } }
#if ENABLE_SPAN_T #if ENABLE_SPAN_T
@ -41,7 +41,7 @@ public struct FlatVariable : IFlatbufferObject
public static Offset<FlatVariable> CreateFlatVariable(FlatBufferBuilder builder, public static Offset<FlatVariable> CreateFlatVariable(FlatBufferBuilder builder,
Offset<IntPair> idOffset = default(Offset<IntPair>), Offset<IntPair> idOffset = default(Offset<IntPair>),
StringOffset nameOffset = default(StringOffset), StringOffset nameOffset = default(StringOffset),
DataType dtype = DataType.INHERIT, DType dtype = DType.INHERIT,
VectorOffset shapeOffset = default(VectorOffset), VectorOffset shapeOffset = default(VectorOffset),
Offset<FlatArray> ndarrayOffset = default(Offset<FlatArray>), Offset<FlatArray> ndarrayOffset = default(Offset<FlatArray>),
int device = 0, int device = 0,
@ -60,7 +60,7 @@ public struct FlatVariable : IFlatbufferObject
public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); } public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); }
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); } public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); } public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(3, shapeOffset.Value, 0); } public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(3, shapeOffset.Value, 0); }
public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); } public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); }
public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); } public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); }

View File

@ -312,11 +312,11 @@ nd4j.graph.FlatNode.prototype.opName = function(optionalEncoding) {
/** /**
* @param {number} index * @param {number} index
* @returns {nd4j.graph.DataType} * @returns {nd4j.graph.DType}
*/ */
nd4j.graph.FlatNode.prototype.outputTypes = function(index) { nd4j.graph.FlatNode.prototype.outputTypes = function(index) {
var offset = this.bb.__offset(this.bb_pos, 38); var offset = this.bb.__offset(this.bb_pos, 38);
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DataType} */ (0); return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0);
}; };
/** /**
@ -686,7 +686,7 @@ nd4j.graph.FlatNode.addOutputTypes = function(builder, outputTypesOffset) {
/** /**
* @param {flatbuffers.Builder} builder * @param {flatbuffers.Builder} builder
* @param {Array.<nd4j.graph.DataType>} data * @param {Array.<nd4j.graph.DType>} data
* @returns {flatbuffers.Offset} * @returns {flatbuffers.Offset}
*/ */
nd4j.graph.FlatNode.createOutputTypesVector = function(builder, data) { nd4j.graph.FlatNode.createOutputTypesVector = function(builder, data) {

View File

@ -65,8 +65,8 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const flatbuffers::String *name() const { const flatbuffers::String *name() const {
return GetPointer<const flatbuffers::String *>(VT_NAME); return GetPointer<const flatbuffers::String *>(VT_NAME);
} }
DataType dtype() const { DType dtype() const {
return static_cast<DataType>(GetField<int8_t>(VT_DTYPE, 0)); return static_cast<DType>(GetField<int8_t>(VT_DTYPE, 0));
} }
const flatbuffers::Vector<int64_t> *shape() const { const flatbuffers::Vector<int64_t> *shape() const {
return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE); return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE);
@ -106,7 +106,7 @@ struct FlatVariableBuilder {
void add_name(flatbuffers::Offset<flatbuffers::String> name) { void add_name(flatbuffers::Offset<flatbuffers::String> name) {
fbb_.AddOffset(FlatVariable::VT_NAME, name); fbb_.AddOffset(FlatVariable::VT_NAME, name);
} }
void add_dtype(DataType dtype) { void add_dtype(DType dtype) {
fbb_.AddElement<int8_t>(FlatVariable::VT_DTYPE, static_cast<int8_t>(dtype), 0); fbb_.AddElement<int8_t>(FlatVariable::VT_DTYPE, static_cast<int8_t>(dtype), 0);
} }
void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) { void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) {
@ -137,7 +137,7 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariable(
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<IntPair> id = 0, flatbuffers::Offset<IntPair> id = 0,
flatbuffers::Offset<flatbuffers::String> name = 0, flatbuffers::Offset<flatbuffers::String> name = 0,
DataType dtype = DataType_INHERIT, DType dtype = DType_INHERIT,
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0, flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
flatbuffers::Offset<FlatArray> ndarray = 0, flatbuffers::Offset<FlatArray> ndarray = 0,
int32_t device = 0, int32_t device = 0,
@ -157,7 +157,7 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariableDirect(
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<IntPair> id = 0, flatbuffers::Offset<IntPair> id = 0,
const char *name = nullptr, const char *name = nullptr,
DataType dtype = DataType_INHERIT, DType dtype = DType_INHERIT,
const std::vector<int64_t> *shape = nullptr, const std::vector<int64_t> *shape = nullptr,
flatbuffers::Offset<FlatArray> ndarray = 0, flatbuffers::Offset<FlatArray> ndarray = 0,
int32_t device = 0, int32_t device = 0,

View File

@ -76,11 +76,11 @@ nd4j.graph.FlatVariable.prototype.name = function(optionalEncoding) {
}; };
/** /**
* @returns {nd4j.graph.DataType} * @returns {nd4j.graph.DType}
*/ */
nd4j.graph.FlatVariable.prototype.dtype = function() { nd4j.graph.FlatVariable.prototype.dtype = function() {
var offset = this.bb.__offset(this.bb_pos, 8); var offset = this.bb.__offset(this.bb_pos, 8);
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT; return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
}; };
/** /**
@ -150,10 +150,10 @@ nd4j.graph.FlatVariable.addName = function(builder, nameOffset) {
/** /**
* @param {flatbuffers.Builder} builder * @param {flatbuffers.Builder} builder
* @param {nd4j.graph.DataType} dtype * @param {nd4j.graph.DType} dtype
*/ */
nd4j.graph.FlatVariable.addDtype = function(builder, dtype) { nd4j.graph.FlatVariable.addDtype = function(builder, dtype) {
builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT); builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT);
}; };
/** /**

View File

@ -111,7 +111,7 @@ namespace nd4j {
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder()); auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array.dataType()), bo); return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DType>(array.dataType()), bo);
} }
} }
} }

View File

@ -219,7 +219,7 @@ namespace nd4j {
throw std::runtime_error("CONSTANT variable must have NDArray bundled"); throw std::runtime_error("CONSTANT variable must have NDArray bundled");
auto ar = flatVariable->ndarray(); auto ar = flatVariable->ndarray();
if (ar->dtype() == DataType_UTF8) { if (ar->dtype() == DType_UTF8) {
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar); _ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
} else { } else {
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar); _ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
@ -320,7 +320,7 @@ namespace nd4j {
auto fBuffer = builder.CreateVector(array->asByteVector()); auto fBuffer = builder.CreateVector(array->asByteVector());
// packing array // packing array
auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DataType) array->dataType()); auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DType) array->dataType());
// packing id/index of this var // packing id/index of this var
auto fVid = CreateIntPair(builder, this->_id, this->_index); auto fVid = CreateIntPair(builder, this->_id, this->_index);
@ -331,7 +331,7 @@ namespace nd4j {
stringId = builder.CreateString(this->_name); stringId = builder.CreateString(this->_name);
// returning array // returning array
return CreateFlatVariable(builder, fVid, stringId, static_cast<nd4j::graph::DataType>(array->dataType()), 0, fArray); return CreateFlatVariable(builder, fVid, stringId, static_cast<nd4j::graph::DType>(array->dataType()), 0, fArray);
} else { } else {
throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList"); throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList");
} }

View File

@ -23,7 +23,7 @@ enum ByteOrder:byte {
} }
// DataType for arrays/buffers // DataType for arrays/buffers
enum DataType:byte { enum DType:byte {
INHERIT, INHERIT,
BOOL, BOOL,
FLOAT8, FLOAT8,
@ -49,7 +49,7 @@ enum DataType:byte {
table FlatArray { table FlatArray {
shape:[long]; // shape in Nd4j format shape:[long]; // shape in Nd4j format
buffer:[byte]; // byte buffer with data buffer:[byte]; // byte buffer with data
dtype:DataType; // data type of actual data within buffer dtype:DType; // data type of actual data within buffer
byteOrder:ByteOrder; // byte order of buffer byteOrder:ByteOrder; // byte order of buffer
} }

View File

@ -48,7 +48,7 @@ table FlatNode {
opName:string; //Used to help resolving the class. In a few cases, multiple classes/opNames are mapped to same hash, and might have different config/properties/differentiability opName:string; //Used to help resolving the class. In a few cases, multiple classes/opNames are mapped to same hash, and might have different config/properties/differentiability
// output data types (optional) // output data types (optional)
outputTypes:[DataType]; outputTypes:[DType];
//Scalar value - used for scalar ops. Should be single value only. //Scalar value - used for scalar ops. Should be single value only.
scalar:FlatArray; scalar:FlatArray;

View File

@ -51,7 +51,7 @@ table UIVariable {
id:IntPair; //Existing IntPair class id:IntPair; //Existing IntPair class
name:string; name:string;
type:VarType; //Use existing VarType: VARIABLE, CONSTANT, ARRAY, PLACEHOLDER type:VarType; //Use existing VarType: VARIABLE, CONSTANT, ARRAY, PLACEHOLDER
datatype:DataType; datatype:DType;
shape:[long]; shape:[long];
controlDeps:[string]; //Input control dependencies: variable x -> this controlDeps:[string]; //Input control dependencies: variable x -> this
outputOfOp:string; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of outputOfOp:string; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of

View File

@ -30,7 +30,7 @@ enum VarType:byte {
table FlatVariable { table FlatVariable {
id:IntPair; // ID of the Variable, in format of IntPair.first is node Id, IntPair.second is output index of the node id:IntPair; // ID of the Variable, in format of IntPair.first is node Id, IntPair.second is output index of the node
name:string; // symbolic ID of the Variable (if defined) name:string; // symbolic ID of the Variable (if defined)
dtype:DataType; dtype:DType;
shape:[long]; // shape is absolutely optional. either shape or ndarray might be set shape:[long]; // shape is absolutely optional. either shape or ndarray might be set
ndarray:FlatArray; ndarray:FlatArray;

View File

@ -0,0 +1,50 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_bitwise_and)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
BROADCASTABLE_OP_IMPL(bitwise_and, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), y, z, false);
return Status::OK();
}
DECLARE_TYPES(bitwise_and) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS})
->setSameMode(true);
}
}
}
#endif

View File

@ -0,0 +1,50 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_bitwise_or)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
BROADCASTABLE_OP_IMPL(bitwise_or, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), y, z, false);
return Status::OK();
}
DECLARE_TYPES(bitwise_or) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS})
->setSameMode(true);
}
}
}
#endif

View File

@ -0,0 +1,50 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_bitwise_xor)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
BROADCASTABLE_OP_IMPL(bitwise_xor, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), y, z, false);
return Status::OK();
}
DECLARE_TYPES(bitwise_xor) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS})
->setSameMode(true);
}
}
}
#endif

View File

@ -29,21 +29,26 @@ namespace ops {
CUSTOM_OP_IMPL(dynamic_stitch, 2, 1, false, 0, 0) { CUSTOM_OP_IMPL(dynamic_stitch, 2, 1, false, 0, 0) {
int numOfData = block.width(); int numOfData = block.width();
// int k = 0; // int k = 0;
// checking input data size
REQUIRE_TRUE(numOfData % 2 == 0, 0, REQUIRE_TRUE(numOfData % 2 == 0, 0,
"dynamic_stitch: The input params should contains" "dynamic_stitch: The input params should contains"
" both indeces and data lists with same length."); " both indeces and data lists with same length.");
// split input data list on two equal parts
numOfData /= 2; numOfData /= 2;
// form input lists to use with helpers - both indices and float data inputs
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
std::vector<NDArray*> inputs(numOfData); std::vector<NDArray*> inputs(numOfData);
std::vector<NDArray*> indices(numOfData); std::vector<NDArray*> indices(numOfData);
for (int e = 0; e < numOfData; e++) { for (int e = 0; e < numOfData; e++) {
auto data = INPUT_VARIABLE(numOfData + e); auto data = INPUT_VARIABLE(numOfData + e);
auto index = INPUT_VARIABLE(e); auto index = INPUT_VARIABLE(e);
inputs[e] = data; inputs[e] = data;
indices[e] = index; indices[e] = index;
} }
// run helper
return helpers::dynamicStitchFunctor(block.launchContext(), inputs, indices, output); return helpers::dynamicStitchFunctor(block.launchContext(), inputs, indices, output);
} }
@ -59,17 +64,17 @@ namespace ops {
numOfData /= 2; // only index part it's needed to review numOfData /= 2; // only index part it's needed to review
auto restShape = inputShape->at(numOfData); auto restShape = inputShape->at(numOfData);
auto firstShape = inputShape->at(0); auto firstShape = inputShape->at(0);
// check up inputs to avoid non-int indices and calculate max value from indices to output shape length
for(int i = 0; i < numOfData; i++) { for(int i = 0; i < numOfData; i++) {
auto input = INPUT_VARIABLE(i); auto input = INPUT_VARIABLE(i);
REQUIRE_TRUE(input->isZ(), 0, "dynamic_stitch: Indices should be integer, but %d type given.", (int)input->dataType() ); REQUIRE_TRUE(input->isZ(), 0, "dynamic_stitch: Indices should be integer, but %d type given.", (int)input->dataType() );
// FIXME: we have reduce::Max, cinsider using it instead
auto maxV = input->reduceNumber(reduce::Max); auto maxV = input->reduceNumber(reduce::Max);
if (maxV.e<Nd4jLong>(0) > maxValue) maxValue = maxV.e<Nd4jLong>(0); if (maxV.e<Nd4jLong>(0) > maxValue) maxValue = maxV.e<Nd4jLong>(0);
} }
// calculate output rank - difference between indices shape and data shape
int outRank = shape::rank(restShape) - shape::rank(firstShape) + 1; int outRank = shape::rank(restShape) - shape::rank(firstShape) + 1; // at least 1D tensor
std::vector<Nd4jLong> outShape(outRank); std::vector<Nd4jLong> outShape(outRank);
// fill up output shape template: the first to max index, and rests - to vals from the first data input
outShape[0] = maxValue + 1; outShape[0] = maxValue + 1;
for(int i = 1; i < outRank; ++i) for(int i = 1; i < outRank; ++i)
outShape[i] = shape::sizeAt(restShape, i); outShape[i] = shape::sizeAt(restShape, i);

View File

@ -33,12 +33,13 @@ namespace nd4j {
* 0: 1D row-vector (or with shape (1, m)) * 0: 1D row-vector (or with shape (1, m))
* 1: 1D integer vector with slice nums * 1: 1D integer vector with slice nums
* 2: 1D float-point values vector with same shape as above * 2: 1D float-point values vector with same shape as above
* 3: 2D float-point matrix with data to search
* *
* Int args: * Int args:
* 0: N - number of slices * 0: N - number of slices
* *
* Output: * Output:
* 0: 1D vector with edge forces for input and values * 0: 2D matrix with the same shape and type as the 3th argument
*/ */
#if NOT_EXCLUDED(OP_barnes_edge_forces) #if NOT_EXCLUDED(OP_barnes_edge_forces)
DECLARE_CUSTOM_OP(barnes_edge_forces, 4, 1, false, 0, 1); DECLARE_CUSTOM_OP(barnes_edge_forces, 4, 1, false, 0, 1);
@ -54,7 +55,9 @@ namespace nd4j {
* 2: 1D float vector with values * 2: 1D float vector with values
* *
* Output: * Output:
* 0: symmetric 2D matrix with given values on given places * 0: 1D int result row-vector
* 1: 1D int result col-vector
* 2: a float-point tensor with shape 1xN, with values from the last input vector
*/ */
#if NOT_EXCLUDED(OP_barnes_symmetrized) #if NOT_EXCLUDED(OP_barnes_symmetrized)
DECLARE_CUSTOM_OP(barnes_symmetrized, 3, 3, false, 0, -1); DECLARE_CUSTOM_OP(barnes_symmetrized, 3, 3, false, 0, -1);

View File

@ -81,6 +81,39 @@ namespace nd4j {
DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0); DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0);
#endif #endif
/**
* This operation applies bitwise AND
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* @tparam T
*/
#if NOT_EXCLUDED(OP_bitwise_and)
DECLARE_BROADCASTABLE_OP(bitwise_and, 0, 0);
#endif
/**
* This operation applies bitwise OR
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* @tparam T
*/
#if NOT_EXCLUDED(OP_bitwise_or)
DECLARE_BROADCASTABLE_OP(bitwise_or, 0, 0);
#endif
/**
* This operation applies bitwise XOR
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* @tparam T
*/
#if NOT_EXCLUDED(OP_bitwise_xor)
DECLARE_BROADCASTABLE_OP(bitwise_xor, 0, 0);
#endif
/** /**
* This operation returns hamming distance based on bits * This operation returns hamming distance based on bits
* *

View File

@ -120,7 +120,7 @@ namespace nd4j {
#endif #endif
/** /**
* This operation unstacks given NDArray into NDArrayList * This operation unstacks given NDArray into NDArrayList by the first dimension
*/ */
#if NOT_EXCLUDED(OP_unstack_list) #if NOT_EXCLUDED(OP_unstack_list)
DECLARE_LIST_OP(unstack_list, 1, 1, 0, 0); DECLARE_LIST_OP(unstack_list, 1, 1, 0, 0);

View File

@ -594,21 +594,46 @@ namespace nd4j {
/** /**
* This operation rearranges data from depth into blocks of spatial data. This is the reverse transformation
* of space_to_depth op. This op output is a copy of the input tensor where values from the depth dimension
* are moved in spatial blocks to the height and width dimensions. Int attr 0 indicates the input
* block size and how the data is moved.
* Input:
* 0 - 4D tensor on given type
* Output:
* 0 - 4D tensor of given type and proper shape
* *
* * Int arguments:
* * 0 - block size
* 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels }
* 1 ("NCHW"): shape{ batch, channels, height, width }
* 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 }
* optional (default 0)
*/ */
#if NOT_EXCLUDED(OP_depth_to_space) #if NOT_EXCLUDED(OP_depth_to_space)
DECLARE_CUSTOM_OP(depth_to_space, 1, 1, false, 0, 2); DECLARE_CUSTOM_OP(depth_to_space, 1, 1, false, 0, -1);
#endif #endif
/** /**
* This operation rearranges blocks of spatial data, into depth.This op output is a copy of the input tensor
* where values from the height and width dimensions are moved to the depth dimension. Int attr 0 indicates
* the input block size.
* *
* Input:
* - 4D tensor of given type
* Output:
* - 4D tensor
* *
* Int arguments:
* 0 - block size
* 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels }
* 1 ("NCHW"): shape{ batch, channels, height, width }
* 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 }
* optional (default 0)
* *
*/ */
#if NOT_EXCLUDED(OP_space_to_depth) #if NOT_EXCLUDED(OP_space_to_depth)
DECLARE_CUSTOM_OP(space_to_depth, 1, 1, false, 0, 2); DECLARE_CUSTOM_OP(space_to_depth, 1, 1, false, 0, -1);
#endif #endif
/** /**
@ -622,13 +647,42 @@ namespace nd4j {
#endif #endif
/** /**
* Zero-pads and then rearranges (permutes) blocks of spatial data into batch. More specifically, this op
* outputs a copy of the input tensor where values from the height and width dimensions are moved to the
* batch dimension. After the zero-padding, both height and width of the input must be divisible by the block
* size.
* *
* Inputs:
* 0 - input tensor
* 1 - 2D paddings tensor (shape {M, 2})
*
* Output:
* - result tensor
*
* Int args:
* 0 - block size (M)
* *
*/ */
#if NOT_EXCLUDED(OP_space_to_batch) #if NOT_EXCLUDED(OP_space_to_batch)
DECLARE_CUSTOM_OP(space_to_batch, 2, 1, false, 0, 1); DECLARE_CUSTOM_OP(space_to_batch, 2, 1, false, 0, 1);
#endif #endif
/*
* This operation divides "spatial" dimensions [1, ..., M] of the input into a grid of blocks of shape
* block_shape, and interleaves these blocks with the "batch" dimension (0) such that in the output,
* the spatial dimensions [1, ..., M] correspond to the position within the grid, and the batch dimension
* combines both the position within a spatial block and the original batch position. Prior to division into
* blocks, the spatial dimensions of the input are optionally zero padded according to paddings.
*
* Inputs:
* 0 - input (N-D tensor)
* 1 - block_shape - int 1D tensor with M length
* 2 - paddings - int 2D tensor with shape {M, 2}
*
* Output:
* - N-D tensor with the same type as input 0.
*
* */
#if NOT_EXCLUDED(OP_space_to_batch_nd) #if NOT_EXCLUDED(OP_space_to_batch_nd)
DECLARE_CUSTOM_OP(space_to_batch_nd, 3, 1, false, 0, 0); DECLARE_CUSTOM_OP(space_to_batch_nd, 3, 1, false, 0, 0);
#endif #endif
@ -973,7 +1027,7 @@ namespace nd4j {
* return value: * return value:
* tensor with min values according to indices sets. * tensor with min values according to indices sets.
*/ */
#if NOT_EXCLUDED(OP_segment_min_bp) #if NOT_EXCLUDED(OP_segment_min)
DECLARE_CUSTOM_OP(segment_min, 2, 1, false, 0, 0); DECLARE_CUSTOM_OP(segment_min, 2, 1, false, 0, 0);
#endif #endif
#if NOT_EXCLUDED(OP_segment_min_bp) #if NOT_EXCLUDED(OP_segment_min_bp)

View File

@ -118,19 +118,19 @@ namespace nd4j {
PointersManager pm(context, "dynamicPartition"); PointersManager pm(context, "dynamicPartition");
if (sourceDimsLen) { if (sourceDimsLen) { // non-linear case
std::vector<int> sourceDims(sourceDimsLen); std::vector<int> sourceDims(sourceDimsLen);
for (int i = sourceDimsLen; i > 0; i--) for (int i = sourceDimsLen; i > 0; i--)
sourceDims[sourceDimsLen - i] = input->rankOf() - i; sourceDims[sourceDimsLen - i] = input->rankOf() - i;
//compute tad array for given dimensions
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), sourceDims); auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), sourceDims);
std::vector<void *> outBuffers(outSize); std::vector<void *> outBuffers(outSize);
std::vector<Nd4jLong *> tadShapes(outSize); std::vector<Nd4jLong *> tadShapes(outSize);
std::vector<Nd4jLong *> tadOffsets(outSize); std::vector<Nd4jLong *> tadOffsets(outSize);
std::vector<Nd4jLong> numTads(outSize); std::vector<Nd4jLong> numTads(outSize);
// fill up dimensions array for before kernel
for (unsigned int i = 0; i < outSize; i++) { for (unsigned int i = 0; i < outSize; i++) {
outputs[i].first = outputList[i]; outputs[i].first = outputList[i];
std::vector<int> outDims(outputs[i].first->rankOf() - 1); std::vector<int> outDims(outputs[i].first->rankOf() - 1);
@ -151,10 +151,10 @@ namespace nd4j {
auto dOutBuffers = reinterpret_cast<void **>(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *))); auto dOutBuffers = reinterpret_cast<void **>(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *)));
auto dOutTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *))); auto dOutTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *)));
auto dOutTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *))); auto dOutTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *)));
// run kernel on device
dynamicPartitionTadKernel<X,Y><<<256, 256, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize); dynamicPartitionTadKernel<X,Y><<<256, 256, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize);
} else { } else { // linear case
auto numThreads = 256; auto numThreads = 256;
auto shmemSize = numThreads * sizeof(Y) * 2 + 1024; auto shmemSize = numThreads * sizeof(Y) * 2 + 1024;
@ -169,7 +169,6 @@ namespace nd4j {
auto dOutBuffers = reinterpret_cast<void **>(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *))); auto dOutBuffers = reinterpret_cast<void **>(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *)));
auto dOutShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(outShapes.data(), outShapes.size() * sizeof(Nd4jLong *))); auto dOutShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(outShapes.data(), outShapes.size() * sizeof(Nd4jLong *)));
dynamicPartitionScalarKernel<X,Y><<<256, numThreads, shmemSize, *context->getCudaStream()>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), indices->getSpecialBuffer(), indices-> getSpecialShapeInfo(), dOutBuffers, dOutShapes, outSize); dynamicPartitionScalarKernel<X,Y><<<256, numThreads, shmemSize, *context->getCudaStream()>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), indices->getSpecialBuffer(), indices-> getSpecialShapeInfo(), dOutBuffers, dOutShapes, outSize);
} }

View File

@ -544,8 +544,8 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_1) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, adjustSaturation_2) { TEST_F(DeclarableOpsTests13, adjustSaturation_2) {
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32); NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::DOUBLE);
NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::FLOAT32); NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::DOUBLE);
nd4j::ops::adjust_saturation op; nd4j::ops::adjust_saturation op;
auto results = op.execute({&input}, {10}, {2}); auto results = op.execute({&input}, {10}, {2});
@ -553,7 +553,8 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_2) {
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0); auto result = results->at(0);
// result->printIndexedBuffer(); // result->printIndexedBuffer("Result2");
// exp.printIndexedBuffer("Expect2");
ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result)); ASSERT_TRUE(exp.equalsTo(result));

View File

@ -94,10 +94,10 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) {
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector()); auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
auto fBuffer = builder.CreateVector(array->asByteVector()); auto fBuffer = builder.CreateVector(array->asByteVector());
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT); auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT);
auto fVid = CreateIntPair(builder, -1); auto fVid = CreateIntPair(builder, -1);
auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray); auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray);
std::vector<int> outputs1, outputs2, inputs1, inputs2; std::vector<int> outputs1, outputs2, inputs1, inputs2;
outputs1.push_back(2); outputs1.push_back(2);
@ -265,7 +265,7 @@ TEST_F(FlatBuffersTest, ExplicitOutputTest1) {
auto name1 = builder.CreateString("wow1"); auto name1 = builder.CreateString("wow1");
auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DataType::FLOAT); auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DType::FLOAT);
std::vector<flatbuffers::Offset<FlatVariable>> variables_vector; std::vector<flatbuffers::Offset<FlatVariable>> variables_vector;
variables_vector.push_back(fXVar); variables_vector.push_back(fXVar);

View File

@ -73,9 +73,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_1) {
auto fBuffer = builder.CreateVector(vec); auto fBuffer = builder.CreateVector(vec);
auto fVid = CreateIntPair(builder, 1, 12); auto fVid = CreateIntPair(builder, 1, 12);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT); auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray); auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray);
builder.Finish(flatVar); builder.Finish(flatVar);
@ -107,9 +107,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_2) {
auto fBuffer = builder.CreateVector(vec); auto fBuffer = builder.CreateVector(vec);
auto fVid = CreateIntPair(builder, 1, 12); auto fVid = CreateIntPair(builder, 1, 12);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE); auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray); auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray);
builder.Finish(flatVar); builder.Finish(flatVar);
@ -144,9 +144,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_3) {
auto fBuffer = builder.CreateVector(vec); auto fBuffer = builder.CreateVector(vec);
auto fVid = CreateIntPair(builder, 1, 12); auto fVid = CreateIntPair(builder, 1, 12);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE); auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray); auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray);
builder.Finish(flatVar); builder.Finish(flatVar);
@ -180,7 +180,7 @@ TEST_F(VariableTests, Test_FlatVariableDataType_4) {
auto fShape = builder.CreateVector(original.getShapeAsFlatVector()); auto fShape = builder.CreateVector(original.getShapeAsFlatVector());
auto fVid = CreateIntPair(builder, 37, 12); auto fVid = CreateIntPair(builder, 37, 12);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER); auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER);
builder.Finish(flatVar); builder.Finish(flatVar);

View File

@ -469,7 +469,7 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable localResponseNormalization(SDVariable input, LocalResponseNormalizationConfig lrnConfig) { public SDVariable localResponseNormalization(SDVariable input, LocalResponseNormalizationConfig lrnConfig) {
LocalResponseNormalization lrn = LocalResponseNormalization.builder() LocalResponseNormalization lrn = LocalResponseNormalization.sameDiffBuilder()
.inputFunctions(new SDVariable[]{input}) .inputFunctions(new SDVariable[]{input})
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.config(lrnConfig) .config(lrnConfig)
@ -487,7 +487,7 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
Conv1D conv1D = Conv1D.builder() Conv1D conv1D = Conv1D.sameDiffBuilder()
.inputFunctions(new SDVariable[]{input, weights}) .inputFunctions(new SDVariable[]{input, weights})
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.config(conv1DConfig) .config(conv1DConfig)
@ -496,6 +496,34 @@ public class DifferentialFunctionFactory {
return conv1D.outputVariable(); return conv1D.outputVariable();
} }
/**
* Conv1d operation.
*
* @param input the inputs to conv1d
* @param weights conv1d weights
* @param bias conv1d bias
* @param conv1DConfig the configuration
* @return
*/
public SDVariable conv1d(SDVariable input, SDVariable weights, SDVariable bias, Conv1DConfig conv1DConfig) {
SDVariable[] args;
if(bias == null){
args = new SDVariable[]{input, weights};
} else {
args = new SDVariable[]{input, weights, bias};
}
Conv1D conv1D = Conv1D.sameDiffBuilder()
.inputFunctions(args)
.sameDiff(sameDiff())
.config(conv1DConfig)
.build();
return conv1D.outputVariable();
}
/** /**
* Conv2d operation. * Conv2d operation.
* *
@ -504,7 +532,7 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
Conv2D conv2D = Conv2D.builder() Conv2D conv2D = Conv2D.sameDiffBuilder()
.inputFunctions(inputs) .inputFunctions(inputs)
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.config(conv2DConfig) .config(conv2DConfig)
@ -530,7 +558,7 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
AvgPooling2D avgPooling2D = AvgPooling2D.builder() AvgPooling2D avgPooling2D = AvgPooling2D.sameDiffBuilder()
.input(input) .input(input)
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.config(pooling2DConfig) .config(pooling2DConfig)
@ -547,7 +575,7 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
MaxPooling2D maxPooling2D = MaxPooling2D.builder() MaxPooling2D maxPooling2D = MaxPooling2D.sameDiffBuilder()
.input(input) .input(input)
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.config(pooling2DConfig) .config(pooling2DConfig)
@ -590,7 +618,7 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
SConv2D sconv2D = SConv2D.sBuilder() SConv2D sconv2D = SConv2D.sameDiffSBuilder()
.inputFunctions(inputs) .inputFunctions(inputs)
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.conv2DConfig(conv2DConfig) .conv2DConfig(conv2DConfig)
@ -609,7 +637,7 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
SConv2D depthWiseConv2D = SConv2D.sBuilder() SConv2D depthWiseConv2D = SConv2D.sameDiffSBuilder()
.inputFunctions(inputs) .inputFunctions(inputs)
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.conv2DConfig(depthConv2DConfig) .conv2DConfig(depthConv2DConfig)
@ -627,7 +655,7 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
DeConv2D deconv2D = DeConv2D.builder() DeConv2D deconv2D = DeConv2D.sameDiffBuilder()
.inputs(inputs) .inputs(inputs)
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.config(deconv2DConfig) .config(deconv2DConfig)
@ -654,9 +682,9 @@ public class DifferentialFunctionFactory {
* @return * @return
*/ */
public SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig) { public SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig) {
Conv3D conv3D = Conv3D.builder() Conv3D conv3D = Conv3D.sameDiffBuilder()
.inputFunctions(inputs) .inputFunctions(inputs)
.conv3DConfig(conv3DConfig) .config(conv3DConfig)
.sameDiff(sameDiff()) .sameDiff(sameDiff())
.build(); .build();
@ -1260,6 +1288,22 @@ public class DifferentialFunctionFactory {
return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable(); return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable();
} }
public SDVariable bitwiseHammingDist(SDVariable x, SDVariable y) {
return new BitsHammingDistance(sameDiff(), x, y).outputVariable();
}
public SDVariable bitwiseAnd(SDVariable x, SDVariable y){
return new BitwiseAnd(sameDiff(), x, y).outputVariable();
}
public SDVariable bitwiseOr(SDVariable x, SDVariable y){
return new BitwiseOr(sameDiff(), x, y).outputVariable();
}
public SDVariable bitwiseXor(SDVariable x, SDVariable y){
return new BitwiseXor(sameDiff(), x, y).outputVariable();
}
public SDVariable eq(SDVariable iX, SDVariable i_y) { public SDVariable eq(SDVariable iX, SDVariable i_y) {
return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable(); return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable();
} }

View File

@ -188,6 +188,11 @@ public class SameDiff extends SDBaseOps {
*/ */
public final SDImage image = new SDImage(this); public final SDImage image = new SDImage(this);
/**
* Op creator object for bitwise operations
*/
public final SDBitwise bitwise = new SDBitwise(this);
/** /**
* Op creator object for math operations * Op creator object for math operations
*/ */
@ -237,6 +242,13 @@ public class SameDiff extends SDBaseOps {
return image; return image;
} }
/**
* Op creator object for bitwise operations
*/
public SDBitwise bitwise(){
return bitwise;
}
/** /**
* For import, many times we have variables * For import, many times we have variables

View File

@ -0,0 +1,205 @@
package org.nd4j.autodiff.samediff.ops;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
/**
*
*/
public class SDBitwise extends SDOps {
public SDBitwise(SameDiff sameDiff) {
super(sameDiff);
}
/**
* See {@link #leftShift(String, SDVariable, SDVariable)}
*/
public SDVariable leftShift(@NonNull SDVariable x, @NonNull SDVariable y){
return leftShift(null, x, y);
}
/**
* Bitwise left shift operation. Supports broadcasting.
*
* @param name Name of the output variable. May be null.
* @param x Input to be bit shifted (must be an integer type)
* @param y Amount to shift elements of x array (must be an integer type)
* @return Bitwise shifted input x
*/
public SDVariable leftShift(String name, SDVariable x, SDVariable y){
validateInteger("bitwise left shift", x);
validateInteger("bitwise left shift", y);
SDVariable ret = f().shift(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #rightShift(String, SDVariable, SDVariable)}
*/
public SDVariable rightShift(SDVariable x, SDVariable y){
return rightShift(null, x, y);
}
/**
* Bitwise right shift operation. Supports broadcasting.
*
* @param name Name of the output variable. May be null.
* @param x Input to be bit shifted (must be an integer type)
* @param y Amount to shift elements of x array (must be an integer type)
* @return Bitwise shifted input x
*/
public SDVariable rightShift(String name, SDVariable x, SDVariable y){
validateInteger("bitwise right shift", x);
validateInteger("bitwise right shift", y);
SDVariable ret = f().rshift(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #leftShiftCyclic(String, SDVariable, SDVariable)}
*/
public SDVariable leftShiftCyclic(SDVariable x, SDVariable y){
return leftShiftCyclic(null, x, y);
}
/**
* Bitwise left cyclical shift operation. Supports broadcasting.
* Unlike {@link #leftShift(String, SDVariable, SDVariable)} the bits will "wrap around":
* {@code leftShiftCyclic(01110000, 2) -> 11000001}
*
* @param name Name of the output variable. May be null.
* @param x Input to be bit shifted (must be an integer type)
* @param y Amount to shift elements of x array (must be an integer type)
* @return Bitwise cyclic shifted input x
*/
public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y){
validateInteger("bitwise left shift (cyclic)", x);
validateInteger("bitwise left shift (cyclic)", y);
SDVariable ret = f().rotl(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #rightShiftCyclic(String, SDVariable, SDVariable)}
*/
public SDVariable rightShiftCyclic(SDVariable x, SDVariable y){
return rightShiftCyclic(null, x, y);
}
/**
* Bitwise right cyclical shift operation. Supports broadcasting.
* Unlike {@link #rightShift(String, SDVariable, SDVariable)} the bits will "wrap around":
* {@code rightShiftCyclic(00001110, 2) -> 10000011}
*
* @param name Name of the output variable. May be null.
* @param x Input to be bit shifted (must be an integer type)
* @param y Amount to shift elements of x array (must be an integer type)
* @return Bitwise cyclic shifted input x
*/
public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y){
validateInteger("bitwise right shift (cyclic)", x);
validateInteger("bitwise right shift (cyclic)", y);
SDVariable ret = f().rotr(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #bitsHammingDistance(String, SDVariable, SDVariable)}
*/
public SDVariable bitsHammingDistance(SDVariable x, SDVariable y){
return bitsHammingDistance(null, x, y);
}
/**
* Bitwise Hamming distance reduction over all elements of both input arrays.<br>
* For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
*
* @param name Name of the output variable. May be null.
* @param x First input array. Must be integer type.
* @param y First input array. Must be integer type, same type as x
* @return
*/
public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y){
validateInteger("bitwise hamming distance", x);
validateInteger("bitwise hamming distance", y);
SDVariable ret = f().bitwiseHammingDist(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #and(String, SDVariable, SDVariable)}
*/
public SDVariable and(SDVariable x, SDVariable y){
return and(null, x, y);
}
/**
* Bitwise AND operation. Supports broadcasting.
*
* @param name Name of the output variable. May be null.
* @param x First input array. Must be integer type.
* @param y First input array. Must be integer type, same type as x
* @return Bitwise AND array
*/
public SDVariable and(String name, SDVariable x, SDVariable y){
validateInteger("bitwise AND", x);
validateInteger("bitwise AND", y);
SDVariable ret = f().bitwiseAnd(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #or(String, SDVariable, SDVariable)}
*/
public SDVariable or(SDVariable x, SDVariable y){
return or(null, x, y);
}
/**
* Bitwise OR operation. Supports broadcasting.
*
* @param name Name of the output variable. May be null.
* @param x First input array. Must be integer type.
* @param y First input array. Must be integer type, same type as x
* @return Bitwise OR array
*/
public SDVariable or(String name, SDVariable x, SDVariable y){
validateInteger("bitwise OR", x);
validateInteger("bitwise OR", y);
SDVariable ret = f().bitwiseOr(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #xor(String, SDVariable, SDVariable)}
*/
public SDVariable xor(SDVariable x, SDVariable y){
return xor(null, x, y);
}
/**
* Bitwise XOR operation (exclusive OR). Supports broadcasting.
*
* @param name Name of the output variable. May be null.
* @param x First input array. Must be integer type.
* @param y First input array. Must be integer type, same type as x
* @return Bitwise XOR array
*/
public SDVariable xor(String name, SDVariable x, SDVariable y){
validateInteger("bitwise XOR", x);
validateInteger("bitwise XOR", y);
SDVariable ret = f().bitwiseXor(x, y);
return updateVariableNameAndReference(ret, name);
}
}

View File

@ -16,6 +16,7 @@
package org.nd4j.autodiff.samediff.ops; package org.nd4j.autodiff.samediff.ops;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
@ -38,14 +39,9 @@ public class SDCNN extends SDOps {
} }
/** /**
* 2D Convolution layer operation - average pooling 2d * See {@link #avgPooling2d(String, SDVariable, Pooling2DConfig)}.
*
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param pooling2DConfig the configuration for
* @return Result after applying average pooling on the input
*/ */
public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { public SDVariable avgPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
return avgPooling2d(null, input, pooling2DConfig); return avgPooling2d(null, input, pooling2DConfig);
} }
@ -58,22 +54,16 @@ public class SDCNN extends SDOps {
* @param pooling2DConfig the configuration * @param pooling2DConfig the configuration
* @return Result after applying average pooling on the input * @return Result after applying average pooling on the input
*/ */
public SDVariable avgPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) { public SDVariable avgPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
validateFloatingPoint("avgPooling2d", input); validateFloatingPoint("avgPooling2d", input);
SDVariable ret = f().avgPooling2d(input, pooling2DConfig); SDVariable ret = f().avgPooling2d(input, pooling2DConfig);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/** /**
* 3D convolution layer operation - average pooling 3d * See {@link #avgPooling3d(String, SDVariable, Pooling3DConfig)}.
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param pooling3DConfig the configuration
* @return Result after applying average pooling on the input
*/ */
public SDVariable avgPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) { public SDVariable avgPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
return avgPooling3d(null, input, pooling3DConfig); return avgPooling3d(null, input, pooling3DConfig);
} }
@ -87,7 +77,7 @@ public class SDCNN extends SDOps {
* @param pooling3DConfig the configuration * @param pooling3DConfig the configuration
* @return Result after applying average pooling on the input * @return Result after applying average pooling on the input
*/ */
public SDVariable avgPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) { public SDVariable avgPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
validateFloatingPoint("avgPooling3d", input); validateFloatingPoint("avgPooling3d", input);
SDVariable ret = f().avgPooling3d(input, pooling3DConfig); SDVariable ret = f().avgPooling3d(input, pooling3DConfig);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
@ -96,7 +86,7 @@ public class SDCNN extends SDOps {
/** /**
* @see #batchToSpace(String, SDVariable, int[], int[][]) * @see #batchToSpace(String, SDVariable, int[], int[][])
*/ */
public SDVariable batchToSpace(SDVariable x, int[] blocks, int[][] crops) { public SDVariable batchToSpace(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) {
return batchToSpace(null, x, blocks, crops); return batchToSpace(null, x, blocks, crops);
} }
@ -111,7 +101,7 @@ public class SDCNN extends SDOps {
* @return Output variable * @return Output variable
* @see #spaceToBatch(String, SDVariable, int[], int[][]) * @see #spaceToBatch(String, SDVariable, int[], int[][])
*/ */
public SDVariable batchToSpace(String name, SDVariable x, int[] blocks, int[][] crops) { public SDVariable batchToSpace(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) {
validateNumerical("batchToSpace", x); validateNumerical("batchToSpace", x);
SDVariable ret = f().batchToSpace(x, blocks, crops); SDVariable ret = f().batchToSpace(x, blocks, crops);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
@ -119,14 +109,9 @@ public class SDCNN extends SDOps {
/** /**
* col2im operation for use in 2D convolution operations. Outputs a 4d array with shape * See {@link #col2Im(String, SDVariable, Conv2DConfig)}.
* [minibatch, inputChannels, height, width]
*
* @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
* @param config Convolution configuration for the col2im operation
* @return Col2Im output variable
*/ */
public SDVariable col2Im(SDVariable in, Conv2DConfig config) { public SDVariable col2Im(@NonNull SDVariable in, @NonNull Conv2DConfig config) {
return col2Im(null, in, config); return col2Im(null, in, config);
} }
@ -139,33 +124,22 @@ public class SDCNN extends SDOps {
* @param config Convolution configuration for the col2im operation * @param config Convolution configuration for the col2im operation
* @return Col2Im output variable * @return Col2Im output variable
*/ */
public SDVariable col2Im(String name, SDVariable in, Conv2DConfig config) { public SDVariable col2Im(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) {
SDVariable ret = f().col2Im(in, config); SDVariable ret = f().col2Im(in, config);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/** /**
* 1D Convolution layer operation - Conv1d * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias.
*
* @param input the input array/activations for the conv1d op
* @param weights weights for conv1d op - rank 3 array with values [kernelSize, inputChannels, outputChannels]
* @param conv1DConfig the configuration
* @return
*/ */
public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) {
return conv1d(null, input, weights, conv1DConfig); return conv1d((String) null, input, weights, conv1DConfig);
} }
/** /**
* Conv1d operation. * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias.
*
* @param name name of the operation in SameDiff
* @param input the inputs to conv1d
* @param weights weights for conv1d op - rank 3 array with values [kernelSize, inputChannels, outputChannels]
* @param conv1DConfig the configuration
* @return
*/ */
public SDVariable conv1d(String name, SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) {
validateFloatingPoint("conv1d", input); validateFloatingPoint("conv1d", input);
validateFloatingPoint("conv1d", weights); validateFloatingPoint("conv1d", weights);
SDVariable ret = f().conv1d(input, weights, conv1DConfig); SDVariable ret = f().conv1d(input, weights, conv1DConfig);
@ -173,21 +147,55 @@ public class SDCNN extends SDOps {
} }
/** /**
* 2D Convolution operation (without bias) * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}.
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels]
* @param config Conv2DConfig configuration
* @return result of conv2d op
*/ */
public SDVariable conv2d(SDVariable layerInput, SDVariable weights, Conv2DConfig config) { public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) {
return conv1d(null, input, weights, bias, conv1DConfig);
}
/**
* Conv1d operation.
*
* @param name name of the operation in SameDiff
* @param input the inputs to conv1d
* @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels]
* @param bias bias for conv1d op - rank 1 array with shape [outputChannels]. May be null.
* @param conv1DConfig the configuration
* @return
*/
public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) {
validateFloatingPoint("conv1d", input);
validateFloatingPoint("conv1d", weights);
validateFloatingPoint("conv1d", bias);
SDVariable ret = f().conv1d(input, weights, bias, conv1DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*/
public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) {
return conv2d(layerInput, weights, null, config); return conv2d(layerInput, weights, null, config);
} }
/**
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*/
public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) {
return conv2d(name, layerInput, weights, null, config);
}
/**
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
*/
public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) {
return conv2d(null, layerInput, weights, bias, config);
}
/** /**
* 2D Convolution operation with optional bias * 2D Convolution operation with optional bias
* *
* @param name name of the operation in SameDiff
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels]
@ -195,7 +203,7 @@ public class SDCNN extends SDOps {
* @param config Conv2DConfig configuration * @param config Conv2DConfig configuration
* @return result of conv2d op * @return result of conv2d op
*/ */
public SDVariable conv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, Conv2DConfig config) { public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) {
validateFloatingPoint("conv2d", "input", layerInput); validateFloatingPoint("conv2d", "input", layerInput);
validateFloatingPoint("conv2d", "weights", weights); validateFloatingPoint("conv2d", "weights", weights);
validateFloatingPoint("conv2d", "bias", bias); validateFloatingPoint("conv2d", "bias", bias);
@ -204,18 +212,13 @@ public class SDCNN extends SDOps {
arr[1] = weights; arr[1] = weights;
if (bias != null) if (bias != null)
arr[2] = bias; arr[2] = bias;
return conv2d(arr, config); return conv2d(name, arr, config);
} }
/** /**
* 2D Convolution operation with optional bias * See {@link #conv2d(String, SDVariable[], Conv2DConfig)}.
*
* @param inputs an array with either 2 elements (layerInput, weights) or 3 elements (layerInput, weights, bias) as
* described in {@link #conv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)}
* @param config Conv2DConfig configuration
* @return result of convolution 2d operation
*/ */
public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig config) { public SDVariable conv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) {
return conv2d(null, inputs, config); return conv2d(null, inputs, config);
} }
@ -228,7 +231,7 @@ public class SDCNN extends SDOps {
* @param config Conv2DConfig configuration * @param config Conv2DConfig configuration
* @return result of convolution 2d operation * @return result of convolution 2d operation
*/ */
public SDVariable conv2d(String name, SDVariable[] inputs, Conv2DConfig config) { public SDVariable conv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) {
for(SDVariable v : inputs) for(SDVariable v : inputs)
validateNumerical("conv2d", v); validateNumerical("conv2d", v);
SDVariable ret = f().conv2d(inputs, config); SDVariable ret = f().conv2d(inputs, config);
@ -236,19 +239,26 @@ public class SDCNN extends SDOps {
} }
/** /**
* Convolution 3D operation without bias * See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias.
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
* @param conv3DConfig the configuration
* @return Conv3d output variable
*/ */
public SDVariable conv3d(SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) { public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) {
return conv3d(null, input, weights, null, conv3DConfig); return conv3d(null, input, weights, null, conv3DConfig);
} }
/**
* See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias.
*/
public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) {
return conv3d(name, input, weights, null, conv3DConfig);
}
/**
* See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}.
*/
public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) {
return conv3d(null, input, weights, bias, conv3DConfig);
}
/** /**
* Convolution 3D operation with optional bias * Convolution 3D operation with optional bias
* *
@ -261,7 +271,7 @@ public class SDCNN extends SDOps {
* @param conv3DConfig the configuration * @param conv3DConfig the configuration
* @return Conv3d output variable * @return Conv3d output variable
*/ */
public SDVariable conv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) { public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) {
validateFloatingPoint("conv3d", "input", input); validateFloatingPoint("conv3d", "input", input);
validateFloatingPoint("conv3d", "weights", weights); validateFloatingPoint("conv3d", "weights", weights);
validateFloatingPoint("conv3d", "bias", bias); validateFloatingPoint("conv3d", "bias", bias);
@ -276,51 +286,30 @@ public class SDCNN extends SDOps {
} }
/** /**
* Convolution 3D operation with optional bias * See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias.
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
* @param bias Optional 1D bias array with shape [outputChannels]. May be null.
* @param conv3DConfig the configuration
* @return Conv3d output variable
*/ */
public SDVariable conv3d(SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) { public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) {
return conv3d(null, input, weights, bias, conv3DConfig);
}
/**
* Convolution 3D operation without bias
*
* @param name Name of the output variable
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
* @param conv3DConfig the configuration
* @return Conv3d output variable
*/
public SDVariable conv3d(String name, SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) {
return conv3d(name, input, weights, null, conv3DConfig);
}
/**
* 2D deconvolution operation without bias
*
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth].
* @param deconv2DConfig DeConv2DConfig configuration
* @return result of deconv2d op
*/
public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, DeConv2DConfig deconv2DConfig) {
return deconv2d(layerInput, weights, null, deconv2DConfig); return deconv2d(layerInput, weights, null, deconv2DConfig);
} }
/**
* See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias.
*/
public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) {
return deconv2d(name, layerInput, weights, null, deconv2DConfig);
}
/**
* See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}.
*/
public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) {
return deconv2d(null, layerInput, weights, bias, deconv2DConfig);
}
/** /**
* 2D deconvolution operation with optional bias * 2D deconvolution operation with optional bias
* *
* @param name name of the operation in SameDiff
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth]. * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth].
@ -328,7 +317,7 @@ public class SDCNN extends SDOps {
* @param deconv2DConfig DeConv2DConfig configuration * @param deconv2DConfig DeConv2DConfig configuration
* @return result of deconv2d op * @return result of deconv2d op
*/ */
public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, DeConv2DConfig deconv2DConfig) { public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) {
validateFloatingPoint("deconv2d", "input", layerInput); validateFloatingPoint("deconv2d", "input", layerInput);
validateFloatingPoint("deconv2d", "weights", weights); validateFloatingPoint("deconv2d", "weights", weights);
validateFloatingPoint("deconv2d", "bias", bias); validateFloatingPoint("deconv2d", "bias", bias);
@ -337,18 +326,13 @@ public class SDCNN extends SDOps {
arr[1] = weights; arr[1] = weights;
if (bias != null) if (bias != null)
arr[2] = bias; arr[2] = bias;
return deconv2d(arr, deconv2DConfig); return deconv2d(name, arr, deconv2DConfig);
} }
/** /**
* 2D deconvolution operation with or without optional bias * See {@link #deconv2d(String, SDVariable[], DeConv2DConfig)}.
*
* @param inputs Inputs to the deconvolution 2d operation - input array of length 2 (layerInput, weights)
* or length 3 (layerInput, weights, bias) as described in {@link #deconv2d(SDVariable[], DeConv2DConfig)}
* @param deconv2DConfig the configuration
* @return result of deconv2d op
*/ */
public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { public SDVariable deconv2d(@NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) {
return deconv2d(null, inputs, deconv2DConfig); return deconv2d(null, inputs, deconv2DConfig);
} }
@ -361,13 +345,34 @@ public class SDCNN extends SDOps {
* @param deconv2DConfig the configuration * @param deconv2DConfig the configuration
* @return result of deconv2d op * @return result of deconv2d op
*/ */
public SDVariable deconv2d(String name, SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { public SDVariable deconv2d(String name, @NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) {
for(SDVariable v : inputs) for(SDVariable v : inputs)
validateNumerical("deconv2d", v); validateNumerical("deconv2d", v);
SDVariable ret = f().deconv2d(inputs, deconv2DConfig); SDVariable ret = f().deconv2d(inputs, deconv2DConfig);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/**
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias.
*/
public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) {
return deconv3d(input, weights, null, config);
}
/**
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias.
*/
public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) {
return deconv3d(name, input, weights, null, config);
}
/**
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}.
*/
public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
return deconv3d(null, input, weights, bias, config);
}
/** /**
* 3D CNN deconvolution operation with or without optional bias * 3D CNN deconvolution operation with or without optional bias
* *
@ -377,7 +382,7 @@ public class SDCNN extends SDOps {
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels] * @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels]
* @param config Configuration * @param config Configuration
*/ */
public SDVariable deconv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) { public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
validateFloatingPoint("conv3d", input); validateFloatingPoint("conv3d", input);
validateFloatingPoint("conv3d", weights); validateFloatingPoint("conv3d", weights);
validateFloatingPoint("conv3d", bias); validateFloatingPoint("conv3d", bias);
@ -386,41 +391,9 @@ public class SDCNN extends SDOps {
} }
/** /**
* 3D CNN deconvolution operation with or without optional bias * See {@link #depthToSpace(String, SDVariable, int, String)}.
*
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
* @param weights Weights array - shape [kD, kH, kW, oC, iC]
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels]
* @param config Configuration
*/ */
public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) { public SDVariable depthToSpace(@NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) {
return deconv3d(null, input, weights, bias, config);
}
/**
* 3D CNN deconvolution operation with no bias
*
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
* @param weights Weights array - shape [kD, kH, kW, oC, iC]
* @param config Configuration
*/
public SDVariable deconv3d(SDVariable input, SDVariable weights, DeConv3DConfig config) {
return deconv3d(input, weights, null, config);
}
/**
* Convolution 2d layer batch to space operation on 4d input.<br>
* Reduces input channels dimension by rearranging data into a larger spatial dimensions<br>
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
* = [mb, 2, 4, 4]
*
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param blockSize Block size, in the height/width dimension
* @param dataFormat Data format: "NCHW" or "NHWC"
* @return Output variable
*/
public SDVariable depthToSpace(SDVariable x, int blockSize, String dataFormat) {
return depthToSpace(null, x, blockSize, dataFormat); return depthToSpace(null, x, blockSize, dataFormat);
} }
@ -438,27 +411,36 @@ public class SDCNN extends SDOps {
* @return Output variable * @return Output variable
* @see #depthToSpace(String, SDVariable, int, String) * @see #depthToSpace(String, SDVariable, int, String)
*/ */
public SDVariable depthToSpace(String name, SDVariable x, int blockSize, String dataFormat) { public SDVariable depthToSpace(String name, @NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) {
SDVariable ret = f().depthToSpace(x, blockSize, dataFormat); SDVariable ret = f().depthToSpace(x, blockSize, dataFormat);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/** /**
* Depth-wise 2D convolution operation without bias * See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
* @param config Conv2DConfig configuration
* @return result of conv2d op
*/ */
public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, Conv2DConfig config) { public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) {
return depthWiseConv2d(layerInput, depthWeights, null, config); return depthWiseConv2d(layerInput, depthWeights, null, config);
} }
/**
* See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*/
public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) {
return depthWiseConv2d(name, layerInput, depthWeights, null, config);
}
/**
* See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
*/
public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) {
return depthWiseConv2d(null, layerInput, depthWeights, bias, config);
}
/** /**
* Depth-wise 2D convolution operation with optional bias * Depth-wise 2D convolution operation with optional bias
* *
* @param name name of the operation in SameDiff
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
@ -466,7 +448,7 @@ public class SDCNN extends SDOps {
* @param config Conv2DConfig configuration * @param config Conv2DConfig configuration
* @return result of depthwise conv2d op * @return result of depthwise conv2d op
*/ */
public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable bias, Conv2DConfig config) { public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) {
validateFloatingPoint("depthwiseConv2d", "input", layerInput); validateFloatingPoint("depthwiseConv2d", "input", layerInput);
validateFloatingPoint("depthwiseConv2d", "depth weights", depthWeights); validateFloatingPoint("depthwiseConv2d", "depth weights", depthWeights);
validateFloatingPoint("depthwiseConv2d", "bias", bias); validateFloatingPoint("depthwiseConv2d", "bias", bias);
@ -475,19 +457,13 @@ public class SDCNN extends SDOps {
arr[1] = depthWeights; arr[1] = depthWeights;
if (bias != null) if (bias != null)
arr[2] = bias; arr[2] = bias;
return depthWiseConv2d(arr, config); return depthWiseConv2d(name, arr, config);
} }
/** /**
* Depth-wise convolution 2D operation. * See {@link #depthWiseConv2d(String, SDVariable[], Conv2DConfig)}.
*
* @param inputs the inputs to depth-wise conv2d. An array with either 2 elements (layerInput, depthWeights)
* or 3 elements (layerInput, depthWeights, bias) as described in
* {@link #depthWiseConv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)}
* @param depthConv2DConfig the configuration
* @return result of depthwise conv2d op
*/ */
public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { public SDVariable depthWiseConv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) {
return depthWiseConv2d(null, inputs, depthConv2DConfig); return depthWiseConv2d(null, inputs, depthConv2DConfig);
} }
@ -501,7 +477,7 @@ public class SDCNN extends SDOps {
* @param depthConv2DConfig the configuration * @param depthConv2DConfig the configuration
* @return result of depthwise conv2d op * @return result of depthwise conv2d op
*/ */
public SDVariable depthWiseConv2d(String name, SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { public SDVariable depthWiseConv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) {
for(SDVariable v : inputs) for(SDVariable v : inputs)
validateFloatingPoint("depthWiseConv2d", v); validateFloatingPoint("depthWiseConv2d", v);
SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig); SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig);
@ -509,17 +485,10 @@ public class SDCNN extends SDOps {
} }
/** /**
* TODO doc string * See {@link #dilation2D(String, SDVariable, SDVariable, int[], int[], boolean)}.
*
* @param df
* @param weights
* @param strides
* @param rates
* @param isSameMode
* @return
*/ */
public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides, public SDVariable dilation2D(@NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides,
int[] rates, boolean isSameMode) { @NonNull int[] rates, @NonNull boolean isSameMode) {
return dilation2D(null, df, weights, strides, rates, isSameMode); return dilation2D(null, df, weights, strides, rates, isSameMode);
} }
@ -534,8 +503,8 @@ public class SDCNN extends SDOps {
* @param isSameMode * @param isSameMode
* @return * @return
*/ */
public SDVariable dilation2D(String name, SDVariable df, SDVariable weights, int[] strides, public SDVariable dilation2D(String name, @NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides,
int[] rates, boolean isSameMode) { @NonNull int[] rates, @NonNull boolean isSameMode) {
SDVariable ret = f().dilation2D(df, weights, strides, rates, isSameMode); SDVariable ret = f().dilation2D(df, weights, strides, rates, isSameMode);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
@ -555,21 +524,16 @@ public class SDCNN extends SDOps {
* @param sameMode If true: use same mode padding. If false * @param sameMode If true: use same mode padding. If false
* @return * @return
*/ */
public SDVariable extractImagePatches(String name, SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) { public SDVariable extractImagePatches(String name, @NonNull SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) {
SDVariable ret = f().extractImagePatches(input, kH, kW, sH, sW, rH, rW, sameMode); SDVariable ret = f().extractImagePatches(input, kH, kW, sH, sW, rH, rW, sameMode);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/** /**
* im2col operation for use in 2D convolution operations. Outputs a 6d array with shape * See {@link #im2Col(String, SDVariable, Conv2DConfig)}.
* [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
*
* @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width]
* @param config Convolution configuration for the im2col operation
* @return Im2Col output variable
*/ */
public SDVariable im2Col(SDVariable in, Conv2DConfig config) { public SDVariable im2Col(@NonNull SDVariable in, @NonNull Conv2DConfig config) {
return im2Col(null, in, config); return im2Col(null, in, config);
} }
@ -582,20 +546,16 @@ public class SDCNN extends SDOps {
* @param config Convolution configuration for the im2col operation * @param config Convolution configuration for the im2col operation
* @return Im2Col output variable * @return Im2Col output variable
*/ */
public SDVariable im2Col(String name, SDVariable in, Conv2DConfig config) { public SDVariable im2Col(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) {
SDVariable ret = f().im2Col(in, config); SDVariable ret = f().im2Col(in, config);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/** /**
* 2D convolution layer operation - local response normalization * See {@link #localResponseNormalization(String, SDVariable, LocalResponseNormalizationConfig)}.
*
* @param inputs the inputs to lrn
* @param lrnConfig the configuration
* @return
*/ */
public SDVariable localResponseNormalization(SDVariable inputs, LocalResponseNormalizationConfig lrnConfig) { public SDVariable localResponseNormalization(@NonNull SDVariable inputs, @NonNull LocalResponseNormalizationConfig lrnConfig) {
return localResponseNormalization(null, inputs, lrnConfig); return localResponseNormalization(null, inputs, lrnConfig);
} }
@ -607,8 +567,8 @@ public class SDCNN extends SDOps {
* @param lrnConfig the configuration * @param lrnConfig the configuration
* @return * @return
*/ */
public SDVariable localResponseNormalization(String name, SDVariable input, public SDVariable localResponseNormalization(String name, @NonNull SDVariable input,
LocalResponseNormalizationConfig lrnConfig) { @NonNull LocalResponseNormalizationConfig lrnConfig) {
validateFloatingPoint("local response normalization", input); validateFloatingPoint("local response normalization", input);
SDVariable ret = f().localResponseNormalization(input, lrnConfig); SDVariable ret = f().localResponseNormalization(input, lrnConfig);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
@ -616,14 +576,9 @@ public class SDCNN extends SDOps {
/** /**
* 2D Convolution layer operation - max pooling 2d * See {@link #maxPooling2d(String, SDVariable, Pooling2DConfig)}.
*
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param pooling2DConfig the configuration
* @return Result after applying max pooling on the input
*/ */
public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { public SDVariable maxPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
return maxPooling2d(null, input, pooling2DConfig); return maxPooling2d(null, input, pooling2DConfig);
} }
@ -636,22 +591,16 @@ public class SDCNN extends SDOps {
* @param pooling2DConfig the configuration * @param pooling2DConfig the configuration
* @return Result after applying max pooling on the input * @return Result after applying max pooling on the input
*/ */
public SDVariable maxPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) { public SDVariable maxPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
validateNumerical("maxPooling2d", input); validateNumerical("maxPooling2d", input);
SDVariable ret = f().maxPooling2d(input, pooling2DConfig); SDVariable ret = f().maxPooling2d(input, pooling2DConfig);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/** /**
* 3D convolution layer operation - max pooling 3d operation. * See {@link #maxPooling3d(String, SDVariable, Pooling3DConfig)}.
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param pooling3DConfig the configuration
* @return Result after applying max pooling on the input
*/ */
public SDVariable maxPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) { public SDVariable maxPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
return maxPooling3d(null, input, pooling3DConfig); return maxPooling3d(null, input, pooling3DConfig);
} }
@ -665,7 +614,7 @@ public class SDCNN extends SDOps {
* @param pooling3DConfig the configuration * @param pooling3DConfig the configuration
* @return Result after applying max pooling on the input * @return Result after applying max pooling on the input
*/ */
public SDVariable maxPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) { public SDVariable maxPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
validateNumerical("maxPooling3d", input); validateNumerical("maxPooling3d", input);
SDVariable ret = f().maxPooling3d(input, pooling3DConfig); SDVariable ret = f().maxPooling3d(input, pooling3DConfig);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
@ -673,21 +622,30 @@ public class SDCNN extends SDOps {
/** /**
* Separable 2D convolution operation without bias * See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]
* May be null
* @param config Conv2DConfig configuration
* @return result of separable convolution 2d operation
*/ */
public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights, public SDVariable separableConv2d(SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
Conv2DConfig config) { @NonNull Conv2DConfig config) {
return separableConv2d(layerInput, depthWeights, pointWeights, null, config); return separableConv2d(layerInput, depthWeights, pointWeights, null, config);
} }
/**
* See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*/
public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
@NonNull Conv2DConfig config) {
return separableConv2d(layerInput, depthWeights, pointWeights, null, config);
}
/**
* See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
*/
public SDVariable separableConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
SDVariable bias, @NonNull Conv2DConfig config) {
return separableConv2d(null, layerInput, depthWeights, pointWeights, bias, config);
}
/** /**
* Separable 2D convolution operation with optional bias * Separable 2D convolution operation with optional bias
* *
@ -700,8 +658,8 @@ public class SDCNN extends SDOps {
* @param config Conv2DConfig configuration * @param config Conv2DConfig configuration
* @return result of separable convolution 2d operation * @return result of separable convolution 2d operation
*/ */
public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights, public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
SDVariable bias, Conv2DConfig config) { SDVariable bias, @NonNull Conv2DConfig config) {
validateFloatingPoint("separableConv2d", "input", layerInput); validateFloatingPoint("separableConv2d", "input", layerInput);
validateFloatingPoint("separableConv2d", "depthWeights", depthWeights); validateFloatingPoint("separableConv2d", "depthWeights", depthWeights);
validateFloatingPoint("separableConv2d", "pointWeights", pointWeights); validateFloatingPoint("separableConv2d", "pointWeights", pointWeights);
@ -712,18 +670,13 @@ public class SDCNN extends SDOps {
arr[2] = pointWeights; arr[2] = pointWeights;
if (bias != null) if (bias != null)
arr[3] = bias; arr[3] = bias;
return sconv2d(arr, config); return sconv2d(name, arr, config);
} }
/** /**
* Separable 2D convolution operation with/without optional bias * See {@link #sconv2d(String, SDVariable[], Conv2DConfig)}.
*
* @param inputs the inputs to separable conv2 operation. Should be length 3 (layerInput, depthWeights, pointWeights)
* or length 4 (layerInput, depthWeights, pointWeights, bias) as described in {@link #separableConv2d(SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}
* @param conv2DConfig the configuration
* @return result of separable convolution 2d operation
*/ */
public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { public SDVariable sconv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) {
return sconv2d(null, inputs, conv2DConfig); return sconv2d(null, inputs, conv2DConfig);
} }
@ -736,7 +689,7 @@ public class SDCNN extends SDOps {
* @param conv2DConfig the configuration * @param conv2DConfig the configuration
* @return result of separable convolution 2d operation * @return result of separable convolution 2d operation
*/ */
public SDVariable sconv2d(String name, SDVariable[] inputs, Conv2DConfig conv2DConfig) { public SDVariable sconv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) {
for(SDVariable v : inputs) for(SDVariable v : inputs)
validateFloatingPoint("sconv2d", v); validateFloatingPoint("sconv2d", v);
SDVariable ret = f().sconv2d(inputs, conv2DConfig); SDVariable ret = f().sconv2d(inputs, conv2DConfig);
@ -747,7 +700,7 @@ public class SDCNN extends SDOps {
/** /**
* @see #spaceToBatch(String, SDVariable, int[], int[][]) * @see #spaceToBatch(String, SDVariable, int[], int[][])
*/ */
public SDVariable spaceToBatch(SDVariable x, int[] blocks, int[][] padding) { public SDVariable spaceToBatch(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) {
return spaceToBatch(null, x, blocks, padding); return spaceToBatch(null, x, blocks, padding);
} }
@ -762,7 +715,7 @@ public class SDCNN extends SDOps {
* @return Output variable * @return Output variable
* @see #batchToSpace(String, SDVariable, int[], int[][]) * @see #batchToSpace(String, SDVariable, int[], int[][])
*/ */
public SDVariable spaceToBatch(String name, SDVariable x, int[] blocks, int[][] padding) { public SDVariable spaceToBatch(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) {
SDVariable ret = f().spaceToBatch(x, blocks, padding); SDVariable ret = f().spaceToBatch(x, blocks, padding);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
@ -770,7 +723,7 @@ public class SDCNN extends SDOps {
/** /**
* @see #spaceToDepth(String, SDVariable, int, String) * @see #spaceToDepth(String, SDVariable, int, String)
*/ */
public SDVariable spaceToDepth(SDVariable x, int blockSize, String dataFormat) { public SDVariable spaceToDepth(@NonNull SDVariable x, int blockSize, @NonNull String dataFormat) {
return spaceToDepth(null, x, blockSize, dataFormat); return spaceToDepth(null, x, blockSize, dataFormat);
} }
@ -788,23 +741,39 @@ public class SDCNN extends SDOps {
* @return Output variable * @return Output variable
* @see #depthToSpace(String, SDVariable, int, String) * @see #depthToSpace(String, SDVariable, int, String)
*/ */
public SDVariable spaceToDepth(String name, SDVariable x, int blockSize, String dataFormat) { public SDVariable spaceToDepth(String name, @NonNull SDVariable x, int blockSize, @NonNull String dataFormat) {
SDVariable ret = f().spaceToDepth(x, blockSize, dataFormat); SDVariable ret = f().spaceToDepth(x, blockSize, dataFormat);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/** /**
* 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format. * See {@link #upsampling2d(String, SDVariable, boolean, int, int)},
* scale is used for both height and width dimensions.
* *
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) * @param scale The scale for both height and width dimensions.
* @param scale Scale to upsample in both H and W dimensions
* @return Upsampled input
*/ */
public SDVariable upsampling2d(SDVariable input, int scale) { public SDVariable upsampling2d(@NonNull SDVariable input, int scale) {
return upsampling2d(null, input, true, scale, scale); return upsampling2d(null, input, true, scale, scale);
} }
/**
* See {@link #upsampling2d(String, SDVariable, boolean, int, int)},
* scale is used for both height and width dimensions.
*
* @param scale The scale for both height and width dimensions.
*/
public SDVariable upsampling2d(String name, @NonNull SDVariable input, int scale) {
return upsampling2d(name, input, true, scale, scale);
}
/**
* See {@link #upsampling2d(String, SDVariable, boolean, int, int)}.
*/
public SDVariable upsampling2d(@NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) {
return upsampling2d(null, input, nchw, scaleH, scaleW);
}
/** /**
* 2D Convolution layer operation - Upsampling 2d * 2D Convolution layer operation - Upsampling 2d
* *
@ -814,33 +783,8 @@ public class SDCNN extends SDOps {
* @param scaleW Scale to upsample in width dimension * @param scaleW Scale to upsample in width dimension
* @return Upsampled input * @return Upsampled input
*/ */
public SDVariable upsampling2d(String name, SDVariable input, boolean nchw, int scaleH, int scaleW) { public SDVariable upsampling2d(String name, @NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) {
SDVariable ret = f().upsampling2d(input, nchw, scaleH, scaleW); SDVariable ret = f().upsampling2d(input, nchw, scaleH, scaleW);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/**
* 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format.
*
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
* @param scale Scale to upsample in both H and W dimensions
* @return Upsampled input
*/
public SDVariable upsampling2d(String name, SDVariable input, int scale) {
return upsampling2d(name, input, true, scale, scale);
}
/**
* 2D Convolution layer operation - Upsampling 2d
*
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
* or NHWC format (shape [minibatch, height, width, channels])
* @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format
* @param scaleH Scale to upsample in height dimension
* @param scaleW Scale to upsample in width dimension
* @return Upsampled input
*/
public SDVariable upsampling2d(SDVariable input, boolean nchw, int scaleH, int scaleW) {
return upsampling2d(null, input, nchw, scaleH, scaleW);
}
} }

View File

@ -31,7 +31,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.graph.DataType; import org.nd4j.graph.DType;
import org.nd4j.graph.FlatArray; import org.nd4j.graph.FlatArray;
import org.nd4j.graph.FlatNode; import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatProperties; import org.nd4j.graph.FlatProperties;
@ -66,33 +66,33 @@ public class FlatBuffersMapper {
public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) { public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) {
switch (type) { switch (type) {
case FLOAT: case FLOAT:
return DataType.FLOAT; return DType.FLOAT;
case DOUBLE: case DOUBLE:
return DataType.DOUBLE; return DType.DOUBLE;
case HALF: case HALF:
return DataType.HALF; return DType.HALF;
case INT: case INT:
return DataType.INT32; return DType.INT32;
case LONG: case LONG:
return DataType.INT64; return DType.INT64;
case BOOL: case BOOL:
return DataType.BOOL; return DType.BOOL;
case SHORT: case SHORT:
return DataType.INT16; return DType.INT16;
case BYTE: case BYTE:
return DataType.INT8; return DType.INT8;
case UBYTE: case UBYTE:
return DataType.UINT8; return DType.UINT8;
case UTF8: case UTF8:
return DataType.UTF8; return DType.UTF8;
case UINT16: case UINT16:
return DataType.UINT16; return DType.UINT16;
case UINT32: case UINT32:
return DataType.UINT32; return DType.UINT32;
case UINT64: case UINT64:
return DataType.UINT64; return DType.UINT64;
case BFLOAT16: case BFLOAT16:
return DataType.BFLOAT16; return DType.BFLOAT16;
default: default:
throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]"); throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]");
} }
@ -102,33 +102,33 @@ public class FlatBuffersMapper {
* This method converts enums for DataType * This method converts enums for DataType
*/ */
public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) { public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) {
if (val == DataType.FLOAT) { if (val == DType.FLOAT) {
return org.nd4j.linalg.api.buffer.DataType.FLOAT; return org.nd4j.linalg.api.buffer.DataType.FLOAT;
} else if (val == DataType.DOUBLE) { } else if (val == DType.DOUBLE) {
return org.nd4j.linalg.api.buffer.DataType.DOUBLE; return org.nd4j.linalg.api.buffer.DataType.DOUBLE;
} else if (val == DataType.HALF) { } else if (val == DType.HALF) {
return org.nd4j.linalg.api.buffer.DataType.HALF; return org.nd4j.linalg.api.buffer.DataType.HALF;
} else if (val == DataType.INT32) { } else if (val == DType.INT32) {
return org.nd4j.linalg.api.buffer.DataType.INT; return org.nd4j.linalg.api.buffer.DataType.INT;
} else if (val == DataType.INT64) { } else if (val == DType.INT64) {
return org.nd4j.linalg.api.buffer.DataType.LONG; return org.nd4j.linalg.api.buffer.DataType.LONG;
} else if (val == DataType.INT8) { } else if (val == DType.INT8) {
return org.nd4j.linalg.api.buffer.DataType.BYTE; return org.nd4j.linalg.api.buffer.DataType.BYTE;
} else if (val == DataType.BOOL) { } else if (val == DType.BOOL) {
return org.nd4j.linalg.api.buffer.DataType.BOOL; return org.nd4j.linalg.api.buffer.DataType.BOOL;
} else if (val == DataType.UINT8) { } else if (val == DType.UINT8) {
return org.nd4j.linalg.api.buffer.DataType.UBYTE; return org.nd4j.linalg.api.buffer.DataType.UBYTE;
} else if (val == DataType.INT16) { } else if (val == DType.INT16) {
return org.nd4j.linalg.api.buffer.DataType.SHORT; return org.nd4j.linalg.api.buffer.DataType.SHORT;
} else if (val == DataType.UTF8) { } else if (val == DType.UTF8) {
return org.nd4j.linalg.api.buffer.DataType.UTF8; return org.nd4j.linalg.api.buffer.DataType.UTF8;
} else if (val == DataType.UINT16) { } else if (val == DType.UINT16) {
return org.nd4j.linalg.api.buffer.DataType.UINT16; return org.nd4j.linalg.api.buffer.DataType.UINT16;
} else if (val == DataType.UINT32) { } else if (val == DType.UINT32) {
return org.nd4j.linalg.api.buffer.DataType.UINT32; return org.nd4j.linalg.api.buffer.DataType.UINT32;
} else if (val == DataType.UINT64) { } else if (val == DType.UINT64) {
return org.nd4j.linalg.api.buffer.DataType.UINT64; return org.nd4j.linalg.api.buffer.DataType.UINT64;
} else if (val == DataType.BFLOAT16){ } else if (val == DType.BFLOAT16){
return org.nd4j.linalg.api.buffer.DataType.BFLOAT16; return org.nd4j.linalg.api.buffer.DataType.BFLOAT16;
} else { } else {
throw new RuntimeException("Unknown datatype: " + val); throw new RuntimeException("Unknown datatype: " + val);

View File

@ -2,8 +2,8 @@
package org.nd4j.graph; package org.nd4j.graph;
public final class DataType { public final class DType {
private DataType() { } private DType() { }
public static final byte INHERIT = 0; public static final byte INHERIT = 0;
public static final byte BOOL = 1; public static final byte BOOL = 1;
public static final byte FLOAT8 = 2; public static final byte FLOAT8 = 2;

View File

@ -353,6 +353,10 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D.class,

View File

@ -1149,16 +1149,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride(), elementWiseStride(), order, this.dataType(), isEmpty())); setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride(), elementWiseStride(), order, this.dataType(), isEmpty()));
} }
@Override
public void setShape(long[] shape) {
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride(), elementWiseStride(), ordering(), this.dataType(), isEmpty()));
}
@Override
public void setStride(long[] stride) {
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride, elementWiseStride(), ordering(), this.dataType(), isEmpty()));
}
@Override @Override
public void setShapeAndStride(int[] shape, int[] stride) { public void setShapeAndStride(int[] shape, int[] stride) {
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 0, ordering(), this.dataType(), false)); setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 0, ordering(), this.dataType(), false));
@ -1283,29 +1273,16 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return scalar.getDouble(0); return scalar.getDouble(0);
} }
/**
* Returns entropy value for this INDArray
* @return
*/
@Override @Override
public Number entropyNumber() { public Number entropyNumber() {
return entropy(Integer.MAX_VALUE).getDouble(0); return entropy(Integer.MAX_VALUE).getDouble(0);
} }
/**
* Returns non-normalized Shannon entropy value for this INDArray
* @return
*/
@Override @Override
public Number shannonEntropyNumber() { public Number shannonEntropyNumber() {
return shannonEntropy(Integer.MAX_VALUE).getDouble(0); return shannonEntropy(Integer.MAX_VALUE).getDouble(0);
} }
/**
* Returns log entropy value for this INDArray
* @return
*/
@Override @Override
public Number logEntropyNumber() { public Number logEntropyNumber() {
return logEntropy(Integer.MAX_VALUE).getDouble(0); return logEntropy(Integer.MAX_VALUE).getDouble(0);
@ -2297,37 +2274,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return size(0); return size(0);
} }
@Override
public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
Nd4j.getCompressor().autoDecompress(this);
int n = shape.length;
// FIXME: shapeInfo should be used here
if (shape.length < 1)
return create(Nd4j.createBufferDetached(shape));
if (offsets.length != n)
throw new IllegalArgumentException("Invalid offset " + Arrays.toString(offsets));
if (stride.length != n)
throw new IllegalArgumentException("Invalid stride " + Arrays.toString(stride));
if (Shape.contentEquals(shape, shapeOf())) {
if (ArrayUtil.isZero(offsets)) {
return this;
} else {
throw new IllegalArgumentException("Invalid subArray offsets");
}
}
long[] dotProductOffsets = offsets;
int[] dotProductStride = stride;
long offset = Shape.offset(jvmShapeInfo.javaShapeInformation) + NDArrayIndex.offset(dotProductStride, dotProductOffsets);
if (offset >= data().length())
offset = ArrayUtil.sumLong(offsets);
return create(data, Arrays.copyOf(shape, shape.length), stride, offset, ordering());
}
protected INDArray create(DataBuffer buffer) { protected INDArray create(DataBuffer buffer) {
return Nd4j.create(buffer); return Nd4j.create(buffer);
} }
@ -4016,58 +3962,30 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().exec(new AMin(this, dimension)); return Nd4j.getExecutioner().exec(new AMin(this, dimension));
} }
/**
* Returns the sum along the specified dimension(s) of this ndarray
*
* @param dimension the dimension to getScalar the sum along
* @return the sum along the specified dimension of this ndarray
*/
@Override @Override
public INDArray sum(int... dimension) { public INDArray sum(int... dimension) {
validateNumericalArray("sum", true); validateNumericalArray("sum", true);
return Nd4j.getExecutioner().exec(new Sum(this, dimension)); return Nd4j.getExecutioner().exec(new Sum(this, dimension));
} }
/**
* Returns the sum along the last dimension of this ndarray
*
* @param dimension the dimension to getScalar the sum along
* @return the sum along the specified dimension of this ndarray
*/
@Override @Override
public INDArray sum(boolean keepDim, int... dimension) { public INDArray sum(boolean keepDim, int... dimension) {
validateNumericalArray("sum", true); validateNumericalArray("sum", true);
return Nd4j.getExecutioner().exec(new Sum(this, null, keepDim, dimension)); return Nd4j.getExecutioner().exec(new Sum(this, null, keepDim, dimension));
} }
/**
* Returns entropy along dimension
* @param dimension
* @return
*/
@Override @Override
public INDArray entropy(int... dimension) { public INDArray entropy(int... dimension) {
validateNumericalArray("entropy", false); validateNumericalArray("entropy", false);
return Nd4j.getExecutioner().exec(new Entropy(this, dimension)); return Nd4j.getExecutioner().exec(new Entropy(this, dimension));
} }
/**
* Returns non-normalized Shannon entropy along dimension
* @param dimension
* @return
*/
@Override @Override
public INDArray shannonEntropy(int... dimension) { public INDArray shannonEntropy(int... dimension) {
validateNumericalArray("shannonEntropy", false); validateNumericalArray("shannonEntropy", false);
return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension)); return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension));
} }
/**
* Returns log entropy along dimension
* @param dimension
* @return
*/
@Override @Override
public INDArray logEntropy(int... dimension) { public INDArray logEntropy(int... dimension) {
validateNumericalArray("logEntropy", false); validateNumericalArray("logEntropy", false);

View File

@ -468,16 +468,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public void setStride(long... stride) {
throw new UnsupportedOperationException();
}
@Override
public void setShape(long... shape) {
throw new UnsupportedOperationException();
}
@Override @Override
public INDArray putScalar(long row, long col, double value) { public INDArray putScalar(long row, long col, double value) {
return null; return null;
@ -1284,17 +1274,10 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
@Override @Override
public void setShapeAndStride(int[] shape, int[] stride) { public void setShapeAndStride(int[] shape, int[] stride) {
} }
@Override @Override
public void setOrder(char order) { public void setOrder(char order) {
}
@Override
public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
return null;
} }
@Override @Override
@ -1842,49 +1825,26 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
return null; return null;
} }
/**
* Returns entropy value for this INDArray
* @return
*/
@Override @Override
public Number entropyNumber() { public Number entropyNumber() {
return entropy(Integer.MAX_VALUE).getDouble(0); return entropy(Integer.MAX_VALUE).getDouble(0);
} }
/**
* Returns non-normalized Shannon entropy value for this INDArray
* @return
*/
@Override @Override
public Number shannonEntropyNumber() { public Number shannonEntropyNumber() {
return shannonEntropy(Integer.MAX_VALUE).getDouble(0); return shannonEntropy(Integer.MAX_VALUE).getDouble(0);
} }
/**
* Returns log entropy value for this INDArray
* @return
*/
@Override @Override
public Number logEntropyNumber() { public Number logEntropyNumber() {
return logEntropy(Integer.MAX_VALUE).getDouble(0); return logEntropy(Integer.MAX_VALUE).getDouble(0);
} }
/**
* Returns entropy along dimension
* @param dimension
* @return
*/
@Override @Override
public INDArray entropy(int... dimension) { public INDArray entropy(int... dimension) {
return Nd4j.getExecutioner().exec(new Entropy(this, dimension)); return Nd4j.getExecutioner().exec(new Entropy(this, dimension));
} }
/**
* Returns non-normalized Shannon entropy along dimension
* @param dimension
* @return
*/
@Override @Override
public INDArray shannonEntropy(int... dimension) { public INDArray shannonEntropy(int... dimension) {
return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension)); return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension));

View File

@ -1016,13 +1016,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray {
return extendedFlags; return extendedFlags;
} }
@Override
public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
throw new UnsupportedOperationException();
}
/** /**
* Returns the underlying indices of the element of the given index * Returns the underlying indices of the element of the given index
* such as there really are in the original ndarray * such as there really are in the original ndarray
@ -1138,16 +1131,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray {
return null; return null;
} }
@Override
public void setStride(long... stride) {
}
@Override
public void setShape(long... shape) {
}
/** /**
* This method returns true if this INDArray is special case: no-value INDArray * This method returns true if this INDArray is special case: no-value INDArray
* *

View File

@ -213,11 +213,6 @@ public abstract class BaseSparseNDArrayCSR extends BaseSparseNDArray {
return shapeInformation; return shapeInformation;
} }
@Override
public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
throw new UnsupportedOperationException();
}
@Override @Override
public boolean equals(Object o) { public boolean equals(Object o) {
//TODO use op //TODO use op

View File

@ -1854,63 +1854,47 @@ public interface INDArray extends Serializable, AutoCloseable {
/** /**
* Returns entropy value for this INDArray * Returns entropy value for this INDArray
* @return * @return entropy value
*/ */
Number entropyNumber(); Number entropyNumber();
/** /**
* Returns non-normalized Shannon entropy value for this INDArray * Returns non-normalized Shannon entropy value for this INDArray
* @return * @return non-normalized Shannon entropy
*/ */
Number shannonEntropyNumber(); Number shannonEntropyNumber();
/** /**
* Returns log entropy value for this INDArray * Returns log entropy value for this INDArray
* @return * @return log entropy value
*/ */
Number logEntropyNumber(); Number logEntropyNumber();
/** /**
* Returns entropy value for this INDArray along specified dimension(s) * Returns entropy value for this INDArray along specified dimension(s)
* @return * @param dimension specified dimension(s)
* @return entropy value
*/ */
INDArray entropy(int... dimension); INDArray entropy(int... dimension);
/** /**
* Returns entropy value for this INDArray along specified dimension(s) * Returns Shannon entropy value for this INDArray along specified dimension(s)
* @return * @param dimension specified dimension(s)
* @return Shannon entropy
*/ */
INDArray shannonEntropy(int... dimension); INDArray shannonEntropy(int... dimension);
/** /**
* Returns entropy value for this INDArray along specified dimension(s) * Returns log entropy value for this INDArray along specified dimension(s)
* @return * @param dimension specified dimension(s)
* @return log entropy value
*/ */
INDArray logEntropy(int... dimension); INDArray logEntropy(int... dimension);
/**
* stride setter
* @param stride
* @deprecated, use {@link #reshape(int...) }
*/
@Deprecated
void setStride(long... stride);
/**
* Shape setter
* @param shape
* @deprecated, use {@link #reshape(int...) }
*/
@Deprecated
void setShape(long... shape);
/** /**
* Shape and stride setter * Shape and stride setter
* @param shape * @param shape new value for shape
* @param stride * @param stride new value for stride
*/ */
void setShapeAndStride(int[] shape, int[] stride); void setShapeAndStride(int[] shape, int[] stride);
@ -1920,14 +1904,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
void setOrder(char order); void setOrder(char order);
/**
* @param offsets
* @param shape
* @param stride
* @return
*/
INDArray subArray(long[] offsets, int[] shape, int[] stride);
/** /**
* Returns the elements at the specified indices * Returns the elements at the specified indices
* *

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.Onnx; import onnx.Onnx;
@ -53,19 +54,19 @@ public class AvgPooling2D extends DynamicCustomOp {
} }
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
public AvgPooling2D(SameDiff sameDiff, SDVariable input, INDArray arrayInput, INDArray arrayOutput, Pooling2DConfig config) { public AvgPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) {
super(null, sameDiff, new SDVariable[]{input}, false); super(sameDiff, new SDVariable[]{input});
if (arrayInput != null) {
addInputArgument(arrayInput);
}
if (arrayOutput != null) {
addOutputArgument(arrayOutput);
}
config.setType(Pooling2D.Pooling2DType.AVG); config.setType(Pooling2D.Pooling2DType.AVG);
this.config = config;
addArgs();
}
public AvgPooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){
super(new INDArray[]{input}, wrapOrNull(output));
config.setType(Pooling2D.Pooling2DType.AVG);
this.sameDiff = sameDiff;
this.config = config; this.config = config;
addArgs(); addArgs();
} }

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.Onnx; import onnx.Onnx;
@ -39,6 +40,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -59,18 +61,28 @@ public class Conv1D extends DynamicCustomOp {
protected Conv1DConfig config; protected Conv1DConfig config;
private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s "; private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s ";
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
public Conv1D(SameDiff sameDiff, public Conv1D(SameDiff sameDiff,
SDVariable[] inputFunctions, SDVariable[] inputFunctions,
INDArray[] inputArrays, INDArray[] outputs,
Conv1DConfig config) { Conv1DConfig config) {
super(null, inputArrays, outputs); super(sameDiff, inputFunctions);
this.sameDiff = sameDiff; initConfig(config);
}
public Conv1D(INDArray[] inputs, INDArray[] outputs, Conv1DConfig config){
super(inputs, outputs);
initConfig(config);
}
public Conv1D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv1DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
}
private void initConfig(Conv1DConfig config){
this.config = config; this.config = config;
Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP()); Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP());
addArgs(); addArgs();
sameDiff.putOpForId(this.getOwnName(), this);
sameDiff.addArgsFor(inputFunctions, this);
} }
protected void addArgs() { protected void addArgs() {

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.Onnx; import onnx.Onnx;
@ -56,23 +57,32 @@ public class Conv2D extends DynamicCustomOp {
protected Conv2DConfig config; protected Conv2DConfig config;
private static final String INVALID_CONFIGURATION = "Invalid Conv2D configuration : sW = %s pH = %s dW = %s "; private static final String INVALID_CONFIGURATION = "Invalid Conv2D configuration : sW = %s pH = %s dW = %s ";
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
public Conv2D(SameDiff sameDiff, public Conv2D(SameDiff sameDiff,
SDVariable[] inputFunctions, SDVariable[] inputFunctions,
INDArray[] inputArrays, INDArray[] outputs,
Conv2DConfig config) { Conv2DConfig config) {
super(null, inputArrays, outputs); super(sameDiff, inputFunctions);
this.sameDiff = sameDiff;
initConfig(config);
}
public Conv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
super(inputs, outputs);
initConfig(config);
}
public Conv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
}
protected void initConfig(Conv2DConfig config){
this.config = config; this.config = config;
Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1, Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1,
INVALID_CONFIGURATION, INVALID_CONFIGURATION,
config.getSH(), config.getPH(), config.getDW()); config.getSH(), config.getPH(), config.getDW());
addArgs(); addArgs();
if(sameDiff != null) {
sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point
sameDiff.addArgsFor(inputFunctions, this);
}
} }
protected void addArgs() { protected void addArgs() {
@ -252,7 +262,6 @@ public class Conv2D extends DynamicCustomOp {
Conv2DDerivative conv2DDerivative = Conv2DDerivative.derivativeBuilder() Conv2DDerivative conv2DDerivative = Conv2DDerivative.derivativeBuilder()
.sameDiff(sameDiff) .sameDiff(sameDiff)
.config(config) .config(config)
.outputs(outputArguments())
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()])) .inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
.build(); .build();
List<SDVariable> ret = Arrays.asList(conv2DDerivative.outputVariables()); List<SDVariable> ret = Arrays.asList(conv2DDerivative.outputVariables());

View File

@ -37,8 +37,8 @@ import java.util.List;
public class Conv2DDerivative extends Conv2D { public class Conv2DDerivative extends Conv2D {
@Builder(builderMethodName = "derivativeBuilder") @Builder(builderMethodName = "derivativeBuilder")
public Conv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig config) { public Conv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig config) {
super(sameDiff, inputFunctions, inputArrays, outputs, config); super(sameDiff, inputFunctions, config);
} }
public Conv2DDerivative() {} public Conv2DDerivative() {}

View File

@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
@ -33,6 +34,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -55,25 +57,27 @@ public class Conv3D extends DynamicCustomOp {
public Conv3D() { public Conv3D() {
} }
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig config) {
Conv3DConfig conv3DConfig) { super(sameDiff, inputFunctions);
super(null, sameDiff, inputFunctions, false); initConfig(config);
setSameDiff(sameDiff); }
if (inputs != null) public Conv3D(INDArray[] inputs, INDArray[] outputs, Conv3DConfig config){
addInputArgument(inputs); super(inputs, outputs);
if (outputs != null) initConfig(config);
addOutputArgument(outputs); }
this.config = conv3DConfig;
public Conv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv3DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
}
private void initConfig(Conv3DConfig config){
this.config = config;
Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1, Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1,
INVALID_CONFIGURATION, INVALID_CONFIGURATION,
config.getSW(), config.getPH(), config.getDW()); config.getSW(), config.getPH(), config.getDW());
addArgs(); addArgs();
//for (val arg: iArgs())
// System.out.println(getIArgument(arg));
} }
@ -259,8 +263,6 @@ public class Conv3D extends DynamicCustomOp {
inputs.add(f1.get(0)); inputs.add(f1.get(0));
Conv3DDerivative conv3DDerivative = Conv3DDerivative.derivativeBuilder() Conv3DDerivative conv3DDerivative = Conv3DDerivative.derivativeBuilder()
.conv3DConfig(config) .conv3DConfig(config)
.inputFunctions(args())
.outputs(outputArguments())
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()])) .inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
.sameDiff(sameDiff) .sameDiff(sameDiff)
.build(); .build();

View File

@ -39,8 +39,8 @@ public class Conv3DDerivative extends Conv3D {
public Conv3DDerivative() {} public Conv3DDerivative() {}
@Builder(builderMethodName = "derivativeBuilder") @Builder(builderMethodName = "derivativeBuilder")
public Conv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, Conv3DConfig conv3DConfig) { public Conv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig conv3DConfig) {
super(sameDiff, inputFunctions, inputs, outputs, conv3DConfig); super(sameDiff, inputFunctions, conv3DConfig);
} }
@Override @Override

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.Onnx; import onnx.Onnx;
@ -31,6 +32,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
@ -51,25 +53,25 @@ public class DeConv2D extends DynamicCustomOp {
protected DeConv2DConfig config; protected DeConv2DConfig config;
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
public DeConv2D(SameDiff sameDiff, public DeConv2D(SameDiff sameDiff,
SDVariable[] inputs, SDVariable[] inputs,
INDArray[] inputArrays, INDArray[] outputs,
DeConv2DConfig config) { DeConv2DConfig config) {
super(null, inputArrays, outputs); super(sameDiff, inputs);
this.sameDiff = sameDiff;
this.config = config; this.config = config;
if (inputArrays != null) { addArgs();
addInputArgument(inputArrays);
}
if (outputs != null) {
addOutputArgument(outputs);
} }
public DeConv2D(INDArray[] inputs, INDArray[] outputs, DeConv2DConfig config){
super(inputs, outputs);
this.config = config;
addArgs(); addArgs();
sameDiff.putOpForId(this.getOwnName(), this); }
sameDiff.addArgsFor(inputs, this);
public DeConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull DeConv2DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
} }
@Override @Override

View File

@ -40,8 +40,8 @@ public class DeConv2DDerivative extends DeConv2D {
public DeConv2DDerivative() {} public DeConv2DDerivative() {}
@Builder(builderMethodName = "derivativeBuilder") @Builder(builderMethodName = "derivativeBuilder")
public DeConv2DDerivative(SameDiff sameDiff, SDVariable[] inputs, INDArray[] inputArrays, INDArray[] outputs, DeConv2DConfig config) { public DeConv2DDerivative(SameDiff sameDiff, SDVariable[] inputs, DeConv2DConfig config) {
super(sameDiff, inputs, inputArrays, outputs, config); super(sameDiff, inputs, config);
} }
@Override @Override

View File

@ -53,25 +53,21 @@ public class DeConv2DTF extends DynamicCustomOp {
protected DeConv2DConfig config; protected DeConv2DConfig config;
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
public DeConv2DTF(SameDiff sameDiff, public DeConv2DTF(SameDiff sameDiff,
SDVariable[] inputs, SDVariable[] inputs,
INDArray[] inputArrays, INDArray[] outputs,
DeConv2DConfig config) { DeConv2DConfig config) {
super(null, inputArrays, outputs); super(sameDiff, inputs);
this.sameDiff = sameDiff;
this.config = config;
addArgs();
}
public DeConv2DTF(INDArray[] inputs, INDArray[] outputs, DeConv2DConfig config){
super(inputs, outputs);
this.config = config; this.config = config;
if (inputArrays != null) {
addInputArgument(inputArrays);
}
if (outputs != null) {
addOutputArgument(outputs);
}
addArgs(); addArgs();
sameDiff.putOpForId(this.getOwnName(), this);
sameDiff.addArgsFor(inputs, this);
} }
@Override @Override

View File

@ -28,6 +28,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
@ -53,12 +54,23 @@ public class DeConv3D extends DynamicCustomOp {
protected DeConv3DConfig config; protected DeConv3DConfig config;
public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, DeConv3DConfig config) { public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
super(sameDiff, toArr(input, weights, bias)); super(sameDiff, toArr(input, weights, bias));
this.config = config; this.config = config;
addArgs(); addArgs();
} }
public DeConv3D(INDArray[] inputs, INDArray[] outputs, DeConv3DConfig config){
super(inputs, outputs);
this.config = config;
addArgs();
}
public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull DeConv3DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
}
private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){ private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){
if(bias != null){ if(bias != null){
return new SDVariable[]{input, weights, bias}; return new SDVariable[]{input, weights, bias};

View File

@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.Onnx; import onnx.Onnx;
@ -35,6 +36,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -53,17 +55,25 @@ public class DepthwiseConv2D extends DynamicCustomOp {
protected Conv2DConfig config; protected Conv2DConfig config;
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
public DepthwiseConv2D(SameDiff sameDiff, public DepthwiseConv2D(SameDiff sameDiff,
SDVariable[] inputFunctions, SDVariable[] inputFunctions,
INDArray[] inputArrays, INDArray[] outputs,
Conv2DConfig config) { Conv2DConfig config) {
super(null, inputArrays, outputs); super(sameDiff, inputFunctions);
this.sameDiff = sameDiff;
this.config = config; this.config = config;
addArgs(); addArgs();
sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point }
sameDiff.addArgsFor(inputFunctions, this);
public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
super(inputs, outputs);
this.config = config;
addArgs();
}
public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
} }
public DepthwiseConv2D() { public DepthwiseConv2D() {

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.Onnx; import onnx.Onnx;
@ -48,18 +49,19 @@ public class LocalResponseNormalization extends DynamicCustomOp {
protected LocalResponseNormalizationConfig config; protected LocalResponseNormalizationConfig config;
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] inputFunctions, public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] inputFunctions, boolean inPlace,
INDArray[] inputs, INDArray[] outputs,boolean inPlace,
LocalResponseNormalizationConfig config) { LocalResponseNormalizationConfig config) {
super(null,sameDiff, inputFunctions, inPlace); super(null,sameDiff, inputFunctions, inPlace);
this.config = config; this.config = config;
if(inputs != null) { addArgs();
addInputArgument(inputs);
}
if(outputs!= null) {
addOutputArgument(outputs);
} }
public LocalResponseNormalization(@NonNull INDArray input, INDArray output, @NonNull LocalResponseNormalizationConfig config){
super(new INDArray[]{input}, wrapOrNull(output));
this.config = config;
addArgs(); addArgs();
} }

View File

@ -33,8 +33,8 @@ import java.util.List;
@Slf4j @Slf4j
public class LocalResponseNormalizationDerivative extends LocalResponseNormalization { public class LocalResponseNormalizationDerivative extends LocalResponseNormalization {
@Builder(builderMethodName = "derivativeBuilder") @Builder(builderMethodName = "derivativeBuilder")
public LocalResponseNormalizationDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, boolean inPlace, LocalResponseNormalizationConfig config) { public LocalResponseNormalizationDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, boolean inPlace, LocalResponseNormalizationConfig config) {
super(sameDiff, inputFunctions, inputs, outputs, inPlace, config); super(sameDiff, inputFunctions, inPlace, config);
} }
public LocalResponseNormalizationDerivative() {} public LocalResponseNormalizationDerivative() {}

View File

@ -51,27 +51,18 @@ public class MaxPooling2D extends DynamicCustomOp {
public MaxPooling2D() { public MaxPooling2D() {
} }
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
@SuppressWarnings("Used in lombok") @SuppressWarnings("Used in lombok")
public MaxPooling2D(SameDiff sameDiff, SDVariable input, INDArray arrayInput, INDArray arrayOutput, Pooling2DConfig config) { public MaxPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) {
super(null, sameDiff, new SDVariable[]{input}, false); super(null, sameDiff, new SDVariable[]{input}, false);
if (arrayInput != null) {
addInputArgument(arrayInput);
}
if (arrayOutput != null) {
addOutputArgument(arrayOutput);
}
config.setType(Pooling2D.Pooling2DType.MAX); config.setType(Pooling2D.Pooling2DType.MAX);
this.config = config; this.config = config;
this.sameDiff = sameDiff;
addArgs(); addArgs();
} }
public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){ public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){
super(null, new INDArray[]{input}, output == null ? null : new INDArray[]{output}); super(null, new INDArray[]{input}, wrapOrNull(output));
config.setType(Pooling2D.Pooling2DType.MAX); config.setType(Pooling2D.Pooling2DType.MAX);
this.config = config; this.config = config;

View File

@ -16,8 +16,14 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution; package org.nd4j.linalg.api.ops.impl.layers.convolution;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.Onnx; import onnx.Onnx;
@ -33,9 +39,6 @@ import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
import java.lang.reflect.Field;
import java.util.*;
/** /**
* Pooling2D operation * Pooling2D operation
@ -70,21 +73,27 @@ public class Pooling2D extends DynamicCustomOp {
public Pooling2D() {} public Pooling2D() {}
@Builder(builderMethodName = "builder") @Builder(builderMethodName = "sameDiffBuilder")
@SuppressWarnings("Used in lombok") @SuppressWarnings("Used in lombok")
public Pooling2D(SameDiff sameDiff, SDVariable[] inputs,INDArray[] arrayInputs, INDArray[] arrayOutputs,Pooling2DConfig config) { public Pooling2D(SameDiff sameDiff, SDVariable[] inputs,
super(null,sameDiff, inputs, false); Pooling2DConfig config) {
if(arrayInputs != null) { super(null, sameDiff, inputs, false);
addInputArgument(arrayInputs);
}
if(arrayOutputs != null) {
addOutputArgument(arrayOutputs);
}
this.config = config; this.config = config;
addArgs();
}
public Pooling2D(@NonNull INDArray[] inputs, INDArray[] outputs, @NonNull Pooling2DConfig config){
super(inputs, outputs);
this.config = config;
addArgs();
}
public Pooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){
super(new INDArray[]{input}, wrapOrNull(output));
this.config = config;
addArgs(); addArgs();
} }

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution; package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -36,8 +37,12 @@ import java.util.List;
@Slf4j @Slf4j
public class Pooling2DDerivative extends Pooling2D { public class Pooling2DDerivative extends Pooling2D {
@Builder(builderMethodName = "derivativeBuilder") @Builder(builderMethodName = "derivativeBuilder")
public Pooling2DDerivative(SameDiff sameDiff, SDVariable[] inputs, INDArray[] arrayInputs, INDArray[] arrayOutputs, Pooling2DConfig config) { public Pooling2DDerivative(SameDiff sameDiff, SDVariable[] inputs, Pooling2DConfig config) {
super(sameDiff, inputs, arrayInputs, arrayOutputs, config); super(sameDiff, inputs, config);
}
public Pooling2DDerivative(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Pooling2DConfig config){
super(new INDArray[]{input, grad}, wrapOrNull(output), config);
} }
public Pooling2DDerivative() {} public Pooling2DDerivative() {}

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution; package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder; import lombok.Builder;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -39,9 +40,17 @@ import java.util.List;
@Slf4j @Slf4j
public class SConv2D extends Conv2D { public class SConv2D extends Conv2D {
@Builder(builderMethodName = "sBuilder") @Builder(builderMethodName = "sameDiffSBuilder")
public SConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig conv2DConfig) { public SConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig conv2DConfig) {
super(sameDiff, inputFunctions, inputArrays, outputs, conv2DConfig); super(sameDiff, inputFunctions, conv2DConfig);
}
public SConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
super(inputs, outputs, config);
}
public SConv2D(@NonNull INDArray input, @NonNull INDArray depthWeights, INDArray pointWeights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
this(wrapFilterNull(input, depthWeights, pointWeights, bias), wrapOrNull(output), config);
} }
public SConv2D() {} public SConv2D() {}

View File

@ -38,8 +38,8 @@ import java.util.List;
public class SConv2DDerivative extends SConv2D { public class SConv2DDerivative extends SConv2D {
@Builder(builderMethodName = "sDerviativeBuilder") @Builder(builderMethodName = "sDerviativeBuilder")
public SConv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig conv2DConfig) { public SConv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig conv2DConfig) {
super(sameDiff, inputFunctions, inputArrays, outputs, conv2DConfig); super(sameDiff, inputFunctions, conv2DConfig);
} }
public SConv2DDerivative() {} public SConv2DDerivative() {}

View File

@ -0,0 +1,37 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
public class BitsHammingDistance extends DynamicCustomOp {
public BitsHammingDistance(){ }
public BitsHammingDistance(@NonNull SameDiff sd, @NonNull SDVariable x, @NonNull SDVariable y){
super(sd, new SDVariable[]{x, y});
}
public BitsHammingDistance(@NonNull INDArray x, @NonNull INDArray y){
super(new INDArray[]{x, y}, null);
}
@Override
public String opName() {
return "bits_hamming_distance";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected 2 input datatypes, got %s", dataTypes);
Preconditions.checkState(dataTypes.get(0).isIntType() && dataTypes.get(1).isIntType(), "Input datatypes must be integer type, got %s", dataTypes);
return Collections.singletonList(DataType.LONG);
}
}

View File

@ -0,0 +1,78 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Collections;
import java.util.List;
/**
* Bit-wise AND operation, broadcastable
*
* @author raver119@gmail.com
*/
public class BitwiseAnd extends BaseDynamicTransformOp {
public BitwiseAnd(SameDiff sameDiff, SDVariable x, SDVariable y) {
super(sameDiff, new SDVariable[] {x, y} ,false);
}
public BitwiseAnd(INDArray x, INDArray y, INDArray output) {
super(new INDArray[]{x, y}, new INDArray[]{output});
}
public BitwiseAnd(INDArray x, INDArray y) {
this(x, y,x.ulike());
}
public BitwiseAnd() {}
@Override
public String opName() {
return "bitwise_and";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
return "bitwise_and";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException("Not yet implemented: " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -0,0 +1,78 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Collections;
import java.util.List;
/**
* Bit-wise OR operation, broadcastable
*
* @author raver119@gmail.com
*/
public class BitwiseOr extends BaseDynamicTransformOp {
public BitwiseOr(SameDiff sameDiff, SDVariable x, SDVariable y) {
super(sameDiff, new SDVariable[] {x, y} ,false);
}
public BitwiseOr(INDArray x, INDArray y, INDArray output) {
super(new INDArray[]{x, y}, new INDArray[]{output});
}
public BitwiseOr(INDArray x, INDArray y) {
this(x, y,x.ulike());
}
public BitwiseOr() {}
@Override
public String opName() {
return "bitwise_or";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
return "bitwise_or";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException("Not yet implemented: " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -0,0 +1,78 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Collections;
import java.util.List;
/**
* Bit-wise XOR operation, broadcastable
*
* @author raver119@gmail.com
*/
public class BitwiseXor extends BaseDynamicTransformOp {
public BitwiseXor(SameDiff sameDiff, SDVariable x, SDVariable y) {
super(sameDiff, new SDVariable[] {x, y} ,false);
}
public BitwiseXor(INDArray x, INDArray y, INDArray output) {
super(new INDArray[]{x, y}, new INDArray[]{output});
}
public BitwiseXor(INDArray x, INDArray y) {
this(x, y,x.ulike());
}
public BitwiseXor() {}
@Override
public String opName() {
return "bitwise_xor";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
return "bitwise_xor";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException("Not yet implemented: " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -61,7 +61,7 @@ public class CyclicRShiftBits extends BaseDynamicTransformOp {
@Override @Override
public String tensorflowName() { public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName()); throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
} }

View File

@ -61,7 +61,7 @@ public class CyclicShiftBits extends BaseDynamicTransformOp {
@Override @Override
public String tensorflowName() { public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName()); throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
} }

View File

@ -61,7 +61,7 @@ public class RShiftBits extends BaseDynamicTransformOp {
@Override @Override
public String tensorflowName() { public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName()); throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
} }

View File

@ -61,7 +61,7 @@ public class ShiftBits extends BaseDynamicTransformOp {
@Override @Override
public String tensorflowName() { public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName()); throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
} }

View File

@ -235,10 +235,7 @@ public class Convolution {
public static INDArray pooling2D(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, public static INDArray pooling2D(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw,
int dh, int dw, boolean isSameMode, Pooling2D.Pooling2DType type, Pooling2D.Divisor divisor, int dh, int dw, boolean isSameMode, Pooling2D.Pooling2DType type, Pooling2D.Divisor divisor,
double extra, int virtualHeight, int virtualWidth, INDArray out) { double extra, int virtualHeight, int virtualWidth, INDArray out) {
Pooling2D pooling = Pooling2D.builder() Pooling2D pooling = new Pooling2D(img, out, Pooling2DConfig.builder()
.arrayInputs(new INDArray[]{img})
.arrayOutputs(new INDArray[]{out})
.config(Pooling2DConfig.builder()
.dH(dh) .dH(dh)
.dW(dw) .dW(dw)
.extra(extra) .extra(extra)
@ -251,8 +248,7 @@ public class Convolution {
.sW(sx) .sW(sx)
.type(type) .type(type)
.divisor(divisor) .divisor(divisor)
.build()) .build());
.build();
Nd4j.getExecutioner().execAndReturn(pooling); Nd4j.getExecutioner().execAndReturn(pooling);
return out; return out;
} }

View File

@ -96,57 +96,6 @@ public abstract class NDArrayIndex implements INDArrayIndex {
return offset(arr.stride(), Indices.offsets(arr.shape(), indices)); return offset(arr.stride(), Indices.offsets(arr.shape(), indices));
} }
/**
* Set the shape and stride for
* new axes based dimensions
* @param arr the array to update
* the shape/strides for
* @param indexes the indexes to update based on
*/
public static void updateForNewAxes(INDArray arr, INDArrayIndex... indexes) {
int numNewAxes = NDArrayIndex.numNewAxis(indexes);
if (numNewAxes >= 1 && (indexes[0].length() > 1 || indexes[0] instanceof NDArrayIndexAll)) {
List<Long> newShape = new ArrayList<>();
List<Long> newStrides = new ArrayList<>();
int currDimension = 0;
for (int i = 0; i < indexes.length; i++) {
if (indexes[i] instanceof NewAxis) {
newShape.add(1L);
newStrides.add(0L);
} else {
newShape.add(arr.size(currDimension));
newStrides.add(arr.size(currDimension));
currDimension++;
}
}
while (currDimension < arr.rank()) {
newShape.add((long) currDimension);
newStrides.add((long) currDimension);
currDimension++;
}
long[] newShapeArr = Longs.toArray(newShape);
long[] newStrideArr = Longs.toArray(newStrides);
// FIXME: this is wrong, it breaks shapeInfo immutability
arr.setShape(newShapeArr);
arr.setStride(newStrideArr);
} else {
if (numNewAxes > 0) {
long[] newShape = Longs.concat(ArrayUtil.toLongArray(ArrayUtil.nTimes(numNewAxes, 1)), arr.shape());
long[] newStrides = Longs.concat(new long[numNewAxes], arr.stride());
arr.setShape(newShape);
arr.setStride(newStrides);
}
}
}
/** /**
* Compute the offset given an array of offsets. * Compute the offset given an array of offsets.
* The offset is computed(for both fortran an d c ordering) as: * The offset is computed(for both fortran an d c ordering) as:

View File

@ -54,7 +54,7 @@ public class DeallocatorService {
deallocatorThreads = new Thread[numThreads]; deallocatorThreads = new Thread[numThreads];
queues = new ReferenceQueue[numThreads]; queues = new ReferenceQueue[numThreads];
for (int e = 0; e < numThreads; e++) { for (int e = 0; e < numThreads; e++) {
log.debug("Starting deallocator thread {}", e + 1); log.trace("Starting deallocator thread {}", e + 1);
queues[e] = new ReferenceQueue<>(); queues[e] = new ReferenceQueue<>();
int deviceId = e % numDevices; int deviceId = e % numDevices;

View File

@ -1151,4 +1151,6 @@ public interface NativeOps {
int lastErrorCode(); int lastErrorCode();
String lastErrorMessage(); String lastErrorMessage();
boolean isBlasVersionMatches(int major, int minor, int build);
} }

View File

@ -101,7 +101,7 @@ public class NativeOpsHolder {
} }
//deviceNativeOps.setOmpNumThreads(4); //deviceNativeOps.setOmpNumThreads(4);
log.info("Number of threads used for NativeOps: {}", deviceNativeOps.ompGetMaxThreads()); log.info("Number of threads used for OpenMP: {}", deviceNativeOps.ompGetMaxThreads());
} catch (Exception | Error e) { } catch (Exception | Error e) {
throw new RuntimeException( throw new RuntimeException(
"ND4J is probably missing dependencies. For more information, please refer to: http://nd4j.org/getstarted.html", "ND4J is probably missing dependencies. For more information, please refer to: http://nd4j.org/getstarted.html",

View File

@ -51,7 +51,8 @@ public abstract class Nd4jBlas implements Blas {
numThreads = NativeOpsHolder.getCores(Runtime.getRuntime().availableProcessors()); numThreads = NativeOpsHolder.getCores(Runtime.getRuntime().availableProcessors());
setMaxThreads(numThreads); setMaxThreads(numThreads);
} }
log.info("Number of threads used for BLAS: {}", getMaxThreads());
log.info("Number of threads used for OpenMP BLAS: {}", getMaxThreads());
} }
} }

View File

@ -52,6 +52,7 @@ public class JCublasBackend extends Nd4jBackend {
throw new RuntimeException("No CUDA devices were found in system"); throw new RuntimeException("No CUDA devices were found in system");
} }
Loader.load(org.bytedeco.cuda.global.cublas.class); Loader.load(org.bytedeco.cuda.global.cublas.class);
return true; return true;
} }

View File

@ -108,6 +108,22 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage()); throw new RuntimeException(nativeOps.lastErrorMessage());
/*
val major = new int[1];
val minor = new int[1];
val build = new int[1];
org.bytedeco.cuda.global.cublas.cublasGetProperty(0, major);
org.bytedeco.cuda.global.cublas.cublasGetProperty(1, minor);
org.bytedeco.cuda.global.cublas.cublasGetProperty(2, build);
val pew = new int[100];
org.bytedeco.cuda.global.cudart.cudaDriverGetVersion(pew);
nativeOps.isBlasVersionMatches(major[0], minor[0], build[0]);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
*/
} }
@Override @Override

View File

@ -28,10 +28,12 @@ import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t; import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.blas.impl.BaseLevel3; import org.nd4j.linalg.api.blas.impl.BaseLevel3;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil; import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.factory.DataTypeValidation; import org.nd4j.linalg.factory.DataTypeValidation;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.CublasPointer; import org.nd4j.linalg.jcublas.CublasPointer;
@ -113,8 +115,13 @@ public class JcublasLevel3 extends BaseLevel3 {
@Override @Override
protected void sgemm(char Order, char TransA, char TransB, int M, int N, int K, float alpha, INDArray A, int lda, protected void sgemm(char Order, char TransA, char TransB, int M, int N, int K, float alpha, INDArray A, int lda,
INDArray B, int ldb, float beta, INDArray C, int ldc) { INDArray B, int ldb, float beta, INDArray C, int ldc) {
//A = Shape.toOffsetZero(A); /*
//B = Shape.toOffsetZero(B); val ctx = AtomicAllocator.getInstance().getDeviceContext();
val handle = ctx.getCublasHandle();
synchronized (handle) {
Nd4j.exec(new Mmul(A, B, C, MMulTranspose.builder().transposeA(false).transposeB(false).build()));
}
*/
Nd4j.getExecutioner().push(); Nd4j.getExecutioner().push();
@ -141,6 +148,7 @@ public class JcublasLevel3 extends BaseLevel3 {
} }
allocator.registerAction(ctx, C, A, B); allocator.registerAction(ctx, C, A, B);
OpExecutionerUtil.checkForAny(C); OpExecutionerUtil.checkForAny(C);
} }

View File

@ -557,6 +557,13 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public Environment(Pointer p) { super(p); } public Environment(Pointer p) { super(p); }
/**
* These 3 fields are mostly for CUDA/cuBLAS version tracking
*/
public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter);
public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter);
public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter);
public static native Environment getInstance(); public static native Environment getInstance();
public native @Cast("bool") boolean isVerbose(); public native @Cast("bool") boolean isVerbose();
@ -1874,7 +1881,7 @@ public native void setOmpNumThreads(int threads);
public native void setOmpMinThreads(int threads); public native void setOmpMinThreads(int threads);
public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build);
/** /**
* *

View File

@ -557,6 +557,13 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public Environment(Pointer p) { super(p); } public Environment(Pointer p) { super(p); }
/**
* These 3 fields are mostly for CUDA/cuBLAS version tracking
*/
public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter);
public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter);
public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter);
public static native Environment getInstance(); public static native Environment getInstance();
public native @Cast("bool") boolean isVerbose(); public native @Cast("bool") boolean isVerbose();
@ -1874,7 +1881,7 @@ public native void setOmpNumThreads(int threads);
public native void setOmpMinThreads(int threads); public native void setOmpMinThreads(int threads);
public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build);
/** /**
* *
@ -21929,6 +21936,78 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
/**
* This operation applies bitwise AND
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* \tparam T
*/
// #if NOT_EXCLUDED(OP_bitwise_and)
@Namespace("nd4j::ops") public static class bitwise_and extends BroadcastableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public bitwise_and(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public bitwise_and(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public bitwise_and position(long position) {
return (bitwise_and)super.position(position);
}
public bitwise_and() { super((Pointer)null); allocate(); }
private native void allocate();
}
// #endif
/**
* This operation applies bitwise OR
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* \tparam T
*/
// #if NOT_EXCLUDED(OP_bitwise_or)
@Namespace("nd4j::ops") public static class bitwise_or extends BroadcastableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public bitwise_or(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public bitwise_or(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public bitwise_or position(long position) {
return (bitwise_or)super.position(position);
}
public bitwise_or() { super((Pointer)null); allocate(); }
private native void allocate();
}
// #endif
/**
* This operation applies bitwise XOR
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* \tparam T
*/
// #if NOT_EXCLUDED(OP_bitwise_xor)
@Namespace("nd4j::ops") public static class bitwise_xor extends BroadcastableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public bitwise_xor(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public bitwise_xor(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public bitwise_xor position(long position) {
return (bitwise_xor)super.position(position);
}
public bitwise_xor() { super((Pointer)null); allocate(); }
private native void allocate();
}
// #endif
/** /**
* This operation returns hamming distance based on bits * This operation returns hamming distance based on bits
* *

View File

@ -389,10 +389,7 @@ public class LayerOpValidation extends BaseOpValidation {
.build(); .build();
INDArray input = Nd4j.create(inSize); INDArray input = Nd4j.create(inSize);
AvgPooling2D avgPooling2D = AvgPooling2D.builder() AvgPooling2D avgPooling2D = new AvgPooling2D(input, null, conf);
.arrayInput(input)
.config(conf)
.build();
val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D); val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D);
@ -410,10 +407,7 @@ public class LayerOpValidation extends BaseOpValidation {
//Test backprop: //Test backprop:
Pooling2DDerivative avg2dDeriv = Pooling2DDerivative.derivativeBuilder() Pooling2DDerivative avg2dDeriv = new Pooling2DDerivative(input, grad, null, conf);
.arrayInputs(new INDArray[]{input, grad})
.config(conf)
.build();
val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv); val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv);
assertEquals(1, outSizesBP.size()); assertEquals(1, outSizesBP.size());
@ -435,10 +429,7 @@ public class LayerOpValidation extends BaseOpValidation {
.build(); .build();
INDArray input = Nd4j.create(inSize); INDArray input = Nd4j.create(inSize);
AvgPooling2D avgPooling2D = AvgPooling2D.builder() AvgPooling2D avgPooling2D = new AvgPooling2D(input, null, conf);
.arrayInput(input)
.config(conf)
.build();
val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D); val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D);
assertEquals(1, outSizes.size()); assertEquals(1, outSizes.size());
@ -454,11 +445,7 @@ public class LayerOpValidation extends BaseOpValidation {
INDArray grad = Nd4j.create(exp); INDArray grad = Nd4j.create(exp);
//Test backprop: //Test backprop:
Pooling2DDerivative avg2dDeriv = Pooling2DDerivative.derivativeBuilder() Pooling2DDerivative avg2dDeriv = new Pooling2DDerivative(input, grad, Nd4j.create(inSize), conf);
.arrayInputs(new INDArray[]{input, grad}) //Original input, and output gradient (eps - same shape as output)
.arrayOutputs(new INDArray[]{Nd4j.create(inSize)}) //Output for BP: same shape as original input
.config(conf)
.build();
val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv); val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv);
assertEquals(1, outSizesBP.size()); assertEquals(1, outSizesBP.size());
@ -749,7 +736,7 @@ public class LayerOpValidation extends BaseOpValidation {
.isSameMode(false) .isSameMode(false)
.build(); .build();
SDVariable out = sd.cnn().conv2d(vars, c); SDVariable out = sd.cnn().conv2d("conv", vars, c);
out = sd.nn().tanh("out", out); out = sd.nn().tanh("out", out);
INDArray outArr = sd.execAndEndResult(); INDArray outArr = sd.execAndEndResult();