Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
AlexDBlack 2019-09-05 00:53:49 +10:00
commit b7226bdd7a
91 changed files with 1622 additions and 864 deletions

View File

@ -87,11 +87,7 @@ public class MKLDNNSubsamplingHelper implements SubsamplingHelper {
break;
}
Pooling2DDerivative d = Pooling2DDerivative.derivativeBuilder()
.config(conf)
.arrayInputs(new INDArray[]{input, epsilon})
.arrayOutputs(new INDArray[]{gradAtInput})
.build();
Pooling2DDerivative d = new Pooling2DDerivative(input, epsilon, gradAtInput, conf);
Nd4j.exec(d);
return new Pair<Gradient,INDArray>(new DefaultGradient(), gradAtInput);

View File

@ -0,0 +1,40 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SAMEDIFF_BLASVERSIONHELPER_H
#define SAMEDIFF_BLASVERSIONHELPER_H
#include <dll.h>
#include <cuda.h>
#include <cuda_runtime.h>
namespace nd4j {
class ND4J_EXPORT BlasVersionHelper {
public:
int _blasMajorVersion = 0;
int _blasMinorVersion = 0;
int _blasPatchVersion = 0;
BlasVersionHelper();
~BlasVersionHelper() = default;
};
}
#endif //DEV_TESTS_BLASVERSIONHELPER_H

View File

@ -253,20 +253,20 @@ if(CUDA_BLAS)
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
if (NOT BUILD_TESTS)
CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA}
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
else()
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true")
CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA}
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
endif()

View File

@ -35,7 +35,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include "BlasVersionHelper.h"
#endif
namespace nd4j {
@ -66,6 +66,13 @@ namespace nd4j {
#endif
#ifdef __CUDABLAS__
BlasVersionHelper ver;
_blasMajorVersion = ver._blasMajorVersion;
_blasMinorVersion = ver._blasMinorVersion;
_blasPatchVersion = ver._blasPatchVersion;
printf("ND4J CUDA build version: %i.%i.%i\n", _blasMajorVersion, _blasMinorVersion, _blasPatchVersion);
fflush(stdout);
int devCnt = 0;
cudaGetDeviceCount(&devCnt);
auto devProperties = new cudaDeviceProp[devCnt];

View File

@ -56,6 +56,13 @@ namespace nd4j{
Environment();
~Environment();
public:
/**
* These 3 fields are mostly for CUDA/cuBLAS version tracking
*/
int _blasMajorVersion = 0;
int _blasMinorVersion = 0;
int _blasPatchVersion = 0;
static Environment* getInstance();
bool isVerbose();

View File

@ -647,7 +647,7 @@ ND4J_EXPORT void setOmpNumThreads(int threads);
ND4J_EXPORT void setOmpMinThreads(int threads);
ND4J_EXPORT bool isBlasVersionMatches(int major, int minor, int build);
/**
*

View File

@ -583,7 +583,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
auto fName = builder.CreateString(*(var->getName()));
auto id = CreateIntPair(builder, var->id(), var->index());
auto fv = CreateFlatVariable(builder, id, fName, static_cast<nd4j::graph::DataType>(array->dataType()), 0, fArray);
auto fv = CreateFlatVariable(builder, id, fName, static_cast<nd4j::graph::DType>(array->dataType()), 0, fArray);
variables_vector.push_back(fv);
arrays++;

View File

@ -728,6 +728,10 @@ void execReduce3Tad(Nd4jPointer *extraPointers,
}
}
bool isBlasVersionMatches(int major, int minor, int build) {
return true;
}
/**
*
* @param opNum

View File

@ -0,0 +1,29 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../BlasVersionHelper.h"
namespace nd4j {
BlasVersionHelper::BlasVersionHelper() {
_blasMajorVersion = __CUDACC_VER_MAJOR__;
_blasMinorVersion = __CUDACC_VER_MINOR__;
_blasPatchVersion = __CUDACC_VER_BUILD__;
}
}

View File

@ -3357,6 +3357,18 @@ void deleteTadPack(nd4j::TadPack* ptr) {
delete ptr;
}
bool isBlasVersionMatches(int major, int minor, int build) {
auto result = major == Environment::getInstance()->_blasMajorVersion && minor == Environment::getInstance()->_blasMinorVersion && build == Environment::getInstance()->_blasPatchVersion;
if (!result) {
nd4j_printf("CUDA/cuBLAS version mismatch. Expected: %i.%i.%i but got %i.%i.%i instead\n", Environment::getInstance()->_blasMajorVersion, Environment::getInstance()->_blasMinorVersion, Environment::getInstance()->_blasPatchVersion, major, minor, build);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(152);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("CUDA/cuBLAS version mismatch");
}
return result;
}
nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) {
return nd4j::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype);
}

View File

@ -38,7 +38,7 @@ namespace nd4j {
public:
static int asInt(DataType type);
static DataType fromInt(int dtype);
static DataType fromFlatDataType(nd4j::graph::DataType dtype);
static DataType fromFlatDataType(nd4j::graph::DType dtype);
FORCEINLINE static std::string asString(DataType dataType);
template <typename T>

View File

@ -27,7 +27,7 @@ namespace nd4j {
return (DataType) val;
}
DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DataType dtype) {
DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DType dtype) {
return (DataType) dtype;
}

View File

@ -40,56 +40,56 @@ inline const char *EnumNameByteOrder(ByteOrder e) {
return EnumNamesByteOrder()[index];
}
enum DataType {
DataType_INHERIT = 0,
DataType_BOOL = 1,
DataType_FLOAT8 = 2,
DataType_HALF = 3,
DataType_HALF2 = 4,
DataType_FLOAT = 5,
DataType_DOUBLE = 6,
DataType_INT8 = 7,
DataType_INT16 = 8,
DataType_INT32 = 9,
DataType_INT64 = 10,
DataType_UINT8 = 11,
DataType_UINT16 = 12,
DataType_UINT32 = 13,
DataType_UINT64 = 14,
DataType_QINT8 = 15,
DataType_QINT16 = 16,
DataType_BFLOAT16 = 17,
DataType_UTF8 = 50,
DataType_MIN = DataType_INHERIT,
DataType_MAX = DataType_UTF8
enum DType {
DType_INHERIT = 0,
DType_BOOL = 1,
DType_FLOAT8 = 2,
DType_HALF = 3,
DType_HALF2 = 4,
DType_FLOAT = 5,
DType_DOUBLE = 6,
DType_INT8 = 7,
DType_INT16 = 8,
DType_INT32 = 9,
DType_INT64 = 10,
DType_UINT8 = 11,
DType_UINT16 = 12,
DType_UINT32 = 13,
DType_UINT64 = 14,
DType_QINT8 = 15,
DType_QINT16 = 16,
DType_BFLOAT16 = 17,
DType_UTF8 = 50,
DType_MIN = DType_INHERIT,
DType_MAX = DType_UTF8
};
inline const DataType (&EnumValuesDataType())[19] {
static const DataType values[] = {
DataType_INHERIT,
DataType_BOOL,
DataType_FLOAT8,
DataType_HALF,
DataType_HALF2,
DataType_FLOAT,
DataType_DOUBLE,
DataType_INT8,
DataType_INT16,
DataType_INT32,
DataType_INT64,
DataType_UINT8,
DataType_UINT16,
DataType_UINT32,
DataType_UINT64,
DataType_QINT8,
DataType_QINT16,
DataType_BFLOAT16,
DataType_UTF8
inline const DType (&EnumValuesDType())[19] {
static const DType values[] = {
DType_INHERIT,
DType_BOOL,
DType_FLOAT8,
DType_HALF,
DType_HALF2,
DType_FLOAT,
DType_DOUBLE,
DType_INT8,
DType_INT16,
DType_INT32,
DType_INT64,
DType_UINT8,
DType_UINT16,
DType_UINT32,
DType_UINT64,
DType_QINT8,
DType_QINT16,
DType_BFLOAT16,
DType_UTF8
};
return values;
}
inline const char * const *EnumNamesDataType() {
inline const char * const *EnumNamesDType() {
static const char * const names[] = {
"INHERIT",
"BOOL",
@ -147,9 +147,9 @@ inline const char * const *EnumNamesDataType() {
return names;
}
inline const char *EnumNameDataType(DataType e) {
inline const char *EnumNameDType(DType e) {
const size_t index = static_cast<int>(e);
return EnumNamesDataType()[index];
return EnumNamesDType()[index];
}
struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
@ -165,8 +165,8 @@ struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const flatbuffers::Vector<int8_t> *buffer() const {
return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_BUFFER);
}
DataType dtype() const {
return static_cast<DataType>(GetField<int8_t>(VT_DTYPE, 0));
DType dtype() const {
return static_cast<DType>(GetField<int8_t>(VT_DTYPE, 0));
}
ByteOrder byteOrder() const {
return static_cast<ByteOrder>(GetField<int8_t>(VT_BYTEORDER, 0));
@ -192,7 +192,7 @@ struct FlatArrayBuilder {
void add_buffer(flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer) {
fbb_.AddOffset(FlatArray::VT_BUFFER, buffer);
}
void add_dtype(DataType dtype) {
void add_dtype(DType dtype) {
fbb_.AddElement<int8_t>(FlatArray::VT_DTYPE, static_cast<int8_t>(dtype), 0);
}
void add_byteOrder(ByteOrder byteOrder) {
@ -214,7 +214,7 @@ inline flatbuffers::Offset<FlatArray> CreateFlatArray(
flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer = 0,
DataType dtype = DataType_INHERIT,
DType dtype = DType_INHERIT,
ByteOrder byteOrder = ByteOrder_LE) {
FlatArrayBuilder builder_(_fbb);
builder_.add_buffer(buffer);
@ -228,7 +228,7 @@ inline flatbuffers::Offset<FlatArray> CreateFlatArrayDirect(
flatbuffers::FlatBufferBuilder &_fbb,
const std::vector<int64_t> *shape = nullptr,
const std::vector<int8_t> *buffer = nullptr,
DataType dtype = DataType_INHERIT,
DType dtype = DType_INHERIT,
ByteOrder byteOrder = ByteOrder_LE) {
return nd4j::graph::CreateFlatArray(
_fbb,

View File

@ -23,7 +23,7 @@ nd4j.graph.ByteOrder = {
/**
* @enum
*/
nd4j.graph.DataType = {
nd4j.graph.DType = {
INHERIT: 0,
BOOL: 1,
FLOAT8: 2,
@ -123,11 +123,11 @@ nd4j.graph.FlatArray.prototype.bufferArray = function() {
};
/**
* @returns {nd4j.graph.DataType}
* @returns {nd4j.graph.DType}
*/
nd4j.graph.FlatArray.prototype.dtype = function() {
var offset = this.bb.__offset(this.bb_pos, 8);
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT;
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
};
/**
@ -205,10 +205,10 @@ nd4j.graph.FlatArray.startBufferVector = function(builder, numElems) {
/**
* @param {flatbuffers.Builder} builder
* @param {nd4j.graph.DataType} dtype
* @param {nd4j.graph.DType} dtype
*/
nd4j.graph.FlatArray.addDtype = function(builder, dtype) {
builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT);
builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT);
};
/**

View File

@ -5,7 +5,7 @@
namespace nd4j.graph
{
public enum DataType : sbyte
public enum DType : sbyte
{
INHERIT = 0,
BOOL = 1,

View File

@ -2,8 +2,8 @@
package nd4j.graph;
public final class DataType {
private DataType() { }
public final class DType {
private DType() { }
public static final byte INHERIT = 0;
public static final byte BOOL = 1;
public static final byte FLOAT8 = 2;

View File

@ -2,7 +2,7 @@
# namespace: graph
class DataType(object):
class DType(object):
INHERIT = 0
BOOL = 1
FLOAT8 = 2

View File

@ -33,13 +33,13 @@ public struct FlatArray : IFlatbufferObject
public ArraySegment<byte>? GetBufferBytes() { return __p.__vector_as_arraysegment(6); }
#endif
public sbyte[] GetBufferArray() { return __p.__vector_as_array<sbyte>(6); }
public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } }
public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
public ByteOrder ByteOrder { get { int o = __p.__offset(10); return o != 0 ? (ByteOrder)__p.bb.GetSbyte(o + __p.bb_pos) : ByteOrder.LE; } }
public static Offset<FlatArray> CreateFlatArray(FlatBufferBuilder builder,
VectorOffset shapeOffset = default(VectorOffset),
VectorOffset bufferOffset = default(VectorOffset),
DataType dtype = DataType.INHERIT,
DType dtype = DType.INHERIT,
ByteOrder byteOrder = ByteOrder.LE) {
builder.StartObject(4);
FlatArray.AddBuffer(builder, bufferOffset);
@ -58,7 +58,7 @@ public struct FlatArray : IFlatbufferObject
public static VectorOffset CreateBufferVector(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte(data[i]); return builder.EndVector(); }
public static VectorOffset CreateBufferVectorBlock(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
public static void StartBufferVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
public static void AddByteOrder(FlatBufferBuilder builder, ByteOrder byteOrder) { builder.AddSbyte(3, (sbyte)byteOrder, 0); }
public static Offset<FlatArray> EndFlatArray(FlatBufferBuilder builder) {
int o = builder.EndObject();

View File

@ -97,14 +97,14 @@ public struct FlatNode : IFlatbufferObject
public ArraySegment<byte>? GetOpNameBytes() { return __p.__vector_as_arraysegment(36); }
#endif
public byte[] GetOpNameArray() { return __p.__vector_as_array<byte>(36); }
public DataType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DataType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DataType)0; }
public DType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; }
public int OutputTypesLength { get { int o = __p.__offset(38); return o != 0 ? __p.__vector_len(o) : 0; } }
#if ENABLE_SPAN_T
public Span<byte> GetOutputTypesBytes() { return __p.__vector_as_span(38); }
#else
public ArraySegment<byte>? GetOutputTypesBytes() { return __p.__vector_as_arraysegment(38); }
#endif
public DataType[] GetOutputTypesArray() { return __p.__vector_as_array<DataType>(38); }
public DType[] GetOutputTypesArray() { return __p.__vector_as_array<DType>(38); }
public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
@ -196,8 +196,8 @@ public struct FlatNode : IFlatbufferObject
public static void StartOutputNamesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
public static void AddOpName(FlatBufferBuilder builder, StringOffset opNameOffset) { builder.AddOffset(16, opNameOffset.Value, 0); }
public static void AddOutputTypes(FlatBufferBuilder builder, VectorOffset outputTypesOffset) { builder.AddOffset(17, outputTypesOffset.Value, 0); }
public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
public static void AddScalar(FlatBufferBuilder builder, Offset<FlatArray> scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); }
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {

View File

@ -25,7 +25,7 @@ public struct FlatVariable : IFlatbufferObject
public ArraySegment<byte>? GetNameBytes() { return __p.__vector_as_arraysegment(6); }
#endif
public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); }
public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } }
public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
public long Shape(int j) { int o = __p.__offset(10); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; }
public int ShapeLength { get { int o = __p.__offset(10); return o != 0 ? __p.__vector_len(o) : 0; } }
#if ENABLE_SPAN_T
@ -41,7 +41,7 @@ public struct FlatVariable : IFlatbufferObject
public static Offset<FlatVariable> CreateFlatVariable(FlatBufferBuilder builder,
Offset<IntPair> idOffset = default(Offset<IntPair>),
StringOffset nameOffset = default(StringOffset),
DataType dtype = DataType.INHERIT,
DType dtype = DType.INHERIT,
VectorOffset shapeOffset = default(VectorOffset),
Offset<FlatArray> ndarrayOffset = default(Offset<FlatArray>),
int device = 0,
@ -60,7 +60,7 @@ public struct FlatVariable : IFlatbufferObject
public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); }
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(3, shapeOffset.Value, 0); }
public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); }
public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); }

View File

@ -312,11 +312,11 @@ nd4j.graph.FlatNode.prototype.opName = function(optionalEncoding) {
/**
* @param {number} index
* @returns {nd4j.graph.DataType}
* @returns {nd4j.graph.DType}
*/
nd4j.graph.FlatNode.prototype.outputTypes = function(index) {
var offset = this.bb.__offset(this.bb_pos, 38);
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DataType} */ (0);
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0);
};
/**
@ -686,7 +686,7 @@ nd4j.graph.FlatNode.addOutputTypes = function(builder, outputTypesOffset) {
/**
* @param {flatbuffers.Builder} builder
* @param {Array.<nd4j.graph.DataType>} data
* @param {Array.<nd4j.graph.DType>} data
* @returns {flatbuffers.Offset}
*/
nd4j.graph.FlatNode.createOutputTypesVector = function(builder, data) {

View File

@ -65,8 +65,8 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const flatbuffers::String *name() const {
return GetPointer<const flatbuffers::String *>(VT_NAME);
}
DataType dtype() const {
return static_cast<DataType>(GetField<int8_t>(VT_DTYPE, 0));
DType dtype() const {
return static_cast<DType>(GetField<int8_t>(VT_DTYPE, 0));
}
const flatbuffers::Vector<int64_t> *shape() const {
return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE);
@ -106,7 +106,7 @@ struct FlatVariableBuilder {
void add_name(flatbuffers::Offset<flatbuffers::String> name) {
fbb_.AddOffset(FlatVariable::VT_NAME, name);
}
void add_dtype(DataType dtype) {
void add_dtype(DType dtype) {
fbb_.AddElement<int8_t>(FlatVariable::VT_DTYPE, static_cast<int8_t>(dtype), 0);
}
void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) {
@ -137,7 +137,7 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariable(
flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<IntPair> id = 0,
flatbuffers::Offset<flatbuffers::String> name = 0,
DataType dtype = DataType_INHERIT,
DType dtype = DType_INHERIT,
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
flatbuffers::Offset<FlatArray> ndarray = 0,
int32_t device = 0,
@ -157,7 +157,7 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariableDirect(
flatbuffers::FlatBufferBuilder &_fbb,
flatbuffers::Offset<IntPair> id = 0,
const char *name = nullptr,
DataType dtype = DataType_INHERIT,
DType dtype = DType_INHERIT,
const std::vector<int64_t> *shape = nullptr,
flatbuffers::Offset<FlatArray> ndarray = 0,
int32_t device = 0,

View File

@ -76,11 +76,11 @@ nd4j.graph.FlatVariable.prototype.name = function(optionalEncoding) {
};
/**
* @returns {nd4j.graph.DataType}
* @returns {nd4j.graph.DType}
*/
nd4j.graph.FlatVariable.prototype.dtype = function() {
var offset = this.bb.__offset(this.bb_pos, 8);
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT;
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
};
/**
@ -150,10 +150,10 @@ nd4j.graph.FlatVariable.addName = function(builder, nameOffset) {
/**
* @param {flatbuffers.Builder} builder
* @param {nd4j.graph.DataType} dtype
* @param {nd4j.graph.DType} dtype
*/
nd4j.graph.FlatVariable.addDtype = function(builder, dtype) {
builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT);
builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT);
};
/**

View File

@ -111,7 +111,7 @@ namespace nd4j {
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array.dataType()), bo);
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DType>(array.dataType()), bo);
}
}
}

View File

@ -219,7 +219,7 @@ namespace nd4j {
throw std::runtime_error("CONSTANT variable must have NDArray bundled");
auto ar = flatVariable->ndarray();
if (ar->dtype() == DataType_UTF8) {
if (ar->dtype() == DType_UTF8) {
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
} else {
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
@ -320,7 +320,7 @@ namespace nd4j {
auto fBuffer = builder.CreateVector(array->asByteVector());
// packing array
auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DataType) array->dataType());
auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DType) array->dataType());
// packing id/index of this var
auto fVid = CreateIntPair(builder, this->_id, this->_index);
@ -331,7 +331,7 @@ namespace nd4j {
stringId = builder.CreateString(this->_name);
// returning array
return CreateFlatVariable(builder, fVid, stringId, static_cast<nd4j::graph::DataType>(array->dataType()), 0, fArray);
return CreateFlatVariable(builder, fVid, stringId, static_cast<nd4j::graph::DType>(array->dataType()), 0, fArray);
} else {
throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList");
}

View File

@ -23,7 +23,7 @@ enum ByteOrder:byte {
}
// DataType for arrays/buffers
enum DataType:byte {
enum DType:byte {
INHERIT,
BOOL,
FLOAT8,
@ -49,7 +49,7 @@ enum DataType:byte {
table FlatArray {
shape:[long]; // shape in Nd4j format
buffer:[byte]; // byte buffer with data
dtype:DataType; // data type of actual data within buffer
dtype:DType; // data type of actual data within buffer
byteOrder:ByteOrder; // byte order of buffer
}

View File

@ -48,7 +48,7 @@ table FlatNode {
opName:string; //Used to help resolving the class. In a few cases, multiple classes/opNames are mapped to same hash, and might have different config/properties/differentiability
// output data types (optional)
outputTypes:[DataType];
outputTypes:[DType];
//Scalar value - used for scalar ops. Should be single value only.
scalar:FlatArray;

View File

@ -51,7 +51,7 @@ table UIVariable {
id:IntPair; //Existing IntPair class
name:string;
type:VarType; //Use existing VarType: VARIABLE, CONSTANT, ARRAY, PLACEHOLDER
datatype:DataType;
datatype:DType;
shape:[long];
controlDeps:[string]; //Input control dependencies: variable x -> this
outputOfOp:string; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of

View File

@ -30,7 +30,7 @@ enum VarType:byte {
table FlatVariable {
id:IntPair; // ID of the Variable, in format of IntPair.first is node Id, IntPair.second is output index of the node
name:string; // symbolic ID of the Variable (if defined)
dtype:DataType;
dtype:DType;
shape:[long]; // shape is absolutely optional. either shape or ndarray might be set
ndarray:FlatArray;

View File

@ -0,0 +1,50 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_bitwise_and)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
BROADCASTABLE_OP_IMPL(bitwise_and, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), y, z, false);
return Status::OK();
}
DECLARE_TYPES(bitwise_and) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS})
->setSameMode(true);
}
}
}
#endif

View File

@ -0,0 +1,50 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_bitwise_or)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
BROADCASTABLE_OP_IMPL(bitwise_or, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), y, z, false);
return Status::OK();
}
DECLARE_TYPES(bitwise_or) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS})
->setSameMode(true);
}
}
}
#endif

View File

@ -0,0 +1,50 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_bitwise_xor)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
BROADCASTABLE_OP_IMPL(bitwise_xor, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), y, z, false);
return Status::OK();
}
DECLARE_TYPES(bitwise_xor) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS})
->setSameMode(true);
}
}
}
#endif

View File

@ -29,21 +29,26 @@ namespace ops {
CUSTOM_OP_IMPL(dynamic_stitch, 2, 1, false, 0, 0) {
int numOfData = block.width();
// int k = 0;
// checking input data size
REQUIRE_TRUE(numOfData % 2 == 0, 0,
"dynamic_stitch: The input params should contains"
" both indeces and data lists with same length.");
// split input data list on two equal parts
numOfData /= 2;
// form input lists to use with helpers - both indices and float data inputs
auto output = OUTPUT_VARIABLE(0);
std::vector<NDArray*> inputs(numOfData);
std::vector<NDArray*> indices(numOfData);
for (int e = 0; e < numOfData; e++) {
auto data = INPUT_VARIABLE(numOfData + e);
auto index = INPUT_VARIABLE(e);
inputs[e] = data;
indices[e] = index;
}
// run helper
return helpers::dynamicStitchFunctor(block.launchContext(), inputs, indices, output);
}
@ -59,17 +64,17 @@ namespace ops {
numOfData /= 2; // only index part it's needed to review
auto restShape = inputShape->at(numOfData);
auto firstShape = inputShape->at(0);
// check up inputs to avoid non-int indices and calculate max value from indices to output shape length
for(int i = 0; i < numOfData; i++) {
auto input = INPUT_VARIABLE(i);
REQUIRE_TRUE(input->isZ(), 0, "dynamic_stitch: Indices should be integer, but %d type given.", (int)input->dataType() );
// FIXME: we have reduce::Max, cinsider using it instead
auto maxV = input->reduceNumber(reduce::Max);
if (maxV.e<Nd4jLong>(0) > maxValue) maxValue = maxV.e<Nd4jLong>(0);
}
int outRank = shape::rank(restShape) - shape::rank(firstShape) + 1;
// calculate output rank - difference between indices shape and data shape
int outRank = shape::rank(restShape) - shape::rank(firstShape) + 1; // at least 1D tensor
std::vector<Nd4jLong> outShape(outRank);
// fill up output shape template: the first to max index, and rests - to vals from the first data input
outShape[0] = maxValue + 1;
for(int i = 1; i < outRank; ++i)
outShape[i] = shape::sizeAt(restShape, i);

View File

@ -33,12 +33,13 @@ namespace nd4j {
* 0: 1D row-vector (or with shape (1, m))
* 1: 1D integer vector with slice nums
* 2: 1D float-point values vector with same shape as above
* 3: 2D float-point matrix with data to search
*
* Int args:
* 0: N - number of slices
*
* Output:
* 0: 1D vector with edge forces for input and values
* 0: 2D matrix with the same shape and type as the 3th argument
*/
#if NOT_EXCLUDED(OP_barnes_edge_forces)
DECLARE_CUSTOM_OP(barnes_edge_forces, 4, 1, false, 0, 1);
@ -52,9 +53,11 @@ namespace nd4j {
* 0: 1D int row-vector
* 1: 1D int col-vector
* 2: 1D float vector with values
*
*
* Output:
* 0: symmetric 2D matrix with given values on given places
* 0: 1D int result row-vector
* 1: 1D int result col-vector
* 2: a float-point tensor with shape 1xN, with values from the last input vector
*/
#if NOT_EXCLUDED(OP_barnes_symmetrized)
DECLARE_CUSTOM_OP(barnes_symmetrized, 3, 3, false, 0, -1);

View File

@ -81,6 +81,39 @@ namespace nd4j {
DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0);
#endif
/**
* This operation applies bitwise AND
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* @tparam T
*/
#if NOT_EXCLUDED(OP_bitwise_and)
DECLARE_BROADCASTABLE_OP(bitwise_and, 0, 0);
#endif
/**
* This operation applies bitwise OR
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* @tparam T
*/
#if NOT_EXCLUDED(OP_bitwise_or)
DECLARE_BROADCASTABLE_OP(bitwise_or, 0, 0);
#endif
/**
* This operation applies bitwise XOR
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* @tparam T
*/
#if NOT_EXCLUDED(OP_bitwise_xor)
DECLARE_BROADCASTABLE_OP(bitwise_xor, 0, 0);
#endif
/**
* This operation returns hamming distance based on bits
*

View File

@ -120,7 +120,7 @@ namespace nd4j {
#endif
/**
* This operation unstacks given NDArray into NDArrayList
* This operation unstacks given NDArray into NDArrayList by the first dimension
*/
#if NOT_EXCLUDED(OP_unstack_list)
DECLARE_LIST_OP(unstack_list, 1, 1, 0, 0);

View File

@ -594,21 +594,46 @@ namespace nd4j {
/**
* This operation rearranges data from depth into blocks of spatial data. This is the reverse transformation
* of space_to_depth op. This op output is a copy of the input tensor where values from the depth dimension
* are moved in spatial blocks to the height and width dimensions. Int attr 0 indicates the input
* block size and how the data is moved.
* Input:
* 0 - 4D tensor on given type
* Output:
* 0 - 4D tensor of given type and proper shape
*
*
*
* Int arguments:
* 0 - block size
* 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels }
* 1 ("NCHW"): shape{ batch, channels, height, width }
* 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 }
* optional (default 0)
*/
#if NOT_EXCLUDED(OP_depth_to_space)
DECLARE_CUSTOM_OP(depth_to_space, 1, 1, false, 0, 2);
DECLARE_CUSTOM_OP(depth_to_space, 1, 1, false, 0, -1);
#endif
/**
* This operation rearranges blocks of spatial data, into depth.This op output is a copy of the input tensor
* where values from the height and width dimensions are moved to the depth dimension. Int attr 0 indicates
* the input block size.
*
* Input:
* - 4D tensor of given type
* Output:
* - 4D tensor
*
* Int arguments:
* 0 - block size
* 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels }
* 1 ("NCHW"): shape{ batch, channels, height, width }
* 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 }
* optional (default 0)
*
*/
#if NOT_EXCLUDED(OP_space_to_depth)
DECLARE_CUSTOM_OP(space_to_depth, 1, 1, false, 0, 2);
DECLARE_CUSTOM_OP(space_to_depth, 1, 1, false, 0, -1);
#endif
/**
@ -622,13 +647,42 @@ namespace nd4j {
#endif
/**
* Zero-pads and then rearranges (permutes) blocks of spatial data into batch. More specifically, this op
* outputs a copy of the input tensor where values from the height and width dimensions are moved to the
* batch dimension. After the zero-padding, both height and width of the input must be divisible by the block
* size.
*
* Inputs:
* 0 - input tensor
* 1 - 2D paddings tensor (shape {M, 2})
*
* Output:
* - result tensor
*
* Int args:
* 0 - block size (M)
*
*/
#if NOT_EXCLUDED(OP_space_to_batch)
DECLARE_CUSTOM_OP(space_to_batch, 2, 1, false, 0, 1);
#endif
/*
* This operation divides "spatial" dimensions [1, ..., M] of the input into a grid of blocks of shape
* block_shape, and interleaves these blocks with the "batch" dimension (0) such that in the output,
* the spatial dimensions [1, ..., M] correspond to the position within the grid, and the batch dimension
* combines both the position within a spatial block and the original batch position. Prior to division into
* blocks, the spatial dimensions of the input are optionally zero padded according to paddings.
*
* Inputs:
* 0 - input (N-D tensor)
* 1 - block_shape - int 1D tensor with M length
* 2 - paddings - int 2D tensor with shape {M, 2}
*
* Output:
* - N-D tensor with the same type as input 0.
*
* */
#if NOT_EXCLUDED(OP_space_to_batch_nd)
DECLARE_CUSTOM_OP(space_to_batch_nd, 3, 1, false, 0, 0);
#endif
@ -973,7 +1027,7 @@ namespace nd4j {
* return value:
* tensor with min values according to indices sets.
*/
#if NOT_EXCLUDED(OP_segment_min_bp)
#if NOT_EXCLUDED(OP_segment_min)
DECLARE_CUSTOM_OP(segment_min, 2, 1, false, 0, 0);
#endif
#if NOT_EXCLUDED(OP_segment_min_bp)

View File

@ -118,19 +118,19 @@ namespace nd4j {
PointersManager pm(context, "dynamicPartition");
if (sourceDimsLen) {
if (sourceDimsLen) { // non-linear case
std::vector<int> sourceDims(sourceDimsLen);
for (int i = sourceDimsLen; i > 0; i--)
sourceDims[sourceDimsLen - i] = input->rankOf() - i;
//compute tad array for given dimensions
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), sourceDims);
std::vector<void *> outBuffers(outSize);
std::vector<Nd4jLong *> tadShapes(outSize);
std::vector<Nd4jLong *> tadOffsets(outSize);
std::vector<Nd4jLong> numTads(outSize);
// fill up dimensions array for before kernel
for (unsigned int i = 0; i < outSize; i++) {
outputs[i].first = outputList[i];
std::vector<int> outDims(outputs[i].first->rankOf() - 1);
@ -151,10 +151,10 @@ namespace nd4j {
auto dOutBuffers = reinterpret_cast<void **>(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *)));
auto dOutTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *)));
auto dOutTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *)));
// run kernel on device
dynamicPartitionTadKernel<X,Y><<<256, 256, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize);
} else {
} else { // linear case
auto numThreads = 256;
auto shmemSize = numThreads * sizeof(Y) * 2 + 1024;
@ -169,7 +169,6 @@ namespace nd4j {
auto dOutBuffers = reinterpret_cast<void **>(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *)));
auto dOutShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(outShapes.data(), outShapes.size() * sizeof(Nd4jLong *)));
dynamicPartitionScalarKernel<X,Y><<<256, numThreads, shmemSize, *context->getCudaStream()>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), indices->getSpecialBuffer(), indices-> getSpecialShapeInfo(), dOutBuffers, dOutShapes, outSize);
}

View File

@ -544,8 +544,8 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_1) {
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, adjustSaturation_2) {
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::FLOAT32);
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::DOUBLE);
NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::DOUBLE);
nd4j::ops::adjust_saturation op;
auto results = op.execute({&input}, {10}, {2});
@ -553,7 +553,8 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_2) {
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printIndexedBuffer();
// result->printIndexedBuffer("Result2");
// exp.printIndexedBuffer("Expect2");
ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result));

View File

@ -94,10 +94,10 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) {
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
auto fBuffer = builder.CreateVector(array->asByteVector());
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT);
auto fVid = CreateIntPair(builder, -1);
auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray);
auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray);
std::vector<int> outputs1, outputs2, inputs1, inputs2;
outputs1.push_back(2);
@ -265,7 +265,7 @@ TEST_F(FlatBuffersTest, ExplicitOutputTest1) {
auto name1 = builder.CreateString("wow1");
auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DataType::FLOAT);
auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DType::FLOAT);
std::vector<flatbuffers::Offset<FlatVariable>> variables_vector;
variables_vector.push_back(fXVar);

View File

@ -73,9 +73,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_1) {
auto fBuffer = builder.CreateVector(vec);
auto fVid = CreateIntPair(builder, 1, 12);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray);
builder.Finish(flatVar);
@ -107,9 +107,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_2) {
auto fBuffer = builder.CreateVector(vec);
auto fVid = CreateIntPair(builder, 1, 12);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray);
builder.Finish(flatVar);
@ -144,9 +144,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_3) {
auto fBuffer = builder.CreateVector(vec);
auto fVid = CreateIntPair(builder, 1, 12);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE);
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray);
builder.Finish(flatVar);
@ -180,7 +180,7 @@ TEST_F(VariableTests, Test_FlatVariableDataType_4) {
auto fShape = builder.CreateVector(original.getShapeAsFlatVector());
auto fVid = CreateIntPair(builder, 37, 12);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER);
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER);
builder.Finish(flatVar);

View File

@ -469,7 +469,7 @@ public class DifferentialFunctionFactory {
* @return
*/
public SDVariable localResponseNormalization(SDVariable input, LocalResponseNormalizationConfig lrnConfig) {
LocalResponseNormalization lrn = LocalResponseNormalization.builder()
LocalResponseNormalization lrn = LocalResponseNormalization.sameDiffBuilder()
.inputFunctions(new SDVariable[]{input})
.sameDiff(sameDiff())
.config(lrnConfig)
@ -487,7 +487,7 @@ public class DifferentialFunctionFactory {
* @return
*/
public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
Conv1D conv1D = Conv1D.builder()
Conv1D conv1D = Conv1D.sameDiffBuilder()
.inputFunctions(new SDVariable[]{input, weights})
.sameDiff(sameDiff())
.config(conv1DConfig)
@ -496,6 +496,34 @@ public class DifferentialFunctionFactory {
return conv1D.outputVariable();
}
/**
* Conv1d operation.
*
* @param input the inputs to conv1d
* @param weights conv1d weights
* @param bias conv1d bias
* @param conv1DConfig the configuration
* @return
*/
public SDVariable conv1d(SDVariable input, SDVariable weights, SDVariable bias, Conv1DConfig conv1DConfig) {
SDVariable[] args;
if(bias == null){
args = new SDVariable[]{input, weights};
} else {
args = new SDVariable[]{input, weights, bias};
}
Conv1D conv1D = Conv1D.sameDiffBuilder()
.inputFunctions(args)
.sameDiff(sameDiff())
.config(conv1DConfig)
.build();
return conv1D.outputVariable();
}
/**
* Conv2d operation.
*
@ -504,7 +532,7 @@ public class DifferentialFunctionFactory {
* @return
*/
public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
Conv2D conv2D = Conv2D.builder()
Conv2D conv2D = Conv2D.sameDiffBuilder()
.inputFunctions(inputs)
.sameDiff(sameDiff())
.config(conv2DConfig)
@ -530,7 +558,7 @@ public class DifferentialFunctionFactory {
* @return
*/
public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
AvgPooling2D avgPooling2D = AvgPooling2D.builder()
AvgPooling2D avgPooling2D = AvgPooling2D.sameDiffBuilder()
.input(input)
.sameDiff(sameDiff())
.config(pooling2DConfig)
@ -547,7 +575,7 @@ public class DifferentialFunctionFactory {
* @return
*/
public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
MaxPooling2D maxPooling2D = MaxPooling2D.builder()
MaxPooling2D maxPooling2D = MaxPooling2D.sameDiffBuilder()
.input(input)
.sameDiff(sameDiff())
.config(pooling2DConfig)
@ -590,7 +618,7 @@ public class DifferentialFunctionFactory {
* @return
*/
public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
SConv2D sconv2D = SConv2D.sBuilder()
SConv2D sconv2D = SConv2D.sameDiffSBuilder()
.inputFunctions(inputs)
.sameDiff(sameDiff())
.conv2DConfig(conv2DConfig)
@ -609,7 +637,7 @@ public class DifferentialFunctionFactory {
* @return
*/
public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
SConv2D depthWiseConv2D = SConv2D.sBuilder()
SConv2D depthWiseConv2D = SConv2D.sameDiffSBuilder()
.inputFunctions(inputs)
.sameDiff(sameDiff())
.conv2DConfig(depthConv2DConfig)
@ -627,7 +655,7 @@ public class DifferentialFunctionFactory {
* @return
*/
public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
DeConv2D deconv2D = DeConv2D.builder()
DeConv2D deconv2D = DeConv2D.sameDiffBuilder()
.inputs(inputs)
.sameDiff(sameDiff())
.config(deconv2DConfig)
@ -654,9 +682,9 @@ public class DifferentialFunctionFactory {
* @return
*/
public SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig) {
Conv3D conv3D = Conv3D.builder()
Conv3D conv3D = Conv3D.sameDiffBuilder()
.inputFunctions(inputs)
.conv3DConfig(conv3DConfig)
.config(conv3DConfig)
.sameDiff(sameDiff())
.build();
@ -1260,6 +1288,22 @@ public class DifferentialFunctionFactory {
return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable();
}
public SDVariable bitwiseHammingDist(SDVariable x, SDVariable y) {
return new BitsHammingDistance(sameDiff(), x, y).outputVariable();
}
public SDVariable bitwiseAnd(SDVariable x, SDVariable y){
return new BitwiseAnd(sameDiff(), x, y).outputVariable();
}
public SDVariable bitwiseOr(SDVariable x, SDVariable y){
return new BitwiseOr(sameDiff(), x, y).outputVariable();
}
public SDVariable bitwiseXor(SDVariable x, SDVariable y){
return new BitwiseXor(sameDiff(), x, y).outputVariable();
}
public SDVariable eq(SDVariable iX, SDVariable i_y) {
return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable();
}

View File

@ -188,6 +188,11 @@ public class SameDiff extends SDBaseOps {
*/
public final SDImage image = new SDImage(this);
/**
* Op creator object for bitwise operations
*/
public final SDBitwise bitwise = new SDBitwise(this);
/**
* Op creator object for math operations
*/
@ -237,6 +242,13 @@ public class SameDiff extends SDBaseOps {
return image;
}
/**
* Op creator object for bitwise operations
*/
public SDBitwise bitwise(){
return bitwise;
}
/**
* For import, many times we have variables

View File

@ -0,0 +1,205 @@
package org.nd4j.autodiff.samediff.ops;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
/**
*
*/
public class SDBitwise extends SDOps {
public SDBitwise(SameDiff sameDiff) {
super(sameDiff);
}
/**
* See {@link #leftShift(String, SDVariable, SDVariable)}
*/
public SDVariable leftShift(@NonNull SDVariable x, @NonNull SDVariable y){
return leftShift(null, x, y);
}
/**
* Bitwise left shift operation. Supports broadcasting.
*
* @param name Name of the output variable. May be null.
* @param x Input to be bit shifted (must be an integer type)
* @param y Amount to shift elements of x array (must be an integer type)
* @return Bitwise shifted input x
*/
public SDVariable leftShift(String name, SDVariable x, SDVariable y){
validateInteger("bitwise left shift", x);
validateInteger("bitwise left shift", y);
SDVariable ret = f().shift(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #rightShift(String, SDVariable, SDVariable)}
*/
public SDVariable rightShift(SDVariable x, SDVariable y){
return rightShift(null, x, y);
}
/**
* Bitwise right shift operation. Supports broadcasting.
*
* @param name Name of the output variable. May be null.
* @param x Input to be bit shifted (must be an integer type)
* @param y Amount to shift elements of x array (must be an integer type)
* @return Bitwise shifted input x
*/
public SDVariable rightShift(String name, SDVariable x, SDVariable y){
validateInteger("bitwise right shift", x);
validateInteger("bitwise right shift", y);
SDVariable ret = f().rshift(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #leftShiftCyclic(String, SDVariable, SDVariable)}
*/
public SDVariable leftShiftCyclic(SDVariable x, SDVariable y){
return leftShiftCyclic(null, x, y);
}
/**
* Bitwise left cyclical shift operation. Supports broadcasting.
* Unlike {@link #leftShift(String, SDVariable, SDVariable)} the bits will "wrap around":
* {@code leftShiftCyclic(01110000, 2) -> 11000001}
*
* @param name Name of the output variable. May be null.
* @param x Input to be bit shifted (must be an integer type)
* @param y Amount to shift elements of x array (must be an integer type)
* @return Bitwise cyclic shifted input x
*/
public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y){
validateInteger("bitwise left shift (cyclic)", x);
validateInteger("bitwise left shift (cyclic)", y);
SDVariable ret = f().rotl(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #rightShiftCyclic(String, SDVariable, SDVariable)}
*/
public SDVariable rightShiftCyclic(SDVariable x, SDVariable y){
return rightShiftCyclic(null, x, y);
}
/**
* Bitwise right cyclical shift operation. Supports broadcasting.
* Unlike {@link #rightShift(String, SDVariable, SDVariable)} the bits will "wrap around":
* {@code rightShiftCyclic(00001110, 2) -> 10000011}
*
* @param name Name of the output variable. May be null.
* @param x Input to be bit shifted (must be an integer type)
* @param y Amount to shift elements of x array (must be an integer type)
* @return Bitwise cyclic shifted input x
*/
public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y){
validateInteger("bitwise right shift (cyclic)", x);
validateInteger("bitwise right shift (cyclic)", y);
SDVariable ret = f().rotr(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #bitsHammingDistance(String, SDVariable, SDVariable)}
*/
public SDVariable bitsHammingDistance(SDVariable x, SDVariable y){
return bitsHammingDistance(null, x, y);
}
/**
* Bitwise Hamming distance reduction over all elements of both input arrays.<br>
* For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
*
* @param name Name of the output variable. May be null.
* @param x First input array. Must be integer type.
* @param y First input array. Must be integer type, same type as x
* @return
*/
public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y){
validateInteger("bitwise hamming distance", x);
validateInteger("bitwise hamming distance", y);
SDVariable ret = f().bitwiseHammingDist(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #and(String, SDVariable, SDVariable)}
*/
public SDVariable and(SDVariable x, SDVariable y){
return and(null, x, y);
}
/**
* Bitwise AND operation. Supports broadcasting.
*
* @param name Name of the output variable. May be null.
* @param x First input array. Must be integer type.
* @param y First input array. Must be integer type, same type as x
* @return Bitwise AND array
*/
public SDVariable and(String name, SDVariable x, SDVariable y){
validateInteger("bitwise AND", x);
validateInteger("bitwise AND", y);
SDVariable ret = f().bitwiseAnd(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #or(String, SDVariable, SDVariable)}
*/
public SDVariable or(SDVariable x, SDVariable y){
return or(null, x, y);
}
/**
* Bitwise OR operation. Supports broadcasting.
*
* @param name Name of the output variable. May be null.
* @param x First input array. Must be integer type.
* @param y First input array. Must be integer type, same type as x
* @return Bitwise OR array
*/
public SDVariable or(String name, SDVariable x, SDVariable y){
validateInteger("bitwise OR", x);
validateInteger("bitwise OR", y);
SDVariable ret = f().bitwiseOr(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #xor(String, SDVariable, SDVariable)}
*/
public SDVariable xor(SDVariable x, SDVariable y){
return xor(null, x, y);
}
/**
* Bitwise XOR operation (exclusive OR). Supports broadcasting.
*
* @param name Name of the output variable. May be null.
* @param x First input array. Must be integer type.
* @param y First input array. Must be integer type, same type as x
* @return Bitwise XOR array
*/
public SDVariable xor(String name, SDVariable x, SDVariable y){
validateInteger("bitwise XOR", x);
validateInteger("bitwise XOR", y);
SDVariable ret = f().bitwiseXor(x, y);
return updateVariableNameAndReference(ret, name);
}
}

View File

@ -16,6 +16,7 @@
package org.nd4j.autodiff.samediff.ops;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
@ -38,14 +39,9 @@ public class SDCNN extends SDOps {
}
/**
* 2D Convolution layer operation - average pooling 2d
*
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param pooling2DConfig the configuration for
* @return Result after applying average pooling on the input
* See {@link #avgPooling2d(String, SDVariable, Pooling2DConfig)}.
*/
public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
public SDVariable avgPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
return avgPooling2d(null, input, pooling2DConfig);
}
@ -58,22 +54,16 @@ public class SDCNN extends SDOps {
* @param pooling2DConfig the configuration
* @return Result after applying average pooling on the input
*/
public SDVariable avgPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) {
public SDVariable avgPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
validateFloatingPoint("avgPooling2d", input);
SDVariable ret = f().avgPooling2d(input, pooling2DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* 3D convolution layer operation - average pooling 3d
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param pooling3DConfig the configuration
* @return Result after applying average pooling on the input
* See {@link #avgPooling3d(String, SDVariable, Pooling3DConfig)}.
*/
public SDVariable avgPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) {
public SDVariable avgPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
return avgPooling3d(null, input, pooling3DConfig);
}
@ -87,7 +77,7 @@ public class SDCNN extends SDOps {
* @param pooling3DConfig the configuration
* @return Result after applying average pooling on the input
*/
public SDVariable avgPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) {
public SDVariable avgPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
validateFloatingPoint("avgPooling3d", input);
SDVariable ret = f().avgPooling3d(input, pooling3DConfig);
return updateVariableNameAndReference(ret, name);
@ -96,7 +86,7 @@ public class SDCNN extends SDOps {
/**
* @see #batchToSpace(String, SDVariable, int[], int[][])
*/
public SDVariable batchToSpace(SDVariable x, int[] blocks, int[][] crops) {
public SDVariable batchToSpace(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) {
return batchToSpace(null, x, blocks, crops);
}
@ -111,7 +101,7 @@ public class SDCNN extends SDOps {
* @return Output variable
* @see #spaceToBatch(String, SDVariable, int[], int[][])
*/
public SDVariable batchToSpace(String name, SDVariable x, int[] blocks, int[][] crops) {
public SDVariable batchToSpace(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) {
validateNumerical("batchToSpace", x);
SDVariable ret = f().batchToSpace(x, blocks, crops);
return updateVariableNameAndReference(ret, name);
@ -119,14 +109,9 @@ public class SDCNN extends SDOps {
/**
* col2im operation for use in 2D convolution operations. Outputs a 4d array with shape
* [minibatch, inputChannels, height, width]
*
* @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
* @param config Convolution configuration for the col2im operation
* @return Col2Im output variable
* See {@link #col2Im(String, SDVariable, Conv2DConfig)}.
*/
public SDVariable col2Im(SDVariable in, Conv2DConfig config) {
public SDVariable col2Im(@NonNull SDVariable in, @NonNull Conv2DConfig config) {
return col2Im(null, in, config);
}
@ -139,33 +124,22 @@ public class SDCNN extends SDOps {
* @param config Convolution configuration for the col2im operation
* @return Col2Im output variable
*/
public SDVariable col2Im(String name, SDVariable in, Conv2DConfig config) {
public SDVariable col2Im(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) {
SDVariable ret = f().col2Im(in, config);
return updateVariableNameAndReference(ret, name);
}
/**
* 1D Convolution layer operation - Conv1d
*
* @param input the input array/activations for the conv1d op
* @param weights weights for conv1d op - rank 3 array with values [kernelSize, inputChannels, outputChannels]
* @param conv1DConfig the configuration
* @return
* See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias.
*/
public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
return conv1d(null, input, weights, conv1DConfig);
public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) {
return conv1d((String) null, input, weights, conv1DConfig);
}
/**
* Conv1d operation.
*
* @param name name of the operation in SameDiff
* @param input the inputs to conv1d
* @param weights weights for conv1d op - rank 3 array with values [kernelSize, inputChannels, outputChannels]
* @param conv1DConfig the configuration
* @return
* See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias.
*/
public SDVariable conv1d(String name, SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) {
validateFloatingPoint("conv1d", input);
validateFloatingPoint("conv1d", weights);
SDVariable ret = f().conv1d(input, weights, conv1DConfig);
@ -173,21 +147,55 @@ public class SDCNN extends SDOps {
}
/**
* 2D Convolution operation (without bias)
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels]
* @param config Conv2DConfig configuration
* @return result of conv2d op
* See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}.
*/
public SDVariable conv2d(SDVariable layerInput, SDVariable weights, Conv2DConfig config) {
public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) {
return conv1d(null, input, weights, bias, conv1DConfig);
}
/**
* Conv1d operation.
*
* @param name name of the operation in SameDiff
* @param input the inputs to conv1d
* @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels]
* @param bias bias for conv1d op - rank 1 array with shape [outputChannels]. May be null.
* @param conv1DConfig the configuration
* @return
*/
public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) {
validateFloatingPoint("conv1d", input);
validateFloatingPoint("conv1d", weights);
validateFloatingPoint("conv1d", bias);
SDVariable ret = f().conv1d(input, weights, bias, conv1DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*/
public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) {
return conv2d(layerInput, weights, null, config);
}
/**
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*/
public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) {
return conv2d(name, layerInput, weights, null, config);
}
/**
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
*/
public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) {
return conv2d(null, layerInput, weights, bias, config);
}
/**
* 2D Convolution operation with optional bias
*
* @param name name of the operation in SameDiff
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels]
@ -195,7 +203,7 @@ public class SDCNN extends SDOps {
* @param config Conv2DConfig configuration
* @return result of conv2d op
*/
public SDVariable conv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, Conv2DConfig config) {
public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) {
validateFloatingPoint("conv2d", "input", layerInput);
validateFloatingPoint("conv2d", "weights", weights);
validateFloatingPoint("conv2d", "bias", bias);
@ -204,18 +212,13 @@ public class SDCNN extends SDOps {
arr[1] = weights;
if (bias != null)
arr[2] = bias;
return conv2d(arr, config);
return conv2d(name, arr, config);
}
/**
* 2D Convolution operation with optional bias
*
* @param inputs an array with either 2 elements (layerInput, weights) or 3 elements (layerInput, weights, bias) as
* described in {@link #conv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)}
* @param config Conv2DConfig configuration
* @return result of convolution 2d operation
* See {@link #conv2d(String, SDVariable[], Conv2DConfig)}.
*/
public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig config) {
public SDVariable conv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) {
return conv2d(null, inputs, config);
}
@ -228,7 +231,7 @@ public class SDCNN extends SDOps {
* @param config Conv2DConfig configuration
* @return result of convolution 2d operation
*/
public SDVariable conv2d(String name, SDVariable[] inputs, Conv2DConfig config) {
public SDVariable conv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) {
for(SDVariable v : inputs)
validateNumerical("conv2d", v);
SDVariable ret = f().conv2d(inputs, config);
@ -236,19 +239,26 @@ public class SDCNN extends SDOps {
}
/**
* Convolution 3D operation without bias
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
* @param conv3DConfig the configuration
* @return Conv3d output variable
* See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias.
*/
public SDVariable conv3d(SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) {
public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) {
return conv3d(null, input, weights, null, conv3DConfig);
}
/**
* See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias.
*/
public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) {
return conv3d(name, input, weights, null, conv3DConfig);
}
/**
* See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}.
*/
public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) {
return conv3d(null, input, weights, bias, conv3DConfig);
}
/**
* Convolution 3D operation with optional bias
*
@ -261,7 +271,7 @@ public class SDCNN extends SDOps {
* @param conv3DConfig the configuration
* @return Conv3d output variable
*/
public SDVariable conv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) {
public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) {
validateFloatingPoint("conv3d", "input", input);
validateFloatingPoint("conv3d", "weights", weights);
validateFloatingPoint("conv3d", "bias", bias);
@ -276,51 +286,30 @@ public class SDCNN extends SDOps {
}
/**
* Convolution 3D operation with optional bias
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
* @param bias Optional 1D bias array with shape [outputChannels]. May be null.
* @param conv3DConfig the configuration
* @return Conv3d output variable
* See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias.
*/
public SDVariable conv3d(SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) {
return conv3d(null, input, weights, bias, conv3DConfig);
}
/**
* Convolution 3D operation without bias
*
* @param name Name of the output variable
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
* @param conv3DConfig the configuration
* @return Conv3d output variable
*/
public SDVariable conv3d(String name, SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) {
return conv3d(name, input, weights, null, conv3DConfig);
}
/**
* 2D deconvolution operation without bias
*
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth].
* @param deconv2DConfig DeConv2DConfig configuration
* @return result of deconv2d op
*/
public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, DeConv2DConfig deconv2DConfig) {
public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) {
return deconv2d(layerInput, weights, null, deconv2DConfig);
}
/**
* See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias.
*/
public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) {
return deconv2d(name, layerInput, weights, null, deconv2DConfig);
}
/**
* See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}.
*/
public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) {
return deconv2d(null, layerInput, weights, bias, deconv2DConfig);
}
/**
* 2D deconvolution operation with optional bias
*
* @param name name of the operation in SameDiff
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth].
@ -328,7 +317,7 @@ public class SDCNN extends SDOps {
* @param deconv2DConfig DeConv2DConfig configuration
* @return result of deconv2d op
*/
public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, DeConv2DConfig deconv2DConfig) {
public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) {
validateFloatingPoint("deconv2d", "input", layerInput);
validateFloatingPoint("deconv2d", "weights", weights);
validateFloatingPoint("deconv2d", "bias", bias);
@ -337,18 +326,13 @@ public class SDCNN extends SDOps {
arr[1] = weights;
if (bias != null)
arr[2] = bias;
return deconv2d(arr, deconv2DConfig);
return deconv2d(name, arr, deconv2DConfig);
}
/**
* 2D deconvolution operation with or without optional bias
*
* @param inputs Inputs to the deconvolution 2d operation - input array of length 2 (layerInput, weights)
* or length 3 (layerInput, weights, bias) as described in {@link #deconv2d(SDVariable[], DeConv2DConfig)}
* @param deconv2DConfig the configuration
* @return result of deconv2d op
* See {@link #deconv2d(String, SDVariable[], DeConv2DConfig)}.
*/
public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
public SDVariable deconv2d(@NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) {
return deconv2d(null, inputs, deconv2DConfig);
}
@ -361,13 +345,34 @@ public class SDCNN extends SDOps {
* @param deconv2DConfig the configuration
* @return result of deconv2d op
*/
public SDVariable deconv2d(String name, SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
public SDVariable deconv2d(String name, @NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) {
for(SDVariable v : inputs)
validateNumerical("deconv2d", v);
SDVariable ret = f().deconv2d(inputs, deconv2DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias.
*/
public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) {
return deconv3d(input, weights, null, config);
}
/**
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias.
*/
public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) {
return deconv3d(name, input, weights, null, config);
}
/**
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}.
*/
public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
return deconv3d(null, input, weights, bias, config);
}
/**
* 3D CNN deconvolution operation with or without optional bias
*
@ -377,7 +382,7 @@ public class SDCNN extends SDOps {
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels]
* @param config Configuration
*/
public SDVariable deconv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) {
public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
validateFloatingPoint("conv3d", input);
validateFloatingPoint("conv3d", weights);
validateFloatingPoint("conv3d", bias);
@ -386,41 +391,9 @@ public class SDCNN extends SDOps {
}
/**
* 3D CNN deconvolution operation with or without optional bias
*
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
* @param weights Weights array - shape [kD, kH, kW, oC, iC]
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels]
* @param config Configuration
* See {@link #depthToSpace(String, SDVariable, int, String)}.
*/
public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) {
return deconv3d(null, input, weights, bias, config);
}
/**
* 3D CNN deconvolution operation with no bias
*
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
* @param weights Weights array - shape [kD, kH, kW, oC, iC]
* @param config Configuration
*/
public SDVariable deconv3d(SDVariable input, SDVariable weights, DeConv3DConfig config) {
return deconv3d(input, weights, null, config);
}
/**
* Convolution 2d layer batch to space operation on 4d input.<br>
* Reduces input channels dimension by rearranging data into a larger spatial dimensions<br>
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
* = [mb, 2, 4, 4]
*
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param blockSize Block size, in the height/width dimension
* @param dataFormat Data format: "NCHW" or "NHWC"
* @return Output variable
*/
public SDVariable depthToSpace(SDVariable x, int blockSize, String dataFormat) {
public SDVariable depthToSpace(@NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) {
return depthToSpace(null, x, blockSize, dataFormat);
}
@ -438,27 +411,36 @@ public class SDCNN extends SDOps {
* @return Output variable
* @see #depthToSpace(String, SDVariable, int, String)
*/
public SDVariable depthToSpace(String name, SDVariable x, int blockSize, String dataFormat) {
public SDVariable depthToSpace(String name, @NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) {
SDVariable ret = f().depthToSpace(x, blockSize, dataFormat);
return updateVariableNameAndReference(ret, name);
}
/**
* Depth-wise 2D convolution operation without bias
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
* @param config Conv2DConfig configuration
* @return result of conv2d op
* See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*/
public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, Conv2DConfig config) {
public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) {
return depthWiseConv2d(layerInput, depthWeights, null, config);
}
/**
* See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*/
public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) {
return depthWiseConv2d(name, layerInput, depthWeights, null, config);
}
/**
* See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
*/
public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) {
return depthWiseConv2d(null, layerInput, depthWeights, bias, config);
}
/**
* Depth-wise 2D convolution operation with optional bias
*
* @param name name of the operation in SameDiff
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
@ -466,7 +448,7 @@ public class SDCNN extends SDOps {
* @param config Conv2DConfig configuration
* @return result of depthwise conv2d op
*/
public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable bias, Conv2DConfig config) {
public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) {
validateFloatingPoint("depthwiseConv2d", "input", layerInput);
validateFloatingPoint("depthwiseConv2d", "depth weights", depthWeights);
validateFloatingPoint("depthwiseConv2d", "bias", bias);
@ -475,19 +457,13 @@ public class SDCNN extends SDOps {
arr[1] = depthWeights;
if (bias != null)
arr[2] = bias;
return depthWiseConv2d(arr, config);
return depthWiseConv2d(name, arr, config);
}
/**
* Depth-wise convolution 2D operation.
*
* @param inputs the inputs to depth-wise conv2d. An array with either 2 elements (layerInput, depthWeights)
* or 3 elements (layerInput, depthWeights, bias) as described in
* {@link #depthWiseConv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)}
* @param depthConv2DConfig the configuration
* @return result of depthwise conv2d op
* See {@link #depthWiseConv2d(String, SDVariable[], Conv2DConfig)}.
*/
public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
public SDVariable depthWiseConv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) {
return depthWiseConv2d(null, inputs, depthConv2DConfig);
}
@ -501,7 +477,7 @@ public class SDCNN extends SDOps {
* @param depthConv2DConfig the configuration
* @return result of depthwise conv2d op
*/
public SDVariable depthWiseConv2d(String name, SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
public SDVariable depthWiseConv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) {
for(SDVariable v : inputs)
validateFloatingPoint("depthWiseConv2d", v);
SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig);
@ -509,17 +485,10 @@ public class SDCNN extends SDOps {
}
/**
* TODO doc string
*
* @param df
* @param weights
* @param strides
* @param rates
* @param isSameMode
* @return
* See {@link #dilation2D(String, SDVariable, SDVariable, int[], int[], boolean)}.
*/
public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides,
int[] rates, boolean isSameMode) {
public SDVariable dilation2D(@NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides,
@NonNull int[] rates, @NonNull boolean isSameMode) {
return dilation2D(null, df, weights, strides, rates, isSameMode);
}
@ -534,8 +503,8 @@ public class SDCNN extends SDOps {
* @param isSameMode
* @return
*/
public SDVariable dilation2D(String name, SDVariable df, SDVariable weights, int[] strides,
int[] rates, boolean isSameMode) {
public SDVariable dilation2D(String name, @NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides,
@NonNull int[] rates, @NonNull boolean isSameMode) {
SDVariable ret = f().dilation2D(df, weights, strides, rates, isSameMode);
return updateVariableNameAndReference(ret, name);
}
@ -555,21 +524,16 @@ public class SDCNN extends SDOps {
* @param sameMode If true: use same mode padding. If false
* @return
*/
public SDVariable extractImagePatches(String name, SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) {
public SDVariable extractImagePatches(String name, @NonNull SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) {
SDVariable ret = f().extractImagePatches(input, kH, kW, sH, sW, rH, rW, sameMode);
return updateVariableNameAndReference(ret, name);
}
/**
* im2col operation for use in 2D convolution operations. Outputs a 6d array with shape
* [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
*
* @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width]
* @param config Convolution configuration for the im2col operation
* @return Im2Col output variable
* See {@link #im2Col(String, SDVariable, Conv2DConfig)}.
*/
public SDVariable im2Col(SDVariable in, Conv2DConfig config) {
public SDVariable im2Col(@NonNull SDVariable in, @NonNull Conv2DConfig config) {
return im2Col(null, in, config);
}
@ -582,20 +546,16 @@ public class SDCNN extends SDOps {
* @param config Convolution configuration for the im2col operation
* @return Im2Col output variable
*/
public SDVariable im2Col(String name, SDVariable in, Conv2DConfig config) {
public SDVariable im2Col(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) {
SDVariable ret = f().im2Col(in, config);
return updateVariableNameAndReference(ret, name);
}
/**
* 2D convolution layer operation - local response normalization
*
* @param inputs the inputs to lrn
* @param lrnConfig the configuration
* @return
* See {@link #localResponseNormalization(String, SDVariable, LocalResponseNormalizationConfig)}.
*/
public SDVariable localResponseNormalization(SDVariable inputs, LocalResponseNormalizationConfig lrnConfig) {
public SDVariable localResponseNormalization(@NonNull SDVariable inputs, @NonNull LocalResponseNormalizationConfig lrnConfig) {
return localResponseNormalization(null, inputs, lrnConfig);
}
@ -607,8 +567,8 @@ public class SDCNN extends SDOps {
* @param lrnConfig the configuration
* @return
*/
public SDVariable localResponseNormalization(String name, SDVariable input,
LocalResponseNormalizationConfig lrnConfig) {
public SDVariable localResponseNormalization(String name, @NonNull SDVariable input,
@NonNull LocalResponseNormalizationConfig lrnConfig) {
validateFloatingPoint("local response normalization", input);
SDVariable ret = f().localResponseNormalization(input, lrnConfig);
return updateVariableNameAndReference(ret, name);
@ -616,14 +576,9 @@ public class SDCNN extends SDOps {
/**
* 2D Convolution layer operation - max pooling 2d
*
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param pooling2DConfig the configuration
* @return Result after applying max pooling on the input
* See {@link #maxPooling2d(String, SDVariable, Pooling2DConfig)}.
*/
public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
public SDVariable maxPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
return maxPooling2d(null, input, pooling2DConfig);
}
@ -636,22 +591,16 @@ public class SDCNN extends SDOps {
* @param pooling2DConfig the configuration
* @return Result after applying max pooling on the input
*/
public SDVariable maxPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) {
public SDVariable maxPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
validateNumerical("maxPooling2d", input);
SDVariable ret = f().maxPooling2d(input, pooling2DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* 3D convolution layer operation - max pooling 3d operation.
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param pooling3DConfig the configuration
* @return Result after applying max pooling on the input
* See {@link #maxPooling3d(String, SDVariable, Pooling3DConfig)}.
*/
public SDVariable maxPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) {
public SDVariable maxPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
return maxPooling3d(null, input, pooling3DConfig);
}
@ -665,7 +614,7 @@ public class SDCNN extends SDOps {
* @param pooling3DConfig the configuration
* @return Result after applying max pooling on the input
*/
public SDVariable maxPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) {
public SDVariable maxPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
validateNumerical("maxPooling3d", input);
SDVariable ret = f().maxPooling3d(input, pooling3DConfig);
return updateVariableNameAndReference(ret, name);
@ -673,21 +622,30 @@ public class SDCNN extends SDOps {
/**
* Separable 2D convolution operation without bias
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]
* May be null
* @param config Conv2DConfig configuration
* @return result of separable convolution 2d operation
* See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*/
public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights,
Conv2DConfig config) {
public SDVariable separableConv2d(SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
@NonNull Conv2DConfig config) {
return separableConv2d(layerInput, depthWeights, pointWeights, null, config);
}
/**
* See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
*/
public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
@NonNull Conv2DConfig config) {
return separableConv2d(layerInput, depthWeights, pointWeights, null, config);
}
/**
* See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
*/
public SDVariable separableConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
SDVariable bias, @NonNull Conv2DConfig config) {
return separableConv2d(null, layerInput, depthWeights, pointWeights, bias, config);
}
/**
* Separable 2D convolution operation with optional bias
*
@ -700,8 +658,8 @@ public class SDCNN extends SDOps {
* @param config Conv2DConfig configuration
* @return result of separable convolution 2d operation
*/
public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights,
SDVariable bias, Conv2DConfig config) {
public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
SDVariable bias, @NonNull Conv2DConfig config) {
validateFloatingPoint("separableConv2d", "input", layerInput);
validateFloatingPoint("separableConv2d", "depthWeights", depthWeights);
validateFloatingPoint("separableConv2d", "pointWeights", pointWeights);
@ -712,18 +670,13 @@ public class SDCNN extends SDOps {
arr[2] = pointWeights;
if (bias != null)
arr[3] = bias;
return sconv2d(arr, config);
return sconv2d(name, arr, config);
}
/**
* Separable 2D convolution operation with/without optional bias
*
* @param inputs the inputs to separable conv2 operation. Should be length 3 (layerInput, depthWeights, pointWeights)
* or length 4 (layerInput, depthWeights, pointWeights, bias) as described in {@link #separableConv2d(SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}
* @param conv2DConfig the configuration
* @return result of separable convolution 2d operation
* See {@link #sconv2d(String, SDVariable[], Conv2DConfig)}.
*/
public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
public SDVariable sconv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) {
return sconv2d(null, inputs, conv2DConfig);
}
@ -736,7 +689,7 @@ public class SDCNN extends SDOps {
* @param conv2DConfig the configuration
* @return result of separable convolution 2d operation
*/
public SDVariable sconv2d(String name, SDVariable[] inputs, Conv2DConfig conv2DConfig) {
public SDVariable sconv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) {
for(SDVariable v : inputs)
validateFloatingPoint("sconv2d", v);
SDVariable ret = f().sconv2d(inputs, conv2DConfig);
@ -747,7 +700,7 @@ public class SDCNN extends SDOps {
/**
* @see #spaceToBatch(String, SDVariable, int[], int[][])
*/
public SDVariable spaceToBatch(SDVariable x, int[] blocks, int[][] padding) {
public SDVariable spaceToBatch(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) {
return spaceToBatch(null, x, blocks, padding);
}
@ -762,7 +715,7 @@ public class SDCNN extends SDOps {
* @return Output variable
* @see #batchToSpace(String, SDVariable, int[], int[][])
*/
public SDVariable spaceToBatch(String name, SDVariable x, int[] blocks, int[][] padding) {
public SDVariable spaceToBatch(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) {
SDVariable ret = f().spaceToBatch(x, blocks, padding);
return updateVariableNameAndReference(ret, name);
}
@ -770,7 +723,7 @@ public class SDCNN extends SDOps {
/**
* @see #spaceToDepth(String, SDVariable, int, String)
*/
public SDVariable spaceToDepth(SDVariable x, int blockSize, String dataFormat) {
public SDVariable spaceToDepth(@NonNull SDVariable x, int blockSize, @NonNull String dataFormat) {
return spaceToDepth(null, x, blockSize, dataFormat);
}
@ -788,23 +741,39 @@ public class SDCNN extends SDOps {
* @return Output variable
* @see #depthToSpace(String, SDVariable, int, String)
*/
public SDVariable spaceToDepth(String name, SDVariable x, int blockSize, String dataFormat) {
public SDVariable spaceToDepth(String name, @NonNull SDVariable x, int blockSize, @NonNull String dataFormat) {
SDVariable ret = f().spaceToDepth(x, blockSize, dataFormat);
return updateVariableNameAndReference(ret, name);
}
/**
* 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format.
* See {@link #upsampling2d(String, SDVariable, boolean, int, int)},
* scale is used for both height and width dimensions.
*
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
* @param scale Scale to upsample in both H and W dimensions
* @return Upsampled input
* @param scale The scale for both height and width dimensions.
*/
public SDVariable upsampling2d(SDVariable input, int scale) {
public SDVariable upsampling2d(@NonNull SDVariable input, int scale) {
return upsampling2d(null, input, true, scale, scale);
}
/**
* See {@link #upsampling2d(String, SDVariable, boolean, int, int)},
* scale is used for both height and width dimensions.
*
* @param scale The scale for both height and width dimensions.
*/
public SDVariable upsampling2d(String name, @NonNull SDVariable input, int scale) {
return upsampling2d(name, input, true, scale, scale);
}
/**
* See {@link #upsampling2d(String, SDVariable, boolean, int, int)}.
*/
public SDVariable upsampling2d(@NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) {
return upsampling2d(null, input, nchw, scaleH, scaleW);
}
/**
* 2D Convolution layer operation - Upsampling 2d
*
@ -814,33 +783,8 @@ public class SDCNN extends SDOps {
* @param scaleW Scale to upsample in width dimension
* @return Upsampled input
*/
public SDVariable upsampling2d(String name, SDVariable input, boolean nchw, int scaleH, int scaleW) {
public SDVariable upsampling2d(String name, @NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) {
SDVariable ret = f().upsampling2d(input, nchw, scaleH, scaleW);
return updateVariableNameAndReference(ret, name);
}
/**
* 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format.
*
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
* @param scale Scale to upsample in both H and W dimensions
* @return Upsampled input
*/
public SDVariable upsampling2d(String name, SDVariable input, int scale) {
return upsampling2d(name, input, true, scale, scale);
}
/**
* 2D Convolution layer operation - Upsampling 2d
*
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
* or NHWC format (shape [minibatch, height, width, channels])
* @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format
* @param scaleH Scale to upsample in height dimension
* @param scaleW Scale to upsample in width dimension
* @return Upsampled input
*/
public SDVariable upsampling2d(SDVariable input, boolean nchw, int scaleH, int scaleW) {
return upsampling2d(null, input, nchw, scaleH, scaleW);
}
}

View File

@ -31,7 +31,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions;
import org.nd4j.graph.DataType;
import org.nd4j.graph.DType;
import org.nd4j.graph.FlatArray;
import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatProperties;
@ -66,33 +66,33 @@ public class FlatBuffersMapper {
public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) {
switch (type) {
case FLOAT:
return DataType.FLOAT;
return DType.FLOAT;
case DOUBLE:
return DataType.DOUBLE;
return DType.DOUBLE;
case HALF:
return DataType.HALF;
return DType.HALF;
case INT:
return DataType.INT32;
return DType.INT32;
case LONG:
return DataType.INT64;
return DType.INT64;
case BOOL:
return DataType.BOOL;
return DType.BOOL;
case SHORT:
return DataType.INT16;
return DType.INT16;
case BYTE:
return DataType.INT8;
return DType.INT8;
case UBYTE:
return DataType.UINT8;
return DType.UINT8;
case UTF8:
return DataType.UTF8;
return DType.UTF8;
case UINT16:
return DataType.UINT16;
return DType.UINT16;
case UINT32:
return DataType.UINT32;
return DType.UINT32;
case UINT64:
return DataType.UINT64;
return DType.UINT64;
case BFLOAT16:
return DataType.BFLOAT16;
return DType.BFLOAT16;
default:
throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]");
}
@ -102,33 +102,33 @@ public class FlatBuffersMapper {
* This method converts enums for DataType
*/
public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) {
if (val == DataType.FLOAT) {
if (val == DType.FLOAT) {
return org.nd4j.linalg.api.buffer.DataType.FLOAT;
} else if (val == DataType.DOUBLE) {
} else if (val == DType.DOUBLE) {
return org.nd4j.linalg.api.buffer.DataType.DOUBLE;
} else if (val == DataType.HALF) {
} else if (val == DType.HALF) {
return org.nd4j.linalg.api.buffer.DataType.HALF;
} else if (val == DataType.INT32) {
} else if (val == DType.INT32) {
return org.nd4j.linalg.api.buffer.DataType.INT;
} else if (val == DataType.INT64) {
} else if (val == DType.INT64) {
return org.nd4j.linalg.api.buffer.DataType.LONG;
} else if (val == DataType.INT8) {
} else if (val == DType.INT8) {
return org.nd4j.linalg.api.buffer.DataType.BYTE;
} else if (val == DataType.BOOL) {
} else if (val == DType.BOOL) {
return org.nd4j.linalg.api.buffer.DataType.BOOL;
} else if (val == DataType.UINT8) {
} else if (val == DType.UINT8) {
return org.nd4j.linalg.api.buffer.DataType.UBYTE;
} else if (val == DataType.INT16) {
} else if (val == DType.INT16) {
return org.nd4j.linalg.api.buffer.DataType.SHORT;
} else if (val == DataType.UTF8) {
} else if (val == DType.UTF8) {
return org.nd4j.linalg.api.buffer.DataType.UTF8;
} else if (val == DataType.UINT16) {
} else if (val == DType.UINT16) {
return org.nd4j.linalg.api.buffer.DataType.UINT16;
} else if (val == DataType.UINT32) {
} else if (val == DType.UINT32) {
return org.nd4j.linalg.api.buffer.DataType.UINT32;
} else if (val == DataType.UINT64) {
} else if (val == DType.UINT64) {
return org.nd4j.linalg.api.buffer.DataType.UINT64;
} else if (val == DataType.BFLOAT16){
} else if (val == DType.BFLOAT16){
return org.nd4j.linalg.api.buffer.DataType.BFLOAT16;
} else {
throw new RuntimeException("Unknown datatype: " + val);

View File

@ -2,8 +2,8 @@
package org.nd4j.graph;
public final class DataType {
private DataType() { }
public final class DType {
private DType() { }
public static final byte INHERIT = 0;
public static final byte BOOL = 1;
public static final byte FLOAT8 = 2;

View File

@ -353,6 +353,10 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D.class,

View File

@ -1149,16 +1149,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride(), elementWiseStride(), order, this.dataType(), isEmpty()));
}
@Override
public void setShape(long[] shape) {
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride(), elementWiseStride(), ordering(), this.dataType(), isEmpty()));
}
@Override
public void setStride(long[] stride) {
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride, elementWiseStride(), ordering(), this.dataType(), isEmpty()));
}
@Override
public void setShapeAndStride(int[] shape, int[] stride) {
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 0, ordering(), this.dataType(), false));
@ -1283,29 +1273,16 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return scalar.getDouble(0);
}
/**
* Returns entropy value for this INDArray
* @return
*/
@Override
public Number entropyNumber() {
return entropy(Integer.MAX_VALUE).getDouble(0);
}
/**
* Returns non-normalized Shannon entropy value for this INDArray
* @return
*/
@Override
public Number shannonEntropyNumber() {
return shannonEntropy(Integer.MAX_VALUE).getDouble(0);
}
/**
* Returns log entropy value for this INDArray
* @return
*/
@Override
public Number logEntropyNumber() {
return logEntropy(Integer.MAX_VALUE).getDouble(0);
@ -2297,37 +2274,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return size(0);
}
@Override
public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
Nd4j.getCompressor().autoDecompress(this);
int n = shape.length;
// FIXME: shapeInfo should be used here
if (shape.length < 1)
return create(Nd4j.createBufferDetached(shape));
if (offsets.length != n)
throw new IllegalArgumentException("Invalid offset " + Arrays.toString(offsets));
if (stride.length != n)
throw new IllegalArgumentException("Invalid stride " + Arrays.toString(stride));
if (Shape.contentEquals(shape, shapeOf())) {
if (ArrayUtil.isZero(offsets)) {
return this;
} else {
throw new IllegalArgumentException("Invalid subArray offsets");
}
}
long[] dotProductOffsets = offsets;
int[] dotProductStride = stride;
long offset = Shape.offset(jvmShapeInfo.javaShapeInformation) + NDArrayIndex.offset(dotProductStride, dotProductOffsets);
if (offset >= data().length())
offset = ArrayUtil.sumLong(offsets);
return create(data, Arrays.copyOf(shape, shape.length), stride, offset, ordering());
}
protected INDArray create(DataBuffer buffer) {
return Nd4j.create(buffer);
}
@ -4016,58 +3962,30 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.getExecutioner().exec(new AMin(this, dimension));
}
/**
* Returns the sum along the specified dimension(s) of this ndarray
*
* @param dimension the dimension to getScalar the sum along
* @return the sum along the specified dimension of this ndarray
*/
@Override
public INDArray sum(int... dimension) {
validateNumericalArray("sum", true);
return Nd4j.getExecutioner().exec(new Sum(this, dimension));
}
/**
* Returns the sum along the last dimension of this ndarray
*
* @param dimension the dimension to getScalar the sum along
* @return the sum along the specified dimension of this ndarray
*/
@Override
public INDArray sum(boolean keepDim, int... dimension) {
validateNumericalArray("sum", true);
return Nd4j.getExecutioner().exec(new Sum(this, null, keepDim, dimension));
}
/**
* Returns entropy along dimension
* @param dimension
* @return
*/
@Override
public INDArray entropy(int... dimension) {
validateNumericalArray("entropy", false);
return Nd4j.getExecutioner().exec(new Entropy(this, dimension));
}
/**
* Returns non-normalized Shannon entropy along dimension
* @param dimension
* @return
*/
@Override
public INDArray shannonEntropy(int... dimension) {
validateNumericalArray("shannonEntropy", false);
return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension));
}
/**
* Returns log entropy along dimension
* @param dimension
* @return
*/
@Override
public INDArray logEntropy(int... dimension) {
validateNumericalArray("logEntropy", false);

View File

@ -468,16 +468,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
throw new UnsupportedOperationException();
}
@Override
public void setStride(long... stride) {
throw new UnsupportedOperationException();
}
@Override
public void setShape(long... shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray putScalar(long row, long col, double value) {
return null;
@ -1284,17 +1274,10 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
@Override
public void setShapeAndStride(int[] shape, int[] stride) {
}
@Override
public void setOrder(char order) {
}
@Override
public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
return null;
}
@Override
@ -1842,49 +1825,26 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
return null;
}
/**
* Returns entropy value for this INDArray
* @return
*/
@Override
public Number entropyNumber() {
return entropy(Integer.MAX_VALUE).getDouble(0);
}
/**
* Returns non-normalized Shannon entropy value for this INDArray
* @return
*/
@Override
public Number shannonEntropyNumber() {
return shannonEntropy(Integer.MAX_VALUE).getDouble(0);
}
/**
* Returns log entropy value for this INDArray
* @return
*/
@Override
public Number logEntropyNumber() {
return logEntropy(Integer.MAX_VALUE).getDouble(0);
}
/**
* Returns entropy along dimension
* @param dimension
* @return
*/
@Override
public INDArray entropy(int... dimension) {
return Nd4j.getExecutioner().exec(new Entropy(this, dimension));
}
/**
* Returns non-normalized Shannon entropy along dimension
* @param dimension
* @return
*/
@Override
public INDArray shannonEntropy(int... dimension) {
return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension));

View File

@ -1016,13 +1016,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray {
return extendedFlags;
}
@Override
public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
throw new UnsupportedOperationException();
}
/**
* Returns the underlying indices of the element of the given index
* such as there really are in the original ndarray
@ -1138,16 +1131,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray {
return null;
}
@Override
public void setStride(long... stride) {
}
@Override
public void setShape(long... shape) {
}
/**
* This method returns true if this INDArray is special case: no-value INDArray
*

View File

@ -213,11 +213,6 @@ public abstract class BaseSparseNDArrayCSR extends BaseSparseNDArray {
return shapeInformation;
}
@Override
public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
throw new UnsupportedOperationException();
}
@Override
public boolean equals(Object o) {
//TODO use op

View File

@ -1854,63 +1854,47 @@ public interface INDArray extends Serializable, AutoCloseable {
/**
* Returns entropy value for this INDArray
* @return
* @return entropy value
*/
Number entropyNumber();
/**
* Returns non-normalized Shannon entropy value for this INDArray
* @return
* @return non-normalized Shannon entropy
*/
Number shannonEntropyNumber();
/**
* Returns log entropy value for this INDArray
* @return
* @return log entropy value
*/
Number logEntropyNumber();
/**
* Returns entropy value for this INDArray along specified dimension(s)
* @return
* @param dimension specified dimension(s)
* @return entropy value
*/
INDArray entropy(int... dimension);
/**
* Returns entropy value for this INDArray along specified dimension(s)
* @return
* Returns Shannon entropy value for this INDArray along specified dimension(s)
* @param dimension specified dimension(s)
* @return Shannon entropy
*/
INDArray shannonEntropy(int... dimension);
/**
* Returns entropy value for this INDArray along specified dimension(s)
* @return
* Returns log entropy value for this INDArray along specified dimension(s)
* @param dimension specified dimension(s)
* @return log entropy value
*/
INDArray logEntropy(int... dimension);
/**
* stride setter
* @param stride
* @deprecated, use {@link #reshape(int...) }
*/
@Deprecated
void setStride(long... stride);
/**
* Shape setter
* @param shape
* @deprecated, use {@link #reshape(int...) }
*/
@Deprecated
void setShape(long... shape);
/**
* Shape and stride setter
* @param shape
* @param stride
* @param shape new value for shape
* @param stride new value for stride
*/
void setShapeAndStride(int[] shape, int[] stride);
@ -1919,15 +1903,7 @@ public interface INDArray extends Serializable, AutoCloseable {
* @param order the ordering to set
*/
void setOrder(char order);
/**
* @param offsets
* @param shape
* @param stride
* @return
*/
INDArray subArray(long[] offsets, int[] shape, int[] stride);
/**
* Returns the elements at the specified indices
*

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx;
@ -53,19 +54,19 @@ public class AvgPooling2D extends DynamicCustomOp {
}
@Builder(builderMethodName = "builder")
public AvgPooling2D(SameDiff sameDiff, SDVariable input, INDArray arrayInput, INDArray arrayOutput, Pooling2DConfig config) {
super(null, sameDiff, new SDVariable[]{input}, false);
if (arrayInput != null) {
addInputArgument(arrayInput);
}
if (arrayOutput != null) {
addOutputArgument(arrayOutput);
}
@Builder(builderMethodName = "sameDiffBuilder")
public AvgPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) {
super(sameDiff, new SDVariable[]{input});
config.setType(Pooling2D.Pooling2DType.AVG);
this.config = config;
addArgs();
}
public AvgPooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){
super(new INDArray[]{input}, wrapOrNull(output));
config.setType(Pooling2D.Pooling2DType.AVG);
this.sameDiff = sameDiff;
this.config = config;
addArgs();
}

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx;
@ -39,6 +40,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
@ -59,18 +61,28 @@ public class Conv1D extends DynamicCustomOp {
protected Conv1DConfig config;
private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s ";
@Builder(builderMethodName = "builder")
@Builder(builderMethodName = "sameDiffBuilder")
public Conv1D(SameDiff sameDiff,
SDVariable[] inputFunctions,
INDArray[] inputArrays, INDArray[] outputs,
Conv1DConfig config) {
super(null, inputArrays, outputs);
this.sameDiff = sameDiff;
super(sameDiff, inputFunctions);
initConfig(config);
}
public Conv1D(INDArray[] inputs, INDArray[] outputs, Conv1DConfig config){
super(inputs, outputs);
initConfig(config);
}
public Conv1D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv1DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
}
private void initConfig(Conv1DConfig config){
this.config = config;
Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP());
addArgs();
sameDiff.putOpForId(this.getOwnName(), this);
sameDiff.addArgsFor(inputFunctions, this);
}
protected void addArgs() {

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx;
@ -56,23 +57,32 @@ public class Conv2D extends DynamicCustomOp {
protected Conv2DConfig config;
private static final String INVALID_CONFIGURATION = "Invalid Conv2D configuration : sW = %s pH = %s dW = %s ";
@Builder(builderMethodName = "builder")
@Builder(builderMethodName = "sameDiffBuilder")
public Conv2D(SameDiff sameDiff,
SDVariable[] inputFunctions,
INDArray[] inputArrays, INDArray[] outputs,
Conv2DConfig config) {
super(null, inputArrays, outputs);
this.sameDiff = sameDiff;
super(sameDiff, inputFunctions);
initConfig(config);
}
public Conv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
super(inputs, outputs);
initConfig(config);
}
public Conv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
}
protected void initConfig(Conv2DConfig config){
this.config = config;
Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1,
INVALID_CONFIGURATION,
config.getSH(), config.getPH(), config.getDW());
INVALID_CONFIGURATION,
config.getSH(), config.getPH(), config.getDW());
addArgs();
if(sameDiff != null) {
sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point
sameDiff.addArgsFor(inputFunctions, this);
}
}
protected void addArgs() {
@ -252,7 +262,6 @@ public class Conv2D extends DynamicCustomOp {
Conv2DDerivative conv2DDerivative = Conv2DDerivative.derivativeBuilder()
.sameDiff(sameDiff)
.config(config)
.outputs(outputArguments())
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
.build();
List<SDVariable> ret = Arrays.asList(conv2DDerivative.outputVariables());

View File

@ -37,8 +37,8 @@ import java.util.List;
public class Conv2DDerivative extends Conv2D {
@Builder(builderMethodName = "derivativeBuilder")
public Conv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig config) {
super(sameDiff, inputFunctions, inputArrays, outputs, config);
public Conv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig config) {
super(sameDiff, inputFunctions, config);
}
public Conv2DDerivative() {}

View File

@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.nd4j.autodiff.functions.DifferentialFunction;
@ -33,6 +34,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
@ -55,25 +57,27 @@ public class Conv3D extends DynamicCustomOp {
public Conv3D() {
}
@Builder(builderMethodName = "builder")
public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs,
Conv3DConfig conv3DConfig) {
super(null, sameDiff, inputFunctions, false);
setSameDiff(sameDiff);
@Builder(builderMethodName = "sameDiffBuilder")
public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig config) {
super(sameDiff, inputFunctions);
initConfig(config);
}
if (inputs != null)
addInputArgument(inputs);
if (outputs != null)
addOutputArgument(outputs);
this.config = conv3DConfig;
public Conv3D(INDArray[] inputs, INDArray[] outputs, Conv3DConfig config){
super(inputs, outputs);
initConfig(config);
}
public Conv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv3DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
}
private void initConfig(Conv3DConfig config){
this.config = config;
Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1,
INVALID_CONFIGURATION,
config.getSW(), config.getPH(), config.getDW());
INVALID_CONFIGURATION,
config.getSW(), config.getPH(), config.getDW());
addArgs();
//for (val arg: iArgs())
// System.out.println(getIArgument(arg));
}
@ -259,8 +263,6 @@ public class Conv3D extends DynamicCustomOp {
inputs.add(f1.get(0));
Conv3DDerivative conv3DDerivative = Conv3DDerivative.derivativeBuilder()
.conv3DConfig(config)
.inputFunctions(args())
.outputs(outputArguments())
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
.sameDiff(sameDiff)
.build();

View File

@ -39,8 +39,8 @@ public class Conv3DDerivative extends Conv3D {
public Conv3DDerivative() {}
@Builder(builderMethodName = "derivativeBuilder")
public Conv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, Conv3DConfig conv3DConfig) {
super(sameDiff, inputFunctions, inputs, outputs, conv3DConfig);
public Conv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig conv3DConfig) {
super(sameDiff, inputFunctions, conv3DConfig);
}
@Override

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx;
@ -31,6 +32,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue;
@ -51,25 +53,25 @@ public class DeConv2D extends DynamicCustomOp {
protected DeConv2DConfig config;
@Builder(builderMethodName = "builder")
@Builder(builderMethodName = "sameDiffBuilder")
public DeConv2D(SameDiff sameDiff,
SDVariable[] inputs,
INDArray[] inputArrays, INDArray[] outputs,
DeConv2DConfig config) {
super(null, inputArrays, outputs);
this.sameDiff = sameDiff;
super(sameDiff, inputs);
this.config = config;
if (inputArrays != null) {
addInputArgument(inputArrays);
}
if (outputs != null) {
addOutputArgument(outputs);
}
addArgs();
sameDiff.putOpForId(this.getOwnName(), this);
sameDiff.addArgsFor(inputs, this);
}
public DeConv2D(INDArray[] inputs, INDArray[] outputs, DeConv2DConfig config){
super(inputs, outputs);
this.config = config;
addArgs();
}
public DeConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull DeConv2DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
}
@Override

View File

@ -40,8 +40,8 @@ public class DeConv2DDerivative extends DeConv2D {
public DeConv2DDerivative() {}
@Builder(builderMethodName = "derivativeBuilder")
public DeConv2DDerivative(SameDiff sameDiff, SDVariable[] inputs, INDArray[] inputArrays, INDArray[] outputs, DeConv2DConfig config) {
super(sameDiff, inputs, inputArrays, outputs, config);
public DeConv2DDerivative(SameDiff sameDiff, SDVariable[] inputs, DeConv2DConfig config) {
super(sameDiff, inputs, config);
}
@Override

View File

@ -53,25 +53,21 @@ public class DeConv2DTF extends DynamicCustomOp {
protected DeConv2DConfig config;
@Builder(builderMethodName = "builder")
@Builder(builderMethodName = "sameDiffBuilder")
public DeConv2DTF(SameDiff sameDiff,
SDVariable[] inputs,
INDArray[] inputArrays, INDArray[] outputs,
DeConv2DConfig config) {
super(null, inputArrays, outputs);
this.sameDiff = sameDiff;
super(sameDiff, inputs);
this.config = config;
addArgs();
}
public DeConv2DTF(INDArray[] inputs, INDArray[] outputs, DeConv2DConfig config){
super(inputs, outputs);
this.config = config;
if (inputArrays != null) {
addInputArgument(inputArrays);
}
if (outputs != null) {
addOutputArgument(outputs);
}
addArgs();
sameDiff.putOpForId(this.getOwnName(), this);
sameDiff.addArgsFor(inputs, this);
}
@Override

View File

@ -28,6 +28,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue;
@ -53,12 +54,23 @@ public class DeConv3D extends DynamicCustomOp {
protected DeConv3DConfig config;
public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, DeConv3DConfig config) {
public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
super(sameDiff, toArr(input, weights, bias));
this.config = config;
addArgs();
}
public DeConv3D(INDArray[] inputs, INDArray[] outputs, DeConv3DConfig config){
super(inputs, outputs);
this.config = config;
addArgs();
}
public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull DeConv3DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
}
private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){
if(bias != null){
return new SDVariable[]{input, weights, bias};

View File

@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx;
@ -35,6 +36,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
@ -53,17 +55,25 @@ public class DepthwiseConv2D extends DynamicCustomOp {
protected Conv2DConfig config;
@Builder(builderMethodName = "builder")
@Builder(builderMethodName = "sameDiffBuilder")
public DepthwiseConv2D(SameDiff sameDiff,
SDVariable[] inputFunctions,
INDArray[] inputArrays, INDArray[] outputs,
Conv2DConfig config) {
super(null, inputArrays, outputs);
this.sameDiff = sameDiff;
super(sameDiff, inputFunctions);
this.config = config;
addArgs();
sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point
sameDiff.addArgsFor(inputFunctions, this);
}
public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
super(inputs, outputs);
this.config = config;
addArgs();
}
public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
}
public DepthwiseConv2D() {

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx;
@ -48,18 +49,19 @@ public class LocalResponseNormalization extends DynamicCustomOp {
protected LocalResponseNormalizationConfig config;
@Builder(builderMethodName = "builder")
public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] inputFunctions,
INDArray[] inputs, INDArray[] outputs,boolean inPlace,
@Builder(builderMethodName = "sameDiffBuilder")
public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] inputFunctions, boolean inPlace,
LocalResponseNormalizationConfig config) {
super(null,sameDiff, inputFunctions, inPlace);
this.config = config;
addArgs();
}
public LocalResponseNormalization(@NonNull INDArray input, INDArray output, @NonNull LocalResponseNormalizationConfig config){
super(new INDArray[]{input}, wrapOrNull(output));
this.config = config;
if(inputs != null) {
addInputArgument(inputs);
}
if(outputs!= null) {
addOutputArgument(outputs);
}
addArgs();
}

View File

@ -33,8 +33,8 @@ import java.util.List;
@Slf4j
public class LocalResponseNormalizationDerivative extends LocalResponseNormalization {
@Builder(builderMethodName = "derivativeBuilder")
public LocalResponseNormalizationDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, boolean inPlace, LocalResponseNormalizationConfig config) {
super(sameDiff, inputFunctions, inputs, outputs, inPlace, config);
public LocalResponseNormalizationDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, boolean inPlace, LocalResponseNormalizationConfig config) {
super(sameDiff, inputFunctions, inPlace, config);
}
public LocalResponseNormalizationDerivative() {}

View File

@ -51,27 +51,18 @@ public class MaxPooling2D extends DynamicCustomOp {
public MaxPooling2D() {
}
@Builder(builderMethodName = "builder")
@Builder(builderMethodName = "sameDiffBuilder")
@SuppressWarnings("Used in lombok")
public MaxPooling2D(SameDiff sameDiff, SDVariable input, INDArray arrayInput, INDArray arrayOutput, Pooling2DConfig config) {
public MaxPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) {
super(null, sameDiff, new SDVariable[]{input}, false);
if (arrayInput != null) {
addInputArgument(arrayInput);
}
if (arrayOutput != null) {
addOutputArgument(arrayOutput);
}
config.setType(Pooling2D.Pooling2DType.MAX);
this.config = config;
this.sameDiff = sameDiff;
addArgs();
}
public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){
super(null, new INDArray[]{input}, output == null ? null : new INDArray[]{output});
super(null, new INDArray[]{input}, wrapOrNull(output));
config.setType(Pooling2D.Pooling2DType.MAX);
this.config = config;

View File

@ -16,8 +16,14 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx;
@ -33,9 +39,6 @@ import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.lang.reflect.Field;
import java.util.*;
/**
* Pooling2D operation
@ -70,21 +73,27 @@ public class Pooling2D extends DynamicCustomOp {
public Pooling2D() {}
@Builder(builderMethodName = "builder")
@Builder(builderMethodName = "sameDiffBuilder")
@SuppressWarnings("Used in lombok")
public Pooling2D(SameDiff sameDiff, SDVariable[] inputs,INDArray[] arrayInputs, INDArray[] arrayOutputs,Pooling2DConfig config) {
super(null,sameDiff, inputs, false);
if(arrayInputs != null) {
addInputArgument(arrayInputs);
}
public Pooling2D(SameDiff sameDiff, SDVariable[] inputs,
Pooling2DConfig config) {
super(null, sameDiff, inputs, false);
if(arrayOutputs != null) {
addOutputArgument(arrayOutputs);
}
this.config = config;
addArgs();
}
this.config = config;
public Pooling2D(@NonNull INDArray[] inputs, INDArray[] outputs, @NonNull Pooling2DConfig config){
super(inputs, outputs);
this.config = config;
addArgs();
}
public Pooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){
super(new INDArray[]{input}, wrapOrNull(output));
this.config = config;
addArgs();
}

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -36,8 +37,12 @@ import java.util.List;
@Slf4j
public class Pooling2DDerivative extends Pooling2D {
@Builder(builderMethodName = "derivativeBuilder")
public Pooling2DDerivative(SameDiff sameDiff, SDVariable[] inputs, INDArray[] arrayInputs, INDArray[] arrayOutputs, Pooling2DConfig config) {
super(sameDiff, inputs, arrayInputs, arrayOutputs, config);
public Pooling2DDerivative(SameDiff sameDiff, SDVariable[] inputs, Pooling2DConfig config) {
super(sameDiff, inputs, config);
}
public Pooling2DDerivative(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Pooling2DConfig config){
super(new INDArray[]{input, grad}, wrapOrNull(output), config);
}
public Pooling2DDerivative() {}

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -39,9 +40,17 @@ import java.util.List;
@Slf4j
public class SConv2D extends Conv2D {
@Builder(builderMethodName = "sBuilder")
public SConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig conv2DConfig) {
super(sameDiff, inputFunctions, inputArrays, outputs, conv2DConfig);
@Builder(builderMethodName = "sameDiffSBuilder")
public SConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig conv2DConfig) {
super(sameDiff, inputFunctions, conv2DConfig);
}
public SConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
super(inputs, outputs, config);
}
public SConv2D(@NonNull INDArray input, @NonNull INDArray depthWeights, INDArray pointWeights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
this(wrapFilterNull(input, depthWeights, pointWeights, bias), wrapOrNull(output), config);
}
public SConv2D() {}

View File

@ -38,8 +38,8 @@ import java.util.List;
public class SConv2DDerivative extends SConv2D {
@Builder(builderMethodName = "sDerviativeBuilder")
public SConv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig conv2DConfig) {
super(sameDiff, inputFunctions, inputArrays, outputs, conv2DConfig);
public SConv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig conv2DConfig) {
super(sameDiff, inputFunctions, conv2DConfig);
}
public SConv2DDerivative() {}

View File

@ -0,0 +1,37 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
public class BitsHammingDistance extends DynamicCustomOp {
public BitsHammingDistance(){ }
public BitsHammingDistance(@NonNull SameDiff sd, @NonNull SDVariable x, @NonNull SDVariable y){
super(sd, new SDVariable[]{x, y});
}
public BitsHammingDistance(@NonNull INDArray x, @NonNull INDArray y){
super(new INDArray[]{x, y}, null);
}
@Override
public String opName() {
return "bits_hamming_distance";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected 2 input datatypes, got %s", dataTypes);
Preconditions.checkState(dataTypes.get(0).isIntType() && dataTypes.get(1).isIntType(), "Input datatypes must be integer type, got %s", dataTypes);
return Collections.singletonList(DataType.LONG);
}
}

View File

@ -0,0 +1,78 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Collections;
import java.util.List;
/**
* Bit-wise AND operation, broadcastable
*
* @author raver119@gmail.com
*/
public class BitwiseAnd extends BaseDynamicTransformOp {
public BitwiseAnd(SameDiff sameDiff, SDVariable x, SDVariable y) {
super(sameDiff, new SDVariable[] {x, y} ,false);
}
public BitwiseAnd(INDArray x, INDArray y, INDArray output) {
super(new INDArray[]{x, y}, new INDArray[]{output});
}
public BitwiseAnd(INDArray x, INDArray y) {
this(x, y,x.ulike());
}
public BitwiseAnd() {}
@Override
public String opName() {
return "bitwise_and";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
return "bitwise_and";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException("Not yet implemented: " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -0,0 +1,78 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Collections;
import java.util.List;
/**
* Bit-wise OR operation, broadcastable
*
* @author raver119@gmail.com
*/
public class BitwiseOr extends BaseDynamicTransformOp {
public BitwiseOr(SameDiff sameDiff, SDVariable x, SDVariable y) {
super(sameDiff, new SDVariable[] {x, y} ,false);
}
public BitwiseOr(INDArray x, INDArray y, INDArray output) {
super(new INDArray[]{x, y}, new INDArray[]{output});
}
public BitwiseOr(INDArray x, INDArray y) {
this(x, y,x.ulike());
}
public BitwiseOr() {}
@Override
public String opName() {
return "bitwise_or";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
return "bitwise_or";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException("Not yet implemented: " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -0,0 +1,78 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Collections;
import java.util.List;
/**
* Bit-wise XOR operation, broadcastable
*
* @author raver119@gmail.com
*/
public class BitwiseXor extends BaseDynamicTransformOp {
public BitwiseXor(SameDiff sameDiff, SDVariable x, SDVariable y) {
super(sameDiff, new SDVariable[] {x, y} ,false);
}
public BitwiseXor(INDArray x, INDArray y, INDArray output) {
super(new INDArray[]{x, y}, new INDArray[]{output});
}
public BitwiseXor(INDArray x, INDArray y) {
this(x, y,x.ulike());
}
public BitwiseXor() {}
@Override
public String opName() {
return "bitwise_xor";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
return "bitwise_xor";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException("Not yet implemented: " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -61,7 +61,7 @@ public class CyclicRShiftBits extends BaseDynamicTransformOp {
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
}

View File

@ -61,7 +61,7 @@ public class CyclicShiftBits extends BaseDynamicTransformOp {
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
}

View File

@ -61,7 +61,7 @@ public class RShiftBits extends BaseDynamicTransformOp {
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
}

View File

@ -61,7 +61,7 @@ public class ShiftBits extends BaseDynamicTransformOp {
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
}

View File

@ -235,24 +235,20 @@ public class Convolution {
public static INDArray pooling2D(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw,
int dh, int dw, boolean isSameMode, Pooling2D.Pooling2DType type, Pooling2D.Divisor divisor,
double extra, int virtualHeight, int virtualWidth, INDArray out) {
Pooling2D pooling = Pooling2D.builder()
.arrayInputs(new INDArray[]{img})
.arrayOutputs(new INDArray[]{out})
.config(Pooling2DConfig.builder()
.dH(dh)
.dW(dw)
.extra(extra)
.kH(kh)
.kW(kw)
.pH(ph)
.pW(pw)
.isSameMode(isSameMode)
.sH(sy)
.sW(sx)
.type(type)
.divisor(divisor)
.build())
.build();
Pooling2D pooling = new Pooling2D(img, out, Pooling2DConfig.builder()
.dH(dh)
.dW(dw)
.extra(extra)
.kH(kh)
.kW(kw)
.pH(ph)
.pW(pw)
.isSameMode(isSameMode)
.sH(sy)
.sW(sx)
.type(type)
.divisor(divisor)
.build());
Nd4j.getExecutioner().execAndReturn(pooling);
return out;
}

View File

@ -96,57 +96,6 @@ public abstract class NDArrayIndex implements INDArrayIndex {
return offset(arr.stride(), Indices.offsets(arr.shape(), indices));
}
/**
* Set the shape and stride for
* new axes based dimensions
* @param arr the array to update
* the shape/strides for
* @param indexes the indexes to update based on
*/
public static void updateForNewAxes(INDArray arr, INDArrayIndex... indexes) {
int numNewAxes = NDArrayIndex.numNewAxis(indexes);
if (numNewAxes >= 1 && (indexes[0].length() > 1 || indexes[0] instanceof NDArrayIndexAll)) {
List<Long> newShape = new ArrayList<>();
List<Long> newStrides = new ArrayList<>();
int currDimension = 0;
for (int i = 0; i < indexes.length; i++) {
if (indexes[i] instanceof NewAxis) {
newShape.add(1L);
newStrides.add(0L);
} else {
newShape.add(arr.size(currDimension));
newStrides.add(arr.size(currDimension));
currDimension++;
}
}
while (currDimension < arr.rank()) {
newShape.add((long) currDimension);
newStrides.add((long) currDimension);
currDimension++;
}
long[] newShapeArr = Longs.toArray(newShape);
long[] newStrideArr = Longs.toArray(newStrides);
// FIXME: this is wrong, it breaks shapeInfo immutability
arr.setShape(newShapeArr);
arr.setStride(newStrideArr);
} else {
if (numNewAxes > 0) {
long[] newShape = Longs.concat(ArrayUtil.toLongArray(ArrayUtil.nTimes(numNewAxes, 1)), arr.shape());
long[] newStrides = Longs.concat(new long[numNewAxes], arr.stride());
arr.setShape(newShape);
arr.setStride(newStrides);
}
}
}
/**
* Compute the offset given an array of offsets.
* The offset is computed(for both fortran an d c ordering) as:

View File

@ -54,7 +54,7 @@ public class DeallocatorService {
deallocatorThreads = new Thread[numThreads];
queues = new ReferenceQueue[numThreads];
for (int e = 0; e < numThreads; e++) {
log.debug("Starting deallocator thread {}", e + 1);
log.trace("Starting deallocator thread {}", e + 1);
queues[e] = new ReferenceQueue<>();
int deviceId = e % numDevices;

View File

@ -1151,4 +1151,6 @@ public interface NativeOps {
int lastErrorCode();
String lastErrorMessage();
boolean isBlasVersionMatches(int major, int minor, int build);
}

View File

@ -101,7 +101,7 @@ public class NativeOpsHolder {
}
//deviceNativeOps.setOmpNumThreads(4);
log.info("Number of threads used for NativeOps: {}", deviceNativeOps.ompGetMaxThreads());
log.info("Number of threads used for OpenMP: {}", deviceNativeOps.ompGetMaxThreads());
} catch (Exception | Error e) {
throw new RuntimeException(
"ND4J is probably missing dependencies. For more information, please refer to: http://nd4j.org/getstarted.html",

View File

@ -51,7 +51,8 @@ public abstract class Nd4jBlas implements Blas {
numThreads = NativeOpsHolder.getCores(Runtime.getRuntime().availableProcessors());
setMaxThreads(numThreads);
}
log.info("Number of threads used for BLAS: {}", getMaxThreads());
log.info("Number of threads used for OpenMP BLAS: {}", getMaxThreads());
}
}

View File

@ -52,6 +52,7 @@ public class JCublasBackend extends Nd4jBackend {
throw new RuntimeException("No CUDA devices were found in system");
}
Loader.load(org.bytedeco.cuda.global.cublas.class);
return true;
}

View File

@ -108,6 +108,22 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
/*
val major = new int[1];
val minor = new int[1];
val build = new int[1];
org.bytedeco.cuda.global.cublas.cublasGetProperty(0, major);
org.bytedeco.cuda.global.cublas.cublasGetProperty(1, minor);
org.bytedeco.cuda.global.cublas.cublasGetProperty(2, build);
val pew = new int[100];
org.bytedeco.cuda.global.cudart.cudaDriverGetVersion(pew);
nativeOps.isBlasVersionMatches(major[0], minor[0], build[0]);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
*/
}
@Override

View File

@ -28,10 +28,12 @@ import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.blas.impl.BaseLevel3;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.factory.DataTypeValidation;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.CublasPointer;
@ -113,8 +115,13 @@ public class JcublasLevel3 extends BaseLevel3 {
@Override
protected void sgemm(char Order, char TransA, char TransB, int M, int N, int K, float alpha, INDArray A, int lda,
INDArray B, int ldb, float beta, INDArray C, int ldc) {
//A = Shape.toOffsetZero(A);
//B = Shape.toOffsetZero(B);
/*
val ctx = AtomicAllocator.getInstance().getDeviceContext();
val handle = ctx.getCublasHandle();
synchronized (handle) {
Nd4j.exec(new Mmul(A, B, C, MMulTranspose.builder().transposeA(false).transposeB(false).build()));
}
*/
Nd4j.getExecutioner().push();
@ -141,6 +148,7 @@ public class JcublasLevel3 extends BaseLevel3 {
}
allocator.registerAction(ctx, C, A, B);
OpExecutionerUtil.checkForAny(C);
}

View File

@ -557,6 +557,13 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public Environment(Pointer p) { super(p); }
/**
* These 3 fields are mostly for CUDA/cuBLAS version tracking
*/
public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter);
public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter);
public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter);
public static native Environment getInstance();
public native @Cast("bool") boolean isVerbose();
@ -1874,7 +1881,7 @@ public native void setOmpNumThreads(int threads);
public native void setOmpMinThreads(int threads);
public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build);
/**
*

View File

@ -557,6 +557,13 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public Environment(Pointer p) { super(p); }
/**
* These 3 fields are mostly for CUDA/cuBLAS version tracking
*/
public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter);
public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter);
public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter);
public static native Environment getInstance();
public native @Cast("bool") boolean isVerbose();
@ -1874,7 +1881,7 @@ public native void setOmpNumThreads(int threads);
public native void setOmpMinThreads(int threads);
public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build);
/**
*
@ -21929,6 +21936,78 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
}
// #endif
/**
* This operation applies bitwise AND
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* \tparam T
*/
// #if NOT_EXCLUDED(OP_bitwise_and)
@Namespace("nd4j::ops") public static class bitwise_and extends BroadcastableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public bitwise_and(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public bitwise_and(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public bitwise_and position(long position) {
return (bitwise_and)super.position(position);
}
public bitwise_and() { super((Pointer)null); allocate(); }
private native void allocate();
}
// #endif
/**
* This operation applies bitwise OR
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* \tparam T
*/
// #if NOT_EXCLUDED(OP_bitwise_or)
@Namespace("nd4j::ops") public static class bitwise_or extends BroadcastableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public bitwise_or(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public bitwise_or(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public bitwise_or position(long position) {
return (bitwise_or)super.position(position);
}
public bitwise_or() { super((Pointer)null); allocate(); }
private native void allocate();
}
// #endif
/**
* This operation applies bitwise XOR
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* \tparam T
*/
// #if NOT_EXCLUDED(OP_bitwise_xor)
@Namespace("nd4j::ops") public static class bitwise_xor extends BroadcastableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public bitwise_xor(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public bitwise_xor(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public bitwise_xor position(long position) {
return (bitwise_xor)super.position(position);
}
public bitwise_xor() { super((Pointer)null); allocate(); }
private native void allocate();
}
// #endif
/**
* This operation returns hamming distance based on bits
*

View File

@ -389,10 +389,7 @@ public class LayerOpValidation extends BaseOpValidation {
.build();
INDArray input = Nd4j.create(inSize);
AvgPooling2D avgPooling2D = AvgPooling2D.builder()
.arrayInput(input)
.config(conf)
.build();
AvgPooling2D avgPooling2D = new AvgPooling2D(input, null, conf);
val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D);
@ -410,10 +407,7 @@ public class LayerOpValidation extends BaseOpValidation {
//Test backprop:
Pooling2DDerivative avg2dDeriv = Pooling2DDerivative.derivativeBuilder()
.arrayInputs(new INDArray[]{input, grad})
.config(conf)
.build();
Pooling2DDerivative avg2dDeriv = new Pooling2DDerivative(input, grad, null, conf);
val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv);
assertEquals(1, outSizesBP.size());
@ -435,10 +429,7 @@ public class LayerOpValidation extends BaseOpValidation {
.build();
INDArray input = Nd4j.create(inSize);
AvgPooling2D avgPooling2D = AvgPooling2D.builder()
.arrayInput(input)
.config(conf)
.build();
AvgPooling2D avgPooling2D = new AvgPooling2D(input, null, conf);
val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D);
assertEquals(1, outSizes.size());
@ -454,11 +445,7 @@ public class LayerOpValidation extends BaseOpValidation {
INDArray grad = Nd4j.create(exp);
//Test backprop:
Pooling2DDerivative avg2dDeriv = Pooling2DDerivative.derivativeBuilder()
.arrayInputs(new INDArray[]{input, grad}) //Original input, and output gradient (eps - same shape as output)
.arrayOutputs(new INDArray[]{Nd4j.create(inSize)}) //Output for BP: same shape as original input
.config(conf)
.build();
Pooling2DDerivative avg2dDeriv = new Pooling2DDerivative(input, grad, Nd4j.create(inSize), conf);
val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv);
assertEquals(1, outSizesBP.size());
@ -749,7 +736,7 @@ public class LayerOpValidation extends BaseOpValidation {
.isSameMode(false)
.build();
SDVariable out = sd.cnn().conv2d(vars, c);
SDVariable out = sd.cnn().conv2d("conv", vars, c);
out = sd.nn().tanh("out", out);
INDArray outArr = sd.execAndEndResult();