commit
b7226bdd7a
|
@ -87,11 +87,7 @@ public class MKLDNNSubsamplingHelper implements SubsamplingHelper {
|
|||
break;
|
||||
}
|
||||
|
||||
Pooling2DDerivative d = Pooling2DDerivative.derivativeBuilder()
|
||||
.config(conf)
|
||||
.arrayInputs(new INDArray[]{input, epsilon})
|
||||
.arrayOutputs(new INDArray[]{gradAtInput})
|
||||
.build();
|
||||
Pooling2DDerivative d = new Pooling2DDerivative(input, epsilon, gradAtInput, conf);
|
||||
|
||||
Nd4j.exec(d);
|
||||
return new Pair<Gradient,INDArray>(new DefaultGradient(), gradAtInput);
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef SAMEDIFF_BLASVERSIONHELPER_H
|
||||
#define SAMEDIFF_BLASVERSIONHELPER_H
|
||||
|
||||
#include <dll.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace nd4j {
|
||||
class ND4J_EXPORT BlasVersionHelper {
|
||||
public:
|
||||
int _blasMajorVersion = 0;
|
||||
int _blasMinorVersion = 0;
|
||||
int _blasPatchVersion = 0;
|
||||
|
||||
BlasVersionHelper();
|
||||
~BlasVersionHelper() = default;
|
||||
};
|
||||
}
|
||||
|
||||
#endif //DEV_TESTS_BLASVERSIONHELPER_H
|
|
@ -253,20 +253,20 @@ if(CUDA_BLAS)
|
|||
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
|
||||
|
||||
if (NOT BUILD_TESTS)
|
||||
CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
|
||||
CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA}
|
||||
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
||||
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
||||
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
|
||||
Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||
Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
|
||||
else()
|
||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true")
|
||||
|
||||
CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
|
||||
CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA}
|
||||
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
||||
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
||||
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
|
||||
Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||
Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
|
||||
endif()
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@
|
|||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "BlasVersionHelper.h"
|
||||
#endif
|
||||
|
||||
namespace nd4j {
|
||||
|
@ -66,6 +66,13 @@ namespace nd4j {
|
|||
#endif
|
||||
|
||||
#ifdef __CUDABLAS__
|
||||
BlasVersionHelper ver;
|
||||
_blasMajorVersion = ver._blasMajorVersion;
|
||||
_blasMinorVersion = ver._blasMinorVersion;
|
||||
_blasPatchVersion = ver._blasPatchVersion;
|
||||
printf("ND4J CUDA build version: %i.%i.%i\n", _blasMajorVersion, _blasMinorVersion, _blasPatchVersion);
|
||||
fflush(stdout);
|
||||
|
||||
int devCnt = 0;
|
||||
cudaGetDeviceCount(&devCnt);
|
||||
auto devProperties = new cudaDeviceProp[devCnt];
|
||||
|
|
|
@ -56,6 +56,13 @@ namespace nd4j{
|
|||
Environment();
|
||||
~Environment();
|
||||
public:
|
||||
/**
|
||||
* These 3 fields are mostly for CUDA/cuBLAS version tracking
|
||||
*/
|
||||
int _blasMajorVersion = 0;
|
||||
int _blasMinorVersion = 0;
|
||||
int _blasPatchVersion = 0;
|
||||
|
||||
static Environment* getInstance();
|
||||
|
||||
bool isVerbose();
|
||||
|
|
|
@ -647,7 +647,7 @@ ND4J_EXPORT void setOmpNumThreads(int threads);
|
|||
ND4J_EXPORT void setOmpMinThreads(int threads);
|
||||
|
||||
|
||||
|
||||
ND4J_EXPORT bool isBlasVersionMatches(int major, int minor, int build);
|
||||
|
||||
/**
|
||||
*
|
||||
|
|
|
@ -583,7 +583,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
|
|||
auto fName = builder.CreateString(*(var->getName()));
|
||||
auto id = CreateIntPair(builder, var->id(), var->index());
|
||||
|
||||
auto fv = CreateFlatVariable(builder, id, fName, static_cast<nd4j::graph::DataType>(array->dataType()), 0, fArray);
|
||||
auto fv = CreateFlatVariable(builder, id, fName, static_cast<nd4j::graph::DType>(array->dataType()), 0, fArray);
|
||||
|
||||
variables_vector.push_back(fv);
|
||||
arrays++;
|
||||
|
|
|
@ -728,6 +728,10 @@ void execReduce3Tad(Nd4jPointer *extraPointers,
|
|||
}
|
||||
}
|
||||
|
||||
bool isBlasVersionMatches(int major, int minor, int build) {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param opNum
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include "../BlasVersionHelper.h"
|
||||
|
||||
namespace nd4j {
|
||||
BlasVersionHelper::BlasVersionHelper() {
|
||||
_blasMajorVersion = __CUDACC_VER_MAJOR__;
|
||||
_blasMinorVersion = __CUDACC_VER_MINOR__;
|
||||
_blasPatchVersion = __CUDACC_VER_BUILD__;
|
||||
}
|
||||
}
|
|
@ -3357,6 +3357,18 @@ void deleteTadPack(nd4j::TadPack* ptr) {
|
|||
delete ptr;
|
||||
}
|
||||
|
||||
bool isBlasVersionMatches(int major, int minor, int build) {
|
||||
auto result = major == Environment::getInstance()->_blasMajorVersion && minor == Environment::getInstance()->_blasMinorVersion && build == Environment::getInstance()->_blasPatchVersion;
|
||||
|
||||
if (!result) {
|
||||
nd4j_printf("CUDA/cuBLAS version mismatch. Expected: %i.%i.%i but got %i.%i.%i instead\n", Environment::getInstance()->_blasMajorVersion, Environment::getInstance()->_blasMinorVersion, Environment::getInstance()->_blasPatchVersion, major, minor, build);
|
||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(152);
|
||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("CUDA/cuBLAS version mismatch");
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) {
|
||||
return nd4j::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype);
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ namespace nd4j {
|
|||
public:
|
||||
static int asInt(DataType type);
|
||||
static DataType fromInt(int dtype);
|
||||
static DataType fromFlatDataType(nd4j::graph::DataType dtype);
|
||||
static DataType fromFlatDataType(nd4j::graph::DType dtype);
|
||||
FORCEINLINE static std::string asString(DataType dataType);
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace nd4j {
|
|||
return (DataType) val;
|
||||
}
|
||||
|
||||
DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DataType dtype) {
|
||||
DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DType dtype) {
|
||||
return (DataType) dtype;
|
||||
}
|
||||
|
||||
|
|
|
@ -40,56 +40,56 @@ inline const char *EnumNameByteOrder(ByteOrder e) {
|
|||
return EnumNamesByteOrder()[index];
|
||||
}
|
||||
|
||||
enum DataType {
|
||||
DataType_INHERIT = 0,
|
||||
DataType_BOOL = 1,
|
||||
DataType_FLOAT8 = 2,
|
||||
DataType_HALF = 3,
|
||||
DataType_HALF2 = 4,
|
||||
DataType_FLOAT = 5,
|
||||
DataType_DOUBLE = 6,
|
||||
DataType_INT8 = 7,
|
||||
DataType_INT16 = 8,
|
||||
DataType_INT32 = 9,
|
||||
DataType_INT64 = 10,
|
||||
DataType_UINT8 = 11,
|
||||
DataType_UINT16 = 12,
|
||||
DataType_UINT32 = 13,
|
||||
DataType_UINT64 = 14,
|
||||
DataType_QINT8 = 15,
|
||||
DataType_QINT16 = 16,
|
||||
DataType_BFLOAT16 = 17,
|
||||
DataType_UTF8 = 50,
|
||||
DataType_MIN = DataType_INHERIT,
|
||||
DataType_MAX = DataType_UTF8
|
||||
enum DType {
|
||||
DType_INHERIT = 0,
|
||||
DType_BOOL = 1,
|
||||
DType_FLOAT8 = 2,
|
||||
DType_HALF = 3,
|
||||
DType_HALF2 = 4,
|
||||
DType_FLOAT = 5,
|
||||
DType_DOUBLE = 6,
|
||||
DType_INT8 = 7,
|
||||
DType_INT16 = 8,
|
||||
DType_INT32 = 9,
|
||||
DType_INT64 = 10,
|
||||
DType_UINT8 = 11,
|
||||
DType_UINT16 = 12,
|
||||
DType_UINT32 = 13,
|
||||
DType_UINT64 = 14,
|
||||
DType_QINT8 = 15,
|
||||
DType_QINT16 = 16,
|
||||
DType_BFLOAT16 = 17,
|
||||
DType_UTF8 = 50,
|
||||
DType_MIN = DType_INHERIT,
|
||||
DType_MAX = DType_UTF8
|
||||
};
|
||||
|
||||
inline const DataType (&EnumValuesDataType())[19] {
|
||||
static const DataType values[] = {
|
||||
DataType_INHERIT,
|
||||
DataType_BOOL,
|
||||
DataType_FLOAT8,
|
||||
DataType_HALF,
|
||||
DataType_HALF2,
|
||||
DataType_FLOAT,
|
||||
DataType_DOUBLE,
|
||||
DataType_INT8,
|
||||
DataType_INT16,
|
||||
DataType_INT32,
|
||||
DataType_INT64,
|
||||
DataType_UINT8,
|
||||
DataType_UINT16,
|
||||
DataType_UINT32,
|
||||
DataType_UINT64,
|
||||
DataType_QINT8,
|
||||
DataType_QINT16,
|
||||
DataType_BFLOAT16,
|
||||
DataType_UTF8
|
||||
inline const DType (&EnumValuesDType())[19] {
|
||||
static const DType values[] = {
|
||||
DType_INHERIT,
|
||||
DType_BOOL,
|
||||
DType_FLOAT8,
|
||||
DType_HALF,
|
||||
DType_HALF2,
|
||||
DType_FLOAT,
|
||||
DType_DOUBLE,
|
||||
DType_INT8,
|
||||
DType_INT16,
|
||||
DType_INT32,
|
||||
DType_INT64,
|
||||
DType_UINT8,
|
||||
DType_UINT16,
|
||||
DType_UINT32,
|
||||
DType_UINT64,
|
||||
DType_QINT8,
|
||||
DType_QINT16,
|
||||
DType_BFLOAT16,
|
||||
DType_UTF8
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
||||
inline const char * const *EnumNamesDataType() {
|
||||
inline const char * const *EnumNamesDType() {
|
||||
static const char * const names[] = {
|
||||
"INHERIT",
|
||||
"BOOL",
|
||||
|
@ -147,9 +147,9 @@ inline const char * const *EnumNamesDataType() {
|
|||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameDataType(DataType e) {
|
||||
inline const char *EnumNameDType(DType e) {
|
||||
const size_t index = static_cast<int>(e);
|
||||
return EnumNamesDataType()[index];
|
||||
return EnumNamesDType()[index];
|
||||
}
|
||||
|
||||
struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
|
@ -165,8 +165,8 @@ struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
const flatbuffers::Vector<int8_t> *buffer() const {
|
||||
return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_BUFFER);
|
||||
}
|
||||
DataType dtype() const {
|
||||
return static_cast<DataType>(GetField<int8_t>(VT_DTYPE, 0));
|
||||
DType dtype() const {
|
||||
return static_cast<DType>(GetField<int8_t>(VT_DTYPE, 0));
|
||||
}
|
||||
ByteOrder byteOrder() const {
|
||||
return static_cast<ByteOrder>(GetField<int8_t>(VT_BYTEORDER, 0));
|
||||
|
@ -192,7 +192,7 @@ struct FlatArrayBuilder {
|
|||
void add_buffer(flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer) {
|
||||
fbb_.AddOffset(FlatArray::VT_BUFFER, buffer);
|
||||
}
|
||||
void add_dtype(DataType dtype) {
|
||||
void add_dtype(DType dtype) {
|
||||
fbb_.AddElement<int8_t>(FlatArray::VT_DTYPE, static_cast<int8_t>(dtype), 0);
|
||||
}
|
||||
void add_byteOrder(ByteOrder byteOrder) {
|
||||
|
@ -214,7 +214,7 @@ inline flatbuffers::Offset<FlatArray> CreateFlatArray(
|
|||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer = 0,
|
||||
DataType dtype = DataType_INHERIT,
|
||||
DType dtype = DType_INHERIT,
|
||||
ByteOrder byteOrder = ByteOrder_LE) {
|
||||
FlatArrayBuilder builder_(_fbb);
|
||||
builder_.add_buffer(buffer);
|
||||
|
@ -228,7 +228,7 @@ inline flatbuffers::Offset<FlatArray> CreateFlatArrayDirect(
|
|||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
const std::vector<int64_t> *shape = nullptr,
|
||||
const std::vector<int8_t> *buffer = nullptr,
|
||||
DataType dtype = DataType_INHERIT,
|
||||
DType dtype = DType_INHERIT,
|
||||
ByteOrder byteOrder = ByteOrder_LE) {
|
||||
return nd4j::graph::CreateFlatArray(
|
||||
_fbb,
|
||||
|
|
|
@ -23,7 +23,7 @@ nd4j.graph.ByteOrder = {
|
|||
/**
|
||||
* @enum
|
||||
*/
|
||||
nd4j.graph.DataType = {
|
||||
nd4j.graph.DType = {
|
||||
INHERIT: 0,
|
||||
BOOL: 1,
|
||||
FLOAT8: 2,
|
||||
|
@ -123,11 +123,11 @@ nd4j.graph.FlatArray.prototype.bufferArray = function() {
|
|||
};
|
||||
|
||||
/**
|
||||
* @returns {nd4j.graph.DataType}
|
||||
* @returns {nd4j.graph.DType}
|
||||
*/
|
||||
nd4j.graph.FlatArray.prototype.dtype = function() {
|
||||
var offset = this.bb.__offset(this.bb_pos, 8);
|
||||
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT;
|
||||
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -205,10 +205,10 @@ nd4j.graph.FlatArray.startBufferVector = function(builder, numElems) {
|
|||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {nd4j.graph.DataType} dtype
|
||||
* @param {nd4j.graph.DType} dtype
|
||||
*/
|
||||
nd4j.graph.FlatArray.addDtype = function(builder, dtype) {
|
||||
builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT);
|
||||
builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT);
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
namespace nd4j.graph
|
||||
{
|
||||
|
||||
public enum DataType : sbyte
|
||||
public enum DType : sbyte
|
||||
{
|
||||
INHERIT = 0,
|
||||
BOOL = 1,
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
package nd4j.graph;
|
||||
|
||||
public final class DataType {
|
||||
private DataType() { }
|
||||
public final class DType {
|
||||
private DType() { }
|
||||
public static final byte INHERIT = 0;
|
||||
public static final byte BOOL = 1;
|
||||
public static final byte FLOAT8 = 2;
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
# namespace: graph
|
||||
|
||||
class DataType(object):
|
||||
class DType(object):
|
||||
INHERIT = 0
|
||||
BOOL = 1
|
||||
FLOAT8 = 2
|
|
@ -33,13 +33,13 @@ public struct FlatArray : IFlatbufferObject
|
|||
public ArraySegment<byte>? GetBufferBytes() { return __p.__vector_as_arraysegment(6); }
|
||||
#endif
|
||||
public sbyte[] GetBufferArray() { return __p.__vector_as_array<sbyte>(6); }
|
||||
public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } }
|
||||
public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
|
||||
public ByteOrder ByteOrder { get { int o = __p.__offset(10); return o != 0 ? (ByteOrder)__p.bb.GetSbyte(o + __p.bb_pos) : ByteOrder.LE; } }
|
||||
|
||||
public static Offset<FlatArray> CreateFlatArray(FlatBufferBuilder builder,
|
||||
VectorOffset shapeOffset = default(VectorOffset),
|
||||
VectorOffset bufferOffset = default(VectorOffset),
|
||||
DataType dtype = DataType.INHERIT,
|
||||
DType dtype = DType.INHERIT,
|
||||
ByteOrder byteOrder = ByteOrder.LE) {
|
||||
builder.StartObject(4);
|
||||
FlatArray.AddBuffer(builder, bufferOffset);
|
||||
|
@ -58,7 +58,7 @@ public struct FlatArray : IFlatbufferObject
|
|||
public static VectorOffset CreateBufferVector(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte(data[i]); return builder.EndVector(); }
|
||||
public static VectorOffset CreateBufferVectorBlock(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
||||
public static void StartBufferVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
||||
public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
||||
public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
||||
public static void AddByteOrder(FlatBufferBuilder builder, ByteOrder byteOrder) { builder.AddSbyte(3, (sbyte)byteOrder, 0); }
|
||||
public static Offset<FlatArray> EndFlatArray(FlatBufferBuilder builder) {
|
||||
int o = builder.EndObject();
|
||||
|
|
|
@ -97,14 +97,14 @@ public struct FlatNode : IFlatbufferObject
|
|||
public ArraySegment<byte>? GetOpNameBytes() { return __p.__vector_as_arraysegment(36); }
|
||||
#endif
|
||||
public byte[] GetOpNameArray() { return __p.__vector_as_array<byte>(36); }
|
||||
public DataType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DataType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DataType)0; }
|
||||
public DType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; }
|
||||
public int OutputTypesLength { get { int o = __p.__offset(38); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||
#if ENABLE_SPAN_T
|
||||
public Span<byte> GetOutputTypesBytes() { return __p.__vector_as_span(38); }
|
||||
#else
|
||||
public ArraySegment<byte>? GetOutputTypesBytes() { return __p.__vector_as_arraysegment(38); }
|
||||
#endif
|
||||
public DataType[] GetOutputTypesArray() { return __p.__vector_as_array<DataType>(38); }
|
||||
public DType[] GetOutputTypesArray() { return __p.__vector_as_array<DType>(38); }
|
||||
public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
|
||||
|
||||
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
|
||||
|
@ -196,8 +196,8 @@ public struct FlatNode : IFlatbufferObject
|
|||
public static void StartOutputNamesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||
public static void AddOpName(FlatBufferBuilder builder, StringOffset opNameOffset) { builder.AddOffset(16, opNameOffset.Value, 0); }
|
||||
public static void AddOutputTypes(FlatBufferBuilder builder, VectorOffset outputTypesOffset) { builder.AddOffset(17, outputTypesOffset.Value, 0); }
|
||||
public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
|
||||
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
||||
public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
|
||||
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
||||
public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
||||
public static void AddScalar(FlatBufferBuilder builder, Offset<FlatArray> scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); }
|
||||
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {
|
||||
|
|
|
@ -25,7 +25,7 @@ public struct FlatVariable : IFlatbufferObject
|
|||
public ArraySegment<byte>? GetNameBytes() { return __p.__vector_as_arraysegment(6); }
|
||||
#endif
|
||||
public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); }
|
||||
public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } }
|
||||
public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
|
||||
public long Shape(int j) { int o = __p.__offset(10); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; }
|
||||
public int ShapeLength { get { int o = __p.__offset(10); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||
#if ENABLE_SPAN_T
|
||||
|
@ -41,7 +41,7 @@ public struct FlatVariable : IFlatbufferObject
|
|||
public static Offset<FlatVariable> CreateFlatVariable(FlatBufferBuilder builder,
|
||||
Offset<IntPair> idOffset = default(Offset<IntPair>),
|
||||
StringOffset nameOffset = default(StringOffset),
|
||||
DataType dtype = DataType.INHERIT,
|
||||
DType dtype = DType.INHERIT,
|
||||
VectorOffset shapeOffset = default(VectorOffset),
|
||||
Offset<FlatArray> ndarrayOffset = default(Offset<FlatArray>),
|
||||
int device = 0,
|
||||
|
@ -60,7 +60,7 @@ public struct FlatVariable : IFlatbufferObject
|
|||
public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); }
|
||||
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
|
||||
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
||||
public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
||||
public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
||||
public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(3, shapeOffset.Value, 0); }
|
||||
public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); }
|
||||
public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); }
|
||||
|
|
|
@ -312,11 +312,11 @@ nd4j.graph.FlatNode.prototype.opName = function(optionalEncoding) {
|
|||
|
||||
/**
|
||||
* @param {number} index
|
||||
* @returns {nd4j.graph.DataType}
|
||||
* @returns {nd4j.graph.DType}
|
||||
*/
|
||||
nd4j.graph.FlatNode.prototype.outputTypes = function(index) {
|
||||
var offset = this.bb.__offset(this.bb_pos, 38);
|
||||
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DataType} */ (0);
|
||||
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0);
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -686,7 +686,7 @@ nd4j.graph.FlatNode.addOutputTypes = function(builder, outputTypesOffset) {
|
|||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {Array.<nd4j.graph.DataType>} data
|
||||
* @param {Array.<nd4j.graph.DType>} data
|
||||
* @returns {flatbuffers.Offset}
|
||||
*/
|
||||
nd4j.graph.FlatNode.createOutputTypesVector = function(builder, data) {
|
||||
|
|
|
@ -65,8 +65,8 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
const flatbuffers::String *name() const {
|
||||
return GetPointer<const flatbuffers::String *>(VT_NAME);
|
||||
}
|
||||
DataType dtype() const {
|
||||
return static_cast<DataType>(GetField<int8_t>(VT_DTYPE, 0));
|
||||
DType dtype() const {
|
||||
return static_cast<DType>(GetField<int8_t>(VT_DTYPE, 0));
|
||||
}
|
||||
const flatbuffers::Vector<int64_t> *shape() const {
|
||||
return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE);
|
||||
|
@ -106,7 +106,7 @@ struct FlatVariableBuilder {
|
|||
void add_name(flatbuffers::Offset<flatbuffers::String> name) {
|
||||
fbb_.AddOffset(FlatVariable::VT_NAME, name);
|
||||
}
|
||||
void add_dtype(DataType dtype) {
|
||||
void add_dtype(DType dtype) {
|
||||
fbb_.AddElement<int8_t>(FlatVariable::VT_DTYPE, static_cast<int8_t>(dtype), 0);
|
||||
}
|
||||
void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) {
|
||||
|
@ -137,7 +137,7 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariable(
|
|||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
flatbuffers::Offset<IntPair> id = 0,
|
||||
flatbuffers::Offset<flatbuffers::String> name = 0,
|
||||
DataType dtype = DataType_INHERIT,
|
||||
DType dtype = DType_INHERIT,
|
||||
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
||||
flatbuffers::Offset<FlatArray> ndarray = 0,
|
||||
int32_t device = 0,
|
||||
|
@ -157,7 +157,7 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariableDirect(
|
|||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
flatbuffers::Offset<IntPair> id = 0,
|
||||
const char *name = nullptr,
|
||||
DataType dtype = DataType_INHERIT,
|
||||
DType dtype = DType_INHERIT,
|
||||
const std::vector<int64_t> *shape = nullptr,
|
||||
flatbuffers::Offset<FlatArray> ndarray = 0,
|
||||
int32_t device = 0,
|
||||
|
|
|
@ -76,11 +76,11 @@ nd4j.graph.FlatVariable.prototype.name = function(optionalEncoding) {
|
|||
};
|
||||
|
||||
/**
|
||||
* @returns {nd4j.graph.DataType}
|
||||
* @returns {nd4j.graph.DType}
|
||||
*/
|
||||
nd4j.graph.FlatVariable.prototype.dtype = function() {
|
||||
var offset = this.bb.__offset(this.bb_pos, 8);
|
||||
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT;
|
||||
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -150,10 +150,10 @@ nd4j.graph.FlatVariable.addName = function(builder, nameOffset) {
|
|||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {nd4j.graph.DataType} dtype
|
||||
* @param {nd4j.graph.DType} dtype
|
||||
*/
|
||||
nd4j.graph.FlatVariable.addDtype = function(builder, dtype) {
|
||||
builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT);
|
||||
builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT);
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
@ -111,7 +111,7 @@ namespace nd4j {
|
|||
|
||||
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
|
||||
|
||||
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array.dataType()), bo);
|
||||
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DType>(array.dataType()), bo);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -219,7 +219,7 @@ namespace nd4j {
|
|||
throw std::runtime_error("CONSTANT variable must have NDArray bundled");
|
||||
|
||||
auto ar = flatVariable->ndarray();
|
||||
if (ar->dtype() == DataType_UTF8) {
|
||||
if (ar->dtype() == DType_UTF8) {
|
||||
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
|
||||
} else {
|
||||
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
|
||||
|
@ -320,7 +320,7 @@ namespace nd4j {
|
|||
auto fBuffer = builder.CreateVector(array->asByteVector());
|
||||
|
||||
// packing array
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DataType) array->dataType());
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DType) array->dataType());
|
||||
|
||||
// packing id/index of this var
|
||||
auto fVid = CreateIntPair(builder, this->_id, this->_index);
|
||||
|
@ -331,7 +331,7 @@ namespace nd4j {
|
|||
stringId = builder.CreateString(this->_name);
|
||||
|
||||
// returning array
|
||||
return CreateFlatVariable(builder, fVid, stringId, static_cast<nd4j::graph::DataType>(array->dataType()), 0, fArray);
|
||||
return CreateFlatVariable(builder, fVid, stringId, static_cast<nd4j::graph::DType>(array->dataType()), 0, fArray);
|
||||
} else {
|
||||
throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList");
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ enum ByteOrder:byte {
|
|||
}
|
||||
|
||||
// DataType for arrays/buffers
|
||||
enum DataType:byte {
|
||||
enum DType:byte {
|
||||
INHERIT,
|
||||
BOOL,
|
||||
FLOAT8,
|
||||
|
@ -49,7 +49,7 @@ enum DataType:byte {
|
|||
table FlatArray {
|
||||
shape:[long]; // shape in Nd4j format
|
||||
buffer:[byte]; // byte buffer with data
|
||||
dtype:DataType; // data type of actual data within buffer
|
||||
dtype:DType; // data type of actual data within buffer
|
||||
byteOrder:ByteOrder; // byte order of buffer
|
||||
}
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ table FlatNode {
|
|||
opName:string; //Used to help resolving the class. In a few cases, multiple classes/opNames are mapped to same hash, and might have different config/properties/differentiability
|
||||
|
||||
// output data types (optional)
|
||||
outputTypes:[DataType];
|
||||
outputTypes:[DType];
|
||||
|
||||
//Scalar value - used for scalar ops. Should be single value only.
|
||||
scalar:FlatArray;
|
||||
|
|
|
@ -51,7 +51,7 @@ table UIVariable {
|
|||
id:IntPair; //Existing IntPair class
|
||||
name:string;
|
||||
type:VarType; //Use existing VarType: VARIABLE, CONSTANT, ARRAY, PLACEHOLDER
|
||||
datatype:DataType;
|
||||
datatype:DType;
|
||||
shape:[long];
|
||||
controlDeps:[string]; //Input control dependencies: variable x -> this
|
||||
outputOfOp:string; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
|
||||
|
|
|
@ -30,7 +30,7 @@ enum VarType:byte {
|
|||
table FlatVariable {
|
||||
id:IntPair; // ID of the Variable, in format of IntPair.first is node Id, IntPair.second is output index of the node
|
||||
name:string; // symbolic ID of the Variable (if defined)
|
||||
dtype:DataType;
|
||||
dtype:DType;
|
||||
|
||||
shape:[long]; // shape is absolutely optional. either shape or ndarray might be set
|
||||
ndarray:FlatArray;
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_bitwise_and)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include <ops/declarable/helpers/shift.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
BROADCASTABLE_OP_IMPL(bitwise_and, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||
|
||||
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), y, z, false);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(bitwise_and) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,50 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_bitwise_or)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include <ops/declarable/helpers/shift.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
BROADCASTABLE_OP_IMPL(bitwise_or, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||
|
||||
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), y, z, false);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(bitwise_or) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,50 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_bitwise_xor)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include <ops/declarable/helpers/shift.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
BROADCASTABLE_OP_IMPL(bitwise_xor, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||
|
||||
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), y, z, false);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(bitwise_xor) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -29,21 +29,26 @@ namespace ops {
|
|||
CUSTOM_OP_IMPL(dynamic_stitch, 2, 1, false, 0, 0) {
|
||||
int numOfData = block.width();
|
||||
// int k = 0;
|
||||
// checking input data size
|
||||
REQUIRE_TRUE(numOfData % 2 == 0, 0,
|
||||
"dynamic_stitch: The input params should contains"
|
||||
" both indeces and data lists with same length.");
|
||||
// split input data list on two equal parts
|
||||
numOfData /= 2;
|
||||
|
||||
// form input lists to use with helpers - both indices and float data inputs
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
std::vector<NDArray*> inputs(numOfData);
|
||||
std::vector<NDArray*> indices(numOfData);
|
||||
|
||||
for (int e = 0; e < numOfData; e++) {
|
||||
auto data = INPUT_VARIABLE(numOfData + e);
|
||||
auto index = INPUT_VARIABLE(e);
|
||||
|
||||
inputs[e] = data;
|
||||
indices[e] = index;
|
||||
}
|
||||
|
||||
// run helper
|
||||
return helpers::dynamicStitchFunctor(block.launchContext(), inputs, indices, output);
|
||||
}
|
||||
|
||||
|
@ -59,17 +64,17 @@ namespace ops {
|
|||
numOfData /= 2; // only index part it's needed to review
|
||||
auto restShape = inputShape->at(numOfData);
|
||||
auto firstShape = inputShape->at(0);
|
||||
// check up inputs to avoid non-int indices and calculate max value from indices to output shape length
|
||||
for(int i = 0; i < numOfData; i++) {
|
||||
auto input = INPUT_VARIABLE(i);
|
||||
REQUIRE_TRUE(input->isZ(), 0, "dynamic_stitch: Indices should be integer, but %d type given.", (int)input->dataType() );
|
||||
// FIXME: we have reduce::Max, cinsider using it instead
|
||||
auto maxV = input->reduceNumber(reduce::Max);
|
||||
if (maxV.e<Nd4jLong>(0) > maxValue) maxValue = maxV.e<Nd4jLong>(0);
|
||||
}
|
||||
|
||||
int outRank = shape::rank(restShape) - shape::rank(firstShape) + 1;
|
||||
// calculate output rank - difference between indices shape and data shape
|
||||
int outRank = shape::rank(restShape) - shape::rank(firstShape) + 1; // at least 1D tensor
|
||||
std::vector<Nd4jLong> outShape(outRank);
|
||||
|
||||
// fill up output shape template: the first to max index, and rests - to vals from the first data input
|
||||
outShape[0] = maxValue + 1;
|
||||
for(int i = 1; i < outRank; ++i)
|
||||
outShape[i] = shape::sizeAt(restShape, i);
|
||||
|
|
|
@ -33,12 +33,13 @@ namespace nd4j {
|
|||
* 0: 1D row-vector (or with shape (1, m))
|
||||
* 1: 1D integer vector with slice nums
|
||||
* 2: 1D float-point values vector with same shape as above
|
||||
* 3: 2D float-point matrix with data to search
|
||||
*
|
||||
* Int args:
|
||||
* 0: N - number of slices
|
||||
*
|
||||
* Output:
|
||||
* 0: 1D vector with edge forces for input and values
|
||||
* 0: 2D matrix with the same shape and type as the 3th argument
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_barnes_edge_forces)
|
||||
DECLARE_CUSTOM_OP(barnes_edge_forces, 4, 1, false, 0, 1);
|
||||
|
@ -52,9 +53,11 @@ namespace nd4j {
|
|||
* 0: 1D int row-vector
|
||||
* 1: 1D int col-vector
|
||||
* 2: 1D float vector with values
|
||||
*
|
||||
*
|
||||
* Output:
|
||||
* 0: symmetric 2D matrix with given values on given places
|
||||
* 0: 1D int result row-vector
|
||||
* 1: 1D int result col-vector
|
||||
* 2: a float-point tensor with shape 1xN, with values from the last input vector
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_barnes_symmetrized)
|
||||
DECLARE_CUSTOM_OP(barnes_symmetrized, 3, 3, false, 0, -1);
|
||||
|
|
|
@ -81,6 +81,39 @@ namespace nd4j {
|
|||
DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This operation applies bitwise AND
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* @tparam T
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_bitwise_and)
|
||||
DECLARE_BROADCASTABLE_OP(bitwise_and, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This operation applies bitwise OR
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* @tparam T
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_bitwise_or)
|
||||
DECLARE_BROADCASTABLE_OP(bitwise_or, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This operation applies bitwise XOR
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* @tparam T
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_bitwise_xor)
|
||||
DECLARE_BROADCASTABLE_OP(bitwise_xor, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This operation returns hamming distance based on bits
|
||||
*
|
||||
|
|
|
@ -120,7 +120,7 @@ namespace nd4j {
|
|||
#endif
|
||||
|
||||
/**
|
||||
* This operation unstacks given NDArray into NDArrayList
|
||||
* This operation unstacks given NDArray into NDArrayList by the first dimension
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_unstack_list)
|
||||
DECLARE_LIST_OP(unstack_list, 1, 1, 0, 0);
|
||||
|
|
|
@ -594,21 +594,46 @@ namespace nd4j {
|
|||
|
||||
|
||||
/**
|
||||
* This operation rearranges data from depth into blocks of spatial data. This is the reverse transformation
|
||||
* of space_to_depth op. This op output is a copy of the input tensor where values from the depth dimension
|
||||
* are moved in spatial blocks to the height and width dimensions. Int attr 0 indicates the input
|
||||
* block size and how the data is moved.
|
||||
* Input:
|
||||
* 0 - 4D tensor on given type
|
||||
* Output:
|
||||
* 0 - 4D tensor of given type and proper shape
|
||||
*
|
||||
*
|
||||
*
|
||||
* Int arguments:
|
||||
* 0 - block size
|
||||
* 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels }
|
||||
* 1 ("NCHW"): shape{ batch, channels, height, width }
|
||||
* 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 }
|
||||
* optional (default 0)
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_depth_to_space)
|
||||
DECLARE_CUSTOM_OP(depth_to_space, 1, 1, false, 0, 2);
|
||||
DECLARE_CUSTOM_OP(depth_to_space, 1, 1, false, 0, -1);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This operation rearranges blocks of spatial data, into depth.This op output is a copy of the input tensor
|
||||
* where values from the height and width dimensions are moved to the depth dimension. Int attr 0 indicates
|
||||
* the input block size.
|
||||
*
|
||||
* Input:
|
||||
* - 4D tensor of given type
|
||||
* Output:
|
||||
* - 4D tensor
|
||||
*
|
||||
* Int arguments:
|
||||
* 0 - block size
|
||||
* 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels }
|
||||
* 1 ("NCHW"): shape{ batch, channels, height, width }
|
||||
* 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 }
|
||||
* optional (default 0)
|
||||
*
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_space_to_depth)
|
||||
DECLARE_CUSTOM_OP(space_to_depth, 1, 1, false, 0, 2);
|
||||
DECLARE_CUSTOM_OP(space_to_depth, 1, 1, false, 0, -1);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
@ -622,13 +647,42 @@ namespace nd4j {
|
|||
#endif
|
||||
|
||||
/**
|
||||
* Zero-pads and then rearranges (permutes) blocks of spatial data into batch. More specifically, this op
|
||||
* outputs a copy of the input tensor where values from the height and width dimensions are moved to the
|
||||
* batch dimension. After the zero-padding, both height and width of the input must be divisible by the block
|
||||
* size.
|
||||
*
|
||||
* Inputs:
|
||||
* 0 - input tensor
|
||||
* 1 - 2D paddings tensor (shape {M, 2})
|
||||
*
|
||||
* Output:
|
||||
* - result tensor
|
||||
*
|
||||
* Int args:
|
||||
* 0 - block size (M)
|
||||
*
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_space_to_batch)
|
||||
DECLARE_CUSTOM_OP(space_to_batch, 2, 1, false, 0, 1);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* This operation divides "spatial" dimensions [1, ..., M] of the input into a grid of blocks of shape
|
||||
* block_shape, and interleaves these blocks with the "batch" dimension (0) such that in the output,
|
||||
* the spatial dimensions [1, ..., M] correspond to the position within the grid, and the batch dimension
|
||||
* combines both the position within a spatial block and the original batch position. Prior to division into
|
||||
* blocks, the spatial dimensions of the input are optionally zero padded according to paddings.
|
||||
*
|
||||
* Inputs:
|
||||
* 0 - input (N-D tensor)
|
||||
* 1 - block_shape - int 1D tensor with M length
|
||||
* 2 - paddings - int 2D tensor with shape {M, 2}
|
||||
*
|
||||
* Output:
|
||||
* - N-D tensor with the same type as input 0.
|
||||
*
|
||||
* */
|
||||
#if NOT_EXCLUDED(OP_space_to_batch_nd)
|
||||
DECLARE_CUSTOM_OP(space_to_batch_nd, 3, 1, false, 0, 0);
|
||||
#endif
|
||||
|
@ -973,7 +1027,7 @@ namespace nd4j {
|
|||
* return value:
|
||||
* tensor with min values according to indices sets.
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_segment_min_bp)
|
||||
#if NOT_EXCLUDED(OP_segment_min)
|
||||
DECLARE_CUSTOM_OP(segment_min, 2, 1, false, 0, 0);
|
||||
#endif
|
||||
#if NOT_EXCLUDED(OP_segment_min_bp)
|
||||
|
|
|
@ -118,19 +118,19 @@ namespace nd4j {
|
|||
|
||||
PointersManager pm(context, "dynamicPartition");
|
||||
|
||||
if (sourceDimsLen) {
|
||||
if (sourceDimsLen) { // non-linear case
|
||||
std::vector<int> sourceDims(sourceDimsLen);
|
||||
|
||||
for (int i = sourceDimsLen; i > 0; i--)
|
||||
sourceDims[sourceDimsLen - i] = input->rankOf() - i;
|
||||
|
||||
//compute tad array for given dimensions
|
||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), sourceDims);
|
||||
|
||||
std::vector<void *> outBuffers(outSize);
|
||||
std::vector<Nd4jLong *> tadShapes(outSize);
|
||||
std::vector<Nd4jLong *> tadOffsets(outSize);
|
||||
std::vector<Nd4jLong> numTads(outSize);
|
||||
|
||||
// fill up dimensions array for before kernel
|
||||
for (unsigned int i = 0; i < outSize; i++) {
|
||||
outputs[i].first = outputList[i];
|
||||
std::vector<int> outDims(outputs[i].first->rankOf() - 1);
|
||||
|
@ -151,10 +151,10 @@ namespace nd4j {
|
|||
auto dOutBuffers = reinterpret_cast<void **>(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *)));
|
||||
auto dOutTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *)));
|
||||
auto dOutTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *)));
|
||||
|
||||
// run kernel on device
|
||||
dynamicPartitionTadKernel<X,Y><<<256, 256, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize);
|
||||
|
||||
} else {
|
||||
} else { // linear case
|
||||
auto numThreads = 256;
|
||||
auto shmemSize = numThreads * sizeof(Y) * 2 + 1024;
|
||||
|
||||
|
@ -169,7 +169,6 @@ namespace nd4j {
|
|||
auto dOutBuffers = reinterpret_cast<void **>(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *)));
|
||||
auto dOutShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(outShapes.data(), outShapes.size() * sizeof(Nd4jLong *)));
|
||||
|
||||
|
||||
dynamicPartitionScalarKernel<X,Y><<<256, numThreads, shmemSize, *context->getCudaStream()>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), indices->getSpecialBuffer(), indices-> getSpecialShapeInfo(), dOutBuffers, dOutShapes, outSize);
|
||||
}
|
||||
|
||||
|
|
|
@ -544,8 +544,8 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_1) {
|
|||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, adjustSaturation_2) {
|
||||
|
||||
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::FLOAT32);
|
||||
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::DOUBLE);
|
||||
|
||||
nd4j::ops::adjust_saturation op;
|
||||
auto results = op.execute({&input}, {10}, {2});
|
||||
|
@ -553,7 +553,8 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_2) {
|
|||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
// result->printIndexedBuffer();
|
||||
// result->printIndexedBuffer("Result2");
|
||||
// exp.printIndexedBuffer("Expect2");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
|
|
@ -94,10 +94,10 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) {
|
|||
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
|
||||
auto fBuffer = builder.CreateVector(array->asByteVector());
|
||||
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT);
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT);
|
||||
auto fVid = CreateIntPair(builder, -1);
|
||||
|
||||
auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray);
|
||||
auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray);
|
||||
|
||||
std::vector<int> outputs1, outputs2, inputs1, inputs2;
|
||||
outputs1.push_back(2);
|
||||
|
@ -265,7 +265,7 @@ TEST_F(FlatBuffersTest, ExplicitOutputTest1) {
|
|||
|
||||
auto name1 = builder.CreateString("wow1");
|
||||
|
||||
auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DataType::FLOAT);
|
||||
auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DType::FLOAT);
|
||||
|
||||
std::vector<flatbuffers::Offset<FlatVariable>> variables_vector;
|
||||
variables_vector.push_back(fXVar);
|
||||
|
|
|
@ -73,9 +73,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_1) {
|
|||
auto fBuffer = builder.CreateVector(vec);
|
||||
auto fVid = CreateIntPair(builder, 1, 12);
|
||||
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT);
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT);
|
||||
|
||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray);
|
||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray);
|
||||
|
||||
builder.Finish(flatVar);
|
||||
|
||||
|
@ -107,9 +107,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_2) {
|
|||
auto fBuffer = builder.CreateVector(vec);
|
||||
auto fVid = CreateIntPair(builder, 1, 12);
|
||||
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE);
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE);
|
||||
|
||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray);
|
||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray);
|
||||
|
||||
builder.Finish(flatVar);
|
||||
|
||||
|
@ -144,9 +144,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_3) {
|
|||
auto fBuffer = builder.CreateVector(vec);
|
||||
auto fVid = CreateIntPair(builder, 1, 12);
|
||||
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE);
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE);
|
||||
|
||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray);
|
||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray);
|
||||
|
||||
builder.Finish(flatVar);
|
||||
|
||||
|
@ -180,7 +180,7 @@ TEST_F(VariableTests, Test_FlatVariableDataType_4) {
|
|||
auto fShape = builder.CreateVector(original.getShapeAsFlatVector());
|
||||
auto fVid = CreateIntPair(builder, 37, 12);
|
||||
|
||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER);
|
||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER);
|
||||
|
||||
builder.Finish(flatVar);
|
||||
|
||||
|
|
|
@ -469,7 +469,7 @@ public class DifferentialFunctionFactory {
|
|||
* @return
|
||||
*/
|
||||
public SDVariable localResponseNormalization(SDVariable input, LocalResponseNormalizationConfig lrnConfig) {
|
||||
LocalResponseNormalization lrn = LocalResponseNormalization.builder()
|
||||
LocalResponseNormalization lrn = LocalResponseNormalization.sameDiffBuilder()
|
||||
.inputFunctions(new SDVariable[]{input})
|
||||
.sameDiff(sameDiff())
|
||||
.config(lrnConfig)
|
||||
|
@ -487,7 +487,7 @@ public class DifferentialFunctionFactory {
|
|||
* @return
|
||||
*/
|
||||
public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
|
||||
Conv1D conv1D = Conv1D.builder()
|
||||
Conv1D conv1D = Conv1D.sameDiffBuilder()
|
||||
.inputFunctions(new SDVariable[]{input, weights})
|
||||
.sameDiff(sameDiff())
|
||||
.config(conv1DConfig)
|
||||
|
@ -496,6 +496,34 @@ public class DifferentialFunctionFactory {
|
|||
return conv1D.outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Conv1d operation.
|
||||
*
|
||||
* @param input the inputs to conv1d
|
||||
* @param weights conv1d weights
|
||||
* @param bias conv1d bias
|
||||
* @param conv1DConfig the configuration
|
||||
* @return
|
||||
*/
|
||||
public SDVariable conv1d(SDVariable input, SDVariable weights, SDVariable bias, Conv1DConfig conv1DConfig) {
|
||||
|
||||
SDVariable[] args;
|
||||
|
||||
if(bias == null){
|
||||
args = new SDVariable[]{input, weights};
|
||||
} else {
|
||||
args = new SDVariable[]{input, weights, bias};
|
||||
}
|
||||
|
||||
Conv1D conv1D = Conv1D.sameDiffBuilder()
|
||||
.inputFunctions(args)
|
||||
.sameDiff(sameDiff())
|
||||
.config(conv1DConfig)
|
||||
.build();
|
||||
|
||||
return conv1D.outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Conv2d operation.
|
||||
*
|
||||
|
@ -504,7 +532,7 @@ public class DifferentialFunctionFactory {
|
|||
* @return
|
||||
*/
|
||||
public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
|
||||
Conv2D conv2D = Conv2D.builder()
|
||||
Conv2D conv2D = Conv2D.sameDiffBuilder()
|
||||
.inputFunctions(inputs)
|
||||
.sameDiff(sameDiff())
|
||||
.config(conv2DConfig)
|
||||
|
@ -530,7 +558,7 @@ public class DifferentialFunctionFactory {
|
|||
* @return
|
||||
*/
|
||||
public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
|
||||
AvgPooling2D avgPooling2D = AvgPooling2D.builder()
|
||||
AvgPooling2D avgPooling2D = AvgPooling2D.sameDiffBuilder()
|
||||
.input(input)
|
||||
.sameDiff(sameDiff())
|
||||
.config(pooling2DConfig)
|
||||
|
@ -547,7 +575,7 @@ public class DifferentialFunctionFactory {
|
|||
* @return
|
||||
*/
|
||||
public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
|
||||
MaxPooling2D maxPooling2D = MaxPooling2D.builder()
|
||||
MaxPooling2D maxPooling2D = MaxPooling2D.sameDiffBuilder()
|
||||
.input(input)
|
||||
.sameDiff(sameDiff())
|
||||
.config(pooling2DConfig)
|
||||
|
@ -590,7 +618,7 @@ public class DifferentialFunctionFactory {
|
|||
* @return
|
||||
*/
|
||||
public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
|
||||
SConv2D sconv2D = SConv2D.sBuilder()
|
||||
SConv2D sconv2D = SConv2D.sameDiffSBuilder()
|
||||
.inputFunctions(inputs)
|
||||
.sameDiff(sameDiff())
|
||||
.conv2DConfig(conv2DConfig)
|
||||
|
@ -609,7 +637,7 @@ public class DifferentialFunctionFactory {
|
|||
* @return
|
||||
*/
|
||||
public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
|
||||
SConv2D depthWiseConv2D = SConv2D.sBuilder()
|
||||
SConv2D depthWiseConv2D = SConv2D.sameDiffSBuilder()
|
||||
.inputFunctions(inputs)
|
||||
.sameDiff(sameDiff())
|
||||
.conv2DConfig(depthConv2DConfig)
|
||||
|
@ -627,7 +655,7 @@ public class DifferentialFunctionFactory {
|
|||
* @return
|
||||
*/
|
||||
public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
|
||||
DeConv2D deconv2D = DeConv2D.builder()
|
||||
DeConv2D deconv2D = DeConv2D.sameDiffBuilder()
|
||||
.inputs(inputs)
|
||||
.sameDiff(sameDiff())
|
||||
.config(deconv2DConfig)
|
||||
|
@ -654,9 +682,9 @@ public class DifferentialFunctionFactory {
|
|||
* @return
|
||||
*/
|
||||
public SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig) {
|
||||
Conv3D conv3D = Conv3D.builder()
|
||||
Conv3D conv3D = Conv3D.sameDiffBuilder()
|
||||
.inputFunctions(inputs)
|
||||
.conv3DConfig(conv3DConfig)
|
||||
.config(conv3DConfig)
|
||||
.sameDiff(sameDiff())
|
||||
.build();
|
||||
|
||||
|
@ -1260,6 +1288,22 @@ public class DifferentialFunctionFactory {
|
|||
return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable bitwiseHammingDist(SDVariable x, SDVariable y) {
|
||||
return new BitsHammingDistance(sameDiff(), x, y).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable bitwiseAnd(SDVariable x, SDVariable y){
|
||||
return new BitwiseAnd(sameDiff(), x, y).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable bitwiseOr(SDVariable x, SDVariable y){
|
||||
return new BitwiseOr(sameDiff(), x, y).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable bitwiseXor(SDVariable x, SDVariable y){
|
||||
return new BitwiseXor(sameDiff(), x, y).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable eq(SDVariable iX, SDVariable i_y) {
|
||||
return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable();
|
||||
}
|
||||
|
|
|
@ -188,6 +188,11 @@ public class SameDiff extends SDBaseOps {
|
|||
*/
|
||||
public final SDImage image = new SDImage(this);
|
||||
|
||||
/**
|
||||
* Op creator object for bitwise operations
|
||||
*/
|
||||
public final SDBitwise bitwise = new SDBitwise(this);
|
||||
|
||||
/**
|
||||
* Op creator object for math operations
|
||||
*/
|
||||
|
@ -237,6 +242,13 @@ public class SameDiff extends SDBaseOps {
|
|||
return image;
|
||||
}
|
||||
|
||||
/**
|
||||
* Op creator object for bitwise operations
|
||||
*/
|
||||
public SDBitwise bitwise(){
|
||||
return bitwise;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* For import, many times we have variables
|
||||
|
|
|
@ -0,0 +1,205 @@
|
|||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
public class SDBitwise extends SDOps {
|
||||
public SDBitwise(SameDiff sameDiff) {
|
||||
super(sameDiff);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #leftShift(String, SDVariable, SDVariable)}
|
||||
*/
|
||||
public SDVariable leftShift(@NonNull SDVariable x, @NonNull SDVariable y){
|
||||
return leftShift(null, x, y);
|
||||
}
|
||||
|
||||
/**
|
||||
* Bitwise left shift operation. Supports broadcasting.
|
||||
*
|
||||
* @param name Name of the output variable. May be null.
|
||||
* @param x Input to be bit shifted (must be an integer type)
|
||||
* @param y Amount to shift elements of x array (must be an integer type)
|
||||
* @return Bitwise shifted input x
|
||||
*/
|
||||
public SDVariable leftShift(String name, SDVariable x, SDVariable y){
|
||||
validateInteger("bitwise left shift", x);
|
||||
validateInteger("bitwise left shift", y);
|
||||
|
||||
SDVariable ret = f().shift(x, y);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #rightShift(String, SDVariable, SDVariable)}
|
||||
*/
|
||||
public SDVariable rightShift(SDVariable x, SDVariable y){
|
||||
return rightShift(null, x, y);
|
||||
}
|
||||
|
||||
/**
|
||||
* Bitwise right shift operation. Supports broadcasting.
|
||||
*
|
||||
* @param name Name of the output variable. May be null.
|
||||
* @param x Input to be bit shifted (must be an integer type)
|
||||
* @param y Amount to shift elements of x array (must be an integer type)
|
||||
* @return Bitwise shifted input x
|
||||
*/
|
||||
public SDVariable rightShift(String name, SDVariable x, SDVariable y){
|
||||
validateInteger("bitwise right shift", x);
|
||||
validateInteger("bitwise right shift", y);
|
||||
|
||||
SDVariable ret = f().rshift(x, y);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #leftShiftCyclic(String, SDVariable, SDVariable)}
|
||||
*/
|
||||
public SDVariable leftShiftCyclic(SDVariable x, SDVariable y){
|
||||
return leftShiftCyclic(null, x, y);
|
||||
}
|
||||
|
||||
/**
|
||||
* Bitwise left cyclical shift operation. Supports broadcasting.
|
||||
* Unlike {@link #leftShift(String, SDVariable, SDVariable)} the bits will "wrap around":
|
||||
* {@code leftShiftCyclic(01110000, 2) -> 11000001}
|
||||
*
|
||||
* @param name Name of the output variable. May be null.
|
||||
* @param x Input to be bit shifted (must be an integer type)
|
||||
* @param y Amount to shift elements of x array (must be an integer type)
|
||||
* @return Bitwise cyclic shifted input x
|
||||
*/
|
||||
public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y){
|
||||
validateInteger("bitwise left shift (cyclic)", x);
|
||||
validateInteger("bitwise left shift (cyclic)", y);
|
||||
|
||||
SDVariable ret = f().rotl(x, y);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #rightShiftCyclic(String, SDVariable, SDVariable)}
|
||||
*/
|
||||
public SDVariable rightShiftCyclic(SDVariable x, SDVariable y){
|
||||
return rightShiftCyclic(null, x, y);
|
||||
}
|
||||
|
||||
/**
|
||||
* Bitwise right cyclical shift operation. Supports broadcasting.
|
||||
* Unlike {@link #rightShift(String, SDVariable, SDVariable)} the bits will "wrap around":
|
||||
* {@code rightShiftCyclic(00001110, 2) -> 10000011}
|
||||
*
|
||||
* @param name Name of the output variable. May be null.
|
||||
* @param x Input to be bit shifted (must be an integer type)
|
||||
* @param y Amount to shift elements of x array (must be an integer type)
|
||||
* @return Bitwise cyclic shifted input x
|
||||
*/
|
||||
public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y){
|
||||
validateInteger("bitwise right shift (cyclic)", x);
|
||||
validateInteger("bitwise right shift (cyclic)", y);
|
||||
|
||||
SDVariable ret = f().rotr(x, y);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #bitsHammingDistance(String, SDVariable, SDVariable)}
|
||||
*/
|
||||
public SDVariable bitsHammingDistance(SDVariable x, SDVariable y){
|
||||
return bitsHammingDistance(null, x, y);
|
||||
}
|
||||
|
||||
/**
|
||||
* Bitwise Hamming distance reduction over all elements of both input arrays.<br>
|
||||
* For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
|
||||
*
|
||||
* @param name Name of the output variable. May be null.
|
||||
* @param x First input array. Must be integer type.
|
||||
* @param y First input array. Must be integer type, same type as x
|
||||
* @return
|
||||
*/
|
||||
public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y){
|
||||
validateInteger("bitwise hamming distance", x);
|
||||
validateInteger("bitwise hamming distance", y);
|
||||
|
||||
SDVariable ret = f().bitwiseHammingDist(x, y);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #and(String, SDVariable, SDVariable)}
|
||||
*/
|
||||
public SDVariable and(SDVariable x, SDVariable y){
|
||||
return and(null, x, y);
|
||||
}
|
||||
|
||||
/**
|
||||
* Bitwise AND operation. Supports broadcasting.
|
||||
*
|
||||
* @param name Name of the output variable. May be null.
|
||||
* @param x First input array. Must be integer type.
|
||||
* @param y First input array. Must be integer type, same type as x
|
||||
* @return Bitwise AND array
|
||||
*/
|
||||
public SDVariable and(String name, SDVariable x, SDVariable y){
|
||||
validateInteger("bitwise AND", x);
|
||||
validateInteger("bitwise AND", y);
|
||||
|
||||
SDVariable ret = f().bitwiseAnd(x, y);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #or(String, SDVariable, SDVariable)}
|
||||
*/
|
||||
public SDVariable or(SDVariable x, SDVariable y){
|
||||
return or(null, x, y);
|
||||
}
|
||||
|
||||
/**
|
||||
* Bitwise OR operation. Supports broadcasting.
|
||||
*
|
||||
* @param name Name of the output variable. May be null.
|
||||
* @param x First input array. Must be integer type.
|
||||
* @param y First input array. Must be integer type, same type as x
|
||||
* @return Bitwise OR array
|
||||
*/
|
||||
public SDVariable or(String name, SDVariable x, SDVariable y){
|
||||
validateInteger("bitwise OR", x);
|
||||
validateInteger("bitwise OR", y);
|
||||
|
||||
SDVariable ret = f().bitwiseOr(x, y);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #xor(String, SDVariable, SDVariable)}
|
||||
*/
|
||||
public SDVariable xor(SDVariable x, SDVariable y){
|
||||
return xor(null, x, y);
|
||||
}
|
||||
|
||||
/**
|
||||
* Bitwise XOR operation (exclusive OR). Supports broadcasting.
|
||||
*
|
||||
* @param name Name of the output variable. May be null.
|
||||
* @param x First input array. Must be integer type.
|
||||
* @param y First input array. Must be integer type, same type as x
|
||||
* @return Bitwise XOR array
|
||||
*/
|
||||
public SDVariable xor(String name, SDVariable x, SDVariable y){
|
||||
validateInteger("bitwise XOR", x);
|
||||
validateInteger("bitwise XOR", y);
|
||||
|
||||
SDVariable ret = f().bitwiseXor(x, y);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
}
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
|
||||
|
@ -38,14 +39,9 @@ public class SDCNN extends SDOps {
|
|||
}
|
||||
|
||||
/**
|
||||
* 2D Convolution layer operation - average pooling 2d
|
||||
*
|
||||
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
||||
* @param pooling2DConfig the configuration for
|
||||
* @return Result after applying average pooling on the input
|
||||
* See {@link #avgPooling2d(String, SDVariable, Pooling2DConfig)}.
|
||||
*/
|
||||
public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
|
||||
public SDVariable avgPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
|
||||
return avgPooling2d(null, input, pooling2DConfig);
|
||||
}
|
||||
|
||||
|
@ -58,22 +54,16 @@ public class SDCNN extends SDOps {
|
|||
* @param pooling2DConfig the configuration
|
||||
* @return Result after applying average pooling on the input
|
||||
*/
|
||||
public SDVariable avgPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) {
|
||||
public SDVariable avgPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
|
||||
validateFloatingPoint("avgPooling2d", input);
|
||||
SDVariable ret = f().avgPooling2d(input, pooling2DConfig);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* 3D convolution layer operation - average pooling 3d
|
||||
*
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
||||
* (shape [minibatch, depth, height, width, channels])
|
||||
* @param pooling3DConfig the configuration
|
||||
* @return Result after applying average pooling on the input
|
||||
* See {@link #avgPooling3d(String, SDVariable, Pooling3DConfig)}.
|
||||
*/
|
||||
public SDVariable avgPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) {
|
||||
public SDVariable avgPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
|
||||
return avgPooling3d(null, input, pooling3DConfig);
|
||||
}
|
||||
|
||||
|
@ -87,7 +77,7 @@ public class SDCNN extends SDOps {
|
|||
* @param pooling3DConfig the configuration
|
||||
* @return Result after applying average pooling on the input
|
||||
*/
|
||||
public SDVariable avgPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) {
|
||||
public SDVariable avgPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
|
||||
validateFloatingPoint("avgPooling3d", input);
|
||||
SDVariable ret = f().avgPooling3d(input, pooling3DConfig);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
|
@ -96,7 +86,7 @@ public class SDCNN extends SDOps {
|
|||
/**
|
||||
* @see #batchToSpace(String, SDVariable, int[], int[][])
|
||||
*/
|
||||
public SDVariable batchToSpace(SDVariable x, int[] blocks, int[][] crops) {
|
||||
public SDVariable batchToSpace(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) {
|
||||
return batchToSpace(null, x, blocks, crops);
|
||||
}
|
||||
|
||||
|
@ -111,7 +101,7 @@ public class SDCNN extends SDOps {
|
|||
* @return Output variable
|
||||
* @see #spaceToBatch(String, SDVariable, int[], int[][])
|
||||
*/
|
||||
public SDVariable batchToSpace(String name, SDVariable x, int[] blocks, int[][] crops) {
|
||||
public SDVariable batchToSpace(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) {
|
||||
validateNumerical("batchToSpace", x);
|
||||
SDVariable ret = f().batchToSpace(x, blocks, crops);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
|
@ -119,14 +109,9 @@ public class SDCNN extends SDOps {
|
|||
|
||||
|
||||
/**
|
||||
* col2im operation for use in 2D convolution operations. Outputs a 4d array with shape
|
||||
* [minibatch, inputChannels, height, width]
|
||||
*
|
||||
* @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
|
||||
* @param config Convolution configuration for the col2im operation
|
||||
* @return Col2Im output variable
|
||||
* See {@link #col2Im(String, SDVariable, Conv2DConfig)}.
|
||||
*/
|
||||
public SDVariable col2Im(SDVariable in, Conv2DConfig config) {
|
||||
public SDVariable col2Im(@NonNull SDVariable in, @NonNull Conv2DConfig config) {
|
||||
return col2Im(null, in, config);
|
||||
}
|
||||
|
||||
|
@ -139,33 +124,22 @@ public class SDCNN extends SDOps {
|
|||
* @param config Convolution configuration for the col2im operation
|
||||
* @return Col2Im output variable
|
||||
*/
|
||||
public SDVariable col2Im(String name, SDVariable in, Conv2DConfig config) {
|
||||
public SDVariable col2Im(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) {
|
||||
SDVariable ret = f().col2Im(in, config);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* 1D Convolution layer operation - Conv1d
|
||||
*
|
||||
* @param input the input array/activations for the conv1d op
|
||||
* @param weights weights for conv1d op - rank 3 array with values [kernelSize, inputChannels, outputChannels]
|
||||
* @param conv1DConfig the configuration
|
||||
* @return
|
||||
* See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias.
|
||||
*/
|
||||
public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
|
||||
return conv1d(null, input, weights, conv1DConfig);
|
||||
public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) {
|
||||
return conv1d((String) null, input, weights, conv1DConfig);
|
||||
}
|
||||
|
||||
/**
|
||||
* Conv1d operation.
|
||||
*
|
||||
* @param name name of the operation in SameDiff
|
||||
* @param input the inputs to conv1d
|
||||
* @param weights weights for conv1d op - rank 3 array with values [kernelSize, inputChannels, outputChannels]
|
||||
* @param conv1DConfig the configuration
|
||||
* @return
|
||||
* See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias.
|
||||
*/
|
||||
public SDVariable conv1d(String name, SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
|
||||
public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) {
|
||||
validateFloatingPoint("conv1d", input);
|
||||
validateFloatingPoint("conv1d", weights);
|
||||
SDVariable ret = f().conv1d(input, weights, conv1DConfig);
|
||||
|
@ -173,21 +147,55 @@ public class SDCNN extends SDOps {
|
|||
}
|
||||
|
||||
/**
|
||||
* 2D Convolution operation (without bias)
|
||||
*
|
||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
||||
* @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels]
|
||||
* @param config Conv2DConfig configuration
|
||||
* @return result of conv2d op
|
||||
* See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}.
|
||||
*/
|
||||
public SDVariable conv2d(SDVariable layerInput, SDVariable weights, Conv2DConfig config) {
|
||||
public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) {
|
||||
return conv1d(null, input, weights, bias, conv1DConfig);
|
||||
}
|
||||
|
||||
/**
|
||||
* Conv1d operation.
|
||||
*
|
||||
* @param name name of the operation in SameDiff
|
||||
* @param input the inputs to conv1d
|
||||
* @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels]
|
||||
* @param bias bias for conv1d op - rank 1 array with shape [outputChannels]. May be null.
|
||||
* @param conv1DConfig the configuration
|
||||
* @return
|
||||
*/
|
||||
public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) {
|
||||
validateFloatingPoint("conv1d", input);
|
||||
validateFloatingPoint("conv1d", weights);
|
||||
validateFloatingPoint("conv1d", bias);
|
||||
SDVariable ret = f().conv1d(input, weights, bias, conv1DConfig);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||
*/
|
||||
public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) {
|
||||
return conv2d(layerInput, weights, null, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||
*/
|
||||
public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) {
|
||||
return conv2d(name, layerInput, weights, null, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
|
||||
*/
|
||||
public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) {
|
||||
return conv2d(null, layerInput, weights, bias, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* 2D Convolution operation with optional bias
|
||||
*
|
||||
* @param name name of the operation in SameDiff
|
||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
||||
* @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels]
|
||||
|
@ -195,7 +203,7 @@ public class SDCNN extends SDOps {
|
|||
* @param config Conv2DConfig configuration
|
||||
* @return result of conv2d op
|
||||
*/
|
||||
public SDVariable conv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, Conv2DConfig config) {
|
||||
public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) {
|
||||
validateFloatingPoint("conv2d", "input", layerInput);
|
||||
validateFloatingPoint("conv2d", "weights", weights);
|
||||
validateFloatingPoint("conv2d", "bias", bias);
|
||||
|
@ -204,18 +212,13 @@ public class SDCNN extends SDOps {
|
|||
arr[1] = weights;
|
||||
if (bias != null)
|
||||
arr[2] = bias;
|
||||
return conv2d(arr, config);
|
||||
return conv2d(name, arr, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* 2D Convolution operation with optional bias
|
||||
*
|
||||
* @param inputs an array with either 2 elements (layerInput, weights) or 3 elements (layerInput, weights, bias) as
|
||||
* described in {@link #conv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)}
|
||||
* @param config Conv2DConfig configuration
|
||||
* @return result of convolution 2d operation
|
||||
* See {@link #conv2d(String, SDVariable[], Conv2DConfig)}.
|
||||
*/
|
||||
public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig config) {
|
||||
public SDVariable conv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) {
|
||||
return conv2d(null, inputs, config);
|
||||
}
|
||||
|
||||
|
@ -228,7 +231,7 @@ public class SDCNN extends SDOps {
|
|||
* @param config Conv2DConfig configuration
|
||||
* @return result of convolution 2d operation
|
||||
*/
|
||||
public SDVariable conv2d(String name, SDVariable[] inputs, Conv2DConfig config) {
|
||||
public SDVariable conv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) {
|
||||
for(SDVariable v : inputs)
|
||||
validateNumerical("conv2d", v);
|
||||
SDVariable ret = f().conv2d(inputs, config);
|
||||
|
@ -236,19 +239,26 @@ public class SDCNN extends SDOps {
|
|||
}
|
||||
|
||||
/**
|
||||
* Convolution 3D operation without bias
|
||||
*
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
||||
* (shape [minibatch, depth, height, width, channels])
|
||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
|
||||
* @param conv3DConfig the configuration
|
||||
* @return Conv3d output variable
|
||||
* See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias.
|
||||
*/
|
||||
public SDVariable conv3d(SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) {
|
||||
public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) {
|
||||
return conv3d(null, input, weights, null, conv3DConfig);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias.
|
||||
*/
|
||||
public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) {
|
||||
return conv3d(name, input, weights, null, conv3DConfig);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}.
|
||||
*/
|
||||
public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) {
|
||||
return conv3d(null, input, weights, bias, conv3DConfig);
|
||||
}
|
||||
|
||||
/**
|
||||
* Convolution 3D operation with optional bias
|
||||
*
|
||||
|
@ -261,7 +271,7 @@ public class SDCNN extends SDOps {
|
|||
* @param conv3DConfig the configuration
|
||||
* @return Conv3d output variable
|
||||
*/
|
||||
public SDVariable conv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) {
|
||||
public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) {
|
||||
validateFloatingPoint("conv3d", "input", input);
|
||||
validateFloatingPoint("conv3d", "weights", weights);
|
||||
validateFloatingPoint("conv3d", "bias", bias);
|
||||
|
@ -276,51 +286,30 @@ public class SDCNN extends SDOps {
|
|||
}
|
||||
|
||||
/**
|
||||
* Convolution 3D operation with optional bias
|
||||
*
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
||||
* (shape [minibatch, depth, height, width, channels])
|
||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
|
||||
* @param bias Optional 1D bias array with shape [outputChannels]. May be null.
|
||||
* @param conv3DConfig the configuration
|
||||
* @return Conv3d output variable
|
||||
* See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias.
|
||||
*/
|
||||
public SDVariable conv3d(SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) {
|
||||
return conv3d(null, input, weights, bias, conv3DConfig);
|
||||
}
|
||||
|
||||
/**
|
||||
* Convolution 3D operation without bias
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
||||
* (shape [minibatch, depth, height, width, channels])
|
||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
|
||||
* @param conv3DConfig the configuration
|
||||
* @return Conv3d output variable
|
||||
*/
|
||||
public SDVariable conv3d(String name, SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) {
|
||||
return conv3d(name, input, weights, null, conv3DConfig);
|
||||
}
|
||||
|
||||
/**
|
||||
* 2D deconvolution operation without bias
|
||||
*
|
||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth].
|
||||
* @param deconv2DConfig DeConv2DConfig configuration
|
||||
* @return result of deconv2d op
|
||||
*/
|
||||
public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, DeConv2DConfig deconv2DConfig) {
|
||||
public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||
return deconv2d(layerInput, weights, null, deconv2DConfig);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias.
|
||||
*/
|
||||
public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||
return deconv2d(name, layerInput, weights, null, deconv2DConfig);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}.
|
||||
*/
|
||||
public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||
return deconv2d(null, layerInput, weights, bias, deconv2DConfig);
|
||||
}
|
||||
|
||||
/**
|
||||
* 2D deconvolution operation with optional bias
|
||||
*
|
||||
* @param name name of the operation in SameDiff
|
||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth].
|
||||
|
@ -328,7 +317,7 @@ public class SDCNN extends SDOps {
|
|||
* @param deconv2DConfig DeConv2DConfig configuration
|
||||
* @return result of deconv2d op
|
||||
*/
|
||||
public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, DeConv2DConfig deconv2DConfig) {
|
||||
public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||
validateFloatingPoint("deconv2d", "input", layerInput);
|
||||
validateFloatingPoint("deconv2d", "weights", weights);
|
||||
validateFloatingPoint("deconv2d", "bias", bias);
|
||||
|
@ -337,18 +326,13 @@ public class SDCNN extends SDOps {
|
|||
arr[1] = weights;
|
||||
if (bias != null)
|
||||
arr[2] = bias;
|
||||
return deconv2d(arr, deconv2DConfig);
|
||||
return deconv2d(name, arr, deconv2DConfig);
|
||||
}
|
||||
|
||||
/**
|
||||
* 2D deconvolution operation with or without optional bias
|
||||
*
|
||||
* @param inputs Inputs to the deconvolution 2d operation - input array of length 2 (layerInput, weights)
|
||||
* or length 3 (layerInput, weights, bias) as described in {@link #deconv2d(SDVariable[], DeConv2DConfig)}
|
||||
* @param deconv2DConfig the configuration
|
||||
* @return result of deconv2d op
|
||||
* See {@link #deconv2d(String, SDVariable[], DeConv2DConfig)}.
|
||||
*/
|
||||
public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
|
||||
public SDVariable deconv2d(@NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||
return deconv2d(null, inputs, deconv2DConfig);
|
||||
}
|
||||
|
||||
|
@ -361,13 +345,34 @@ public class SDCNN extends SDOps {
|
|||
* @param deconv2DConfig the configuration
|
||||
* @return result of deconv2d op
|
||||
*/
|
||||
public SDVariable deconv2d(String name, SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
|
||||
public SDVariable deconv2d(String name, @NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||
for(SDVariable v : inputs)
|
||||
validateNumerical("deconv2d", v);
|
||||
SDVariable ret = f().deconv2d(inputs, deconv2DConfig);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias.
|
||||
*/
|
||||
public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) {
|
||||
return deconv3d(input, weights, null, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias.
|
||||
*/
|
||||
public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) {
|
||||
return deconv3d(name, input, weights, null, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}.
|
||||
*/
|
||||
public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
|
||||
return deconv3d(null, input, weights, bias, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* 3D CNN deconvolution operation with or without optional bias
|
||||
*
|
||||
|
@ -377,7 +382,7 @@ public class SDCNN extends SDOps {
|
|||
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels]
|
||||
* @param config Configuration
|
||||
*/
|
||||
public SDVariable deconv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) {
|
||||
public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
|
||||
validateFloatingPoint("conv3d", input);
|
||||
validateFloatingPoint("conv3d", weights);
|
||||
validateFloatingPoint("conv3d", bias);
|
||||
|
@ -386,41 +391,9 @@ public class SDCNN extends SDOps {
|
|||
}
|
||||
|
||||
/**
|
||||
* 3D CNN deconvolution operation with or without optional bias
|
||||
*
|
||||
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
* @param weights Weights array - shape [kD, kH, kW, oC, iC]
|
||||
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels]
|
||||
* @param config Configuration
|
||||
* See {@link #depthToSpace(String, SDVariable, int, String)}.
|
||||
*/
|
||||
public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) {
|
||||
return deconv3d(null, input, weights, bias, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* 3D CNN deconvolution operation with no bias
|
||||
*
|
||||
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
* @param weights Weights array - shape [kD, kH, kW, oC, iC]
|
||||
* @param config Configuration
|
||||
*/
|
||||
public SDVariable deconv3d(SDVariable input, SDVariable weights, DeConv3DConfig config) {
|
||||
return deconv3d(input, weights, null, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* Convolution 2d layer batch to space operation on 4d input.<br>
|
||||
* Reduces input channels dimension by rearranging data into a larger spatial dimensions<br>
|
||||
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
|
||||
* = [mb, 2, 4, 4]
|
||||
*
|
||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
||||
* @param blockSize Block size, in the height/width dimension
|
||||
* @param dataFormat Data format: "NCHW" or "NHWC"
|
||||
* @return Output variable
|
||||
*/
|
||||
public SDVariable depthToSpace(SDVariable x, int blockSize, String dataFormat) {
|
||||
public SDVariable depthToSpace(@NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) {
|
||||
return depthToSpace(null, x, blockSize, dataFormat);
|
||||
}
|
||||
|
||||
|
@ -438,27 +411,36 @@ public class SDCNN extends SDOps {
|
|||
* @return Output variable
|
||||
* @see #depthToSpace(String, SDVariable, int, String)
|
||||
*/
|
||||
public SDVariable depthToSpace(String name, SDVariable x, int blockSize, String dataFormat) {
|
||||
public SDVariable depthToSpace(String name, @NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) {
|
||||
SDVariable ret = f().depthToSpace(x, blockSize, dataFormat);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Depth-wise 2D convolution operation without bias
|
||||
*
|
||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
||||
* @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
|
||||
* @param config Conv2DConfig configuration
|
||||
* @return result of conv2d op
|
||||
* See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||
*/
|
||||
public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, Conv2DConfig config) {
|
||||
public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) {
|
||||
return depthWiseConv2d(layerInput, depthWeights, null, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||
*/
|
||||
public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) {
|
||||
return depthWiseConv2d(name, layerInput, depthWeights, null, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
|
||||
*/
|
||||
public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) {
|
||||
return depthWiseConv2d(null, layerInput, depthWeights, bias, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* Depth-wise 2D convolution operation with optional bias
|
||||
*
|
||||
* @param name name of the operation in SameDiff
|
||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
||||
* @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
|
||||
|
@ -466,7 +448,7 @@ public class SDCNN extends SDOps {
|
|||
* @param config Conv2DConfig configuration
|
||||
* @return result of depthwise conv2d op
|
||||
*/
|
||||
public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable bias, Conv2DConfig config) {
|
||||
public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) {
|
||||
validateFloatingPoint("depthwiseConv2d", "input", layerInput);
|
||||
validateFloatingPoint("depthwiseConv2d", "depth weights", depthWeights);
|
||||
validateFloatingPoint("depthwiseConv2d", "bias", bias);
|
||||
|
@ -475,19 +457,13 @@ public class SDCNN extends SDOps {
|
|||
arr[1] = depthWeights;
|
||||
if (bias != null)
|
||||
arr[2] = bias;
|
||||
return depthWiseConv2d(arr, config);
|
||||
return depthWiseConv2d(name, arr, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* Depth-wise convolution 2D operation.
|
||||
*
|
||||
* @param inputs the inputs to depth-wise conv2d. An array with either 2 elements (layerInput, depthWeights)
|
||||
* or 3 elements (layerInput, depthWeights, bias) as described in
|
||||
* {@link #depthWiseConv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)}
|
||||
* @param depthConv2DConfig the configuration
|
||||
* @return result of depthwise conv2d op
|
||||
* See {@link #depthWiseConv2d(String, SDVariable[], Conv2DConfig)}.
|
||||
*/
|
||||
public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
|
||||
public SDVariable depthWiseConv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) {
|
||||
return depthWiseConv2d(null, inputs, depthConv2DConfig);
|
||||
}
|
||||
|
||||
|
@ -501,7 +477,7 @@ public class SDCNN extends SDOps {
|
|||
* @param depthConv2DConfig the configuration
|
||||
* @return result of depthwise conv2d op
|
||||
*/
|
||||
public SDVariable depthWiseConv2d(String name, SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
|
||||
public SDVariable depthWiseConv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) {
|
||||
for(SDVariable v : inputs)
|
||||
validateFloatingPoint("depthWiseConv2d", v);
|
||||
SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig);
|
||||
|
@ -509,17 +485,10 @@ public class SDCNN extends SDOps {
|
|||
}
|
||||
|
||||
/**
|
||||
* TODO doc string
|
||||
*
|
||||
* @param df
|
||||
* @param weights
|
||||
* @param strides
|
||||
* @param rates
|
||||
* @param isSameMode
|
||||
* @return
|
||||
* See {@link #dilation2D(String, SDVariable, SDVariable, int[], int[], boolean)}.
|
||||
*/
|
||||
public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides,
|
||||
int[] rates, boolean isSameMode) {
|
||||
public SDVariable dilation2D(@NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides,
|
||||
@NonNull int[] rates, @NonNull boolean isSameMode) {
|
||||
return dilation2D(null, df, weights, strides, rates, isSameMode);
|
||||
}
|
||||
|
||||
|
@ -534,8 +503,8 @@ public class SDCNN extends SDOps {
|
|||
* @param isSameMode
|
||||
* @return
|
||||
*/
|
||||
public SDVariable dilation2D(String name, SDVariable df, SDVariable weights, int[] strides,
|
||||
int[] rates, boolean isSameMode) {
|
||||
public SDVariable dilation2D(String name, @NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides,
|
||||
@NonNull int[] rates, @NonNull boolean isSameMode) {
|
||||
SDVariable ret = f().dilation2D(df, weights, strides, rates, isSameMode);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
@ -555,21 +524,16 @@ public class SDCNN extends SDOps {
|
|||
* @param sameMode If true: use same mode padding. If false
|
||||
* @return
|
||||
*/
|
||||
public SDVariable extractImagePatches(String name, SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) {
|
||||
public SDVariable extractImagePatches(String name, @NonNull SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) {
|
||||
SDVariable ret = f().extractImagePatches(input, kH, kW, sH, sW, rH, rW, sameMode);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* im2col operation for use in 2D convolution operations. Outputs a 6d array with shape
|
||||
* [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
|
||||
*
|
||||
* @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width]
|
||||
* @param config Convolution configuration for the im2col operation
|
||||
* @return Im2Col output variable
|
||||
* See {@link #im2Col(String, SDVariable, Conv2DConfig)}.
|
||||
*/
|
||||
public SDVariable im2Col(SDVariable in, Conv2DConfig config) {
|
||||
public SDVariable im2Col(@NonNull SDVariable in, @NonNull Conv2DConfig config) {
|
||||
return im2Col(null, in, config);
|
||||
}
|
||||
|
||||
|
@ -582,20 +546,16 @@ public class SDCNN extends SDOps {
|
|||
* @param config Convolution configuration for the im2col operation
|
||||
* @return Im2Col output variable
|
||||
*/
|
||||
public SDVariable im2Col(String name, SDVariable in, Conv2DConfig config) {
|
||||
public SDVariable im2Col(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) {
|
||||
SDVariable ret = f().im2Col(in, config);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 2D convolution layer operation - local response normalization
|
||||
*
|
||||
* @param inputs the inputs to lrn
|
||||
* @param lrnConfig the configuration
|
||||
* @return
|
||||
* See {@link #localResponseNormalization(String, SDVariable, LocalResponseNormalizationConfig)}.
|
||||
*/
|
||||
public SDVariable localResponseNormalization(SDVariable inputs, LocalResponseNormalizationConfig lrnConfig) {
|
||||
public SDVariable localResponseNormalization(@NonNull SDVariable inputs, @NonNull LocalResponseNormalizationConfig lrnConfig) {
|
||||
return localResponseNormalization(null, inputs, lrnConfig);
|
||||
}
|
||||
|
||||
|
@ -607,8 +567,8 @@ public class SDCNN extends SDOps {
|
|||
* @param lrnConfig the configuration
|
||||
* @return
|
||||
*/
|
||||
public SDVariable localResponseNormalization(String name, SDVariable input,
|
||||
LocalResponseNormalizationConfig lrnConfig) {
|
||||
public SDVariable localResponseNormalization(String name, @NonNull SDVariable input,
|
||||
@NonNull LocalResponseNormalizationConfig lrnConfig) {
|
||||
validateFloatingPoint("local response normalization", input);
|
||||
SDVariable ret = f().localResponseNormalization(input, lrnConfig);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
|
@ -616,14 +576,9 @@ public class SDCNN extends SDOps {
|
|||
|
||||
|
||||
/**
|
||||
* 2D Convolution layer operation - max pooling 2d
|
||||
*
|
||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
||||
* @param pooling2DConfig the configuration
|
||||
* @return Result after applying max pooling on the input
|
||||
* See {@link #maxPooling2d(String, SDVariable, Pooling2DConfig)}.
|
||||
*/
|
||||
public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
|
||||
public SDVariable maxPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
|
||||
return maxPooling2d(null, input, pooling2DConfig);
|
||||
}
|
||||
|
||||
|
@ -636,22 +591,16 @@ public class SDCNN extends SDOps {
|
|||
* @param pooling2DConfig the configuration
|
||||
* @return Result after applying max pooling on the input
|
||||
*/
|
||||
public SDVariable maxPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) {
|
||||
public SDVariable maxPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
|
||||
validateNumerical("maxPooling2d", input);
|
||||
SDVariable ret = f().maxPooling2d(input, pooling2DConfig);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* 3D convolution layer operation - max pooling 3d operation.
|
||||
*
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
||||
* (shape [minibatch, depth, height, width, channels])
|
||||
* @param pooling3DConfig the configuration
|
||||
* @return Result after applying max pooling on the input
|
||||
* See {@link #maxPooling3d(String, SDVariable, Pooling3DConfig)}.
|
||||
*/
|
||||
public SDVariable maxPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) {
|
||||
public SDVariable maxPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
|
||||
return maxPooling3d(null, input, pooling3DConfig);
|
||||
}
|
||||
|
||||
|
@ -665,7 +614,7 @@ public class SDCNN extends SDOps {
|
|||
* @param pooling3DConfig the configuration
|
||||
* @return Result after applying max pooling on the input
|
||||
*/
|
||||
public SDVariable maxPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) {
|
||||
public SDVariable maxPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
|
||||
validateNumerical("maxPooling3d", input);
|
||||
SDVariable ret = f().maxPooling3d(input, pooling3DConfig);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
|
@ -673,21 +622,30 @@ public class SDCNN extends SDOps {
|
|||
|
||||
|
||||
/**
|
||||
* Separable 2D convolution operation without bias
|
||||
*
|
||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
||||
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
|
||||
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]
|
||||
* May be null
|
||||
* @param config Conv2DConfig configuration
|
||||
* @return result of separable convolution 2d operation
|
||||
* See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||
*/
|
||||
public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights,
|
||||
Conv2DConfig config) {
|
||||
public SDVariable separableConv2d(SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
|
||||
@NonNull Conv2DConfig config) {
|
||||
return separableConv2d(layerInput, depthWeights, pointWeights, null, config);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||
*/
|
||||
public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
|
||||
@NonNull Conv2DConfig config) {
|
||||
return separableConv2d(layerInput, depthWeights, pointWeights, null, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
|
||||
*/
|
||||
public SDVariable separableConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
|
||||
SDVariable bias, @NonNull Conv2DConfig config) {
|
||||
return separableConv2d(null, layerInput, depthWeights, pointWeights, bias, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* Separable 2D convolution operation with optional bias
|
||||
*
|
||||
|
@ -700,8 +658,8 @@ public class SDCNN extends SDOps {
|
|||
* @param config Conv2DConfig configuration
|
||||
* @return result of separable convolution 2d operation
|
||||
*/
|
||||
public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights,
|
||||
SDVariable bias, Conv2DConfig config) {
|
||||
public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
|
||||
SDVariable bias, @NonNull Conv2DConfig config) {
|
||||
validateFloatingPoint("separableConv2d", "input", layerInput);
|
||||
validateFloatingPoint("separableConv2d", "depthWeights", depthWeights);
|
||||
validateFloatingPoint("separableConv2d", "pointWeights", pointWeights);
|
||||
|
@ -712,18 +670,13 @@ public class SDCNN extends SDOps {
|
|||
arr[2] = pointWeights;
|
||||
if (bias != null)
|
||||
arr[3] = bias;
|
||||
return sconv2d(arr, config);
|
||||
return sconv2d(name, arr, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* Separable 2D convolution operation with/without optional bias
|
||||
*
|
||||
* @param inputs the inputs to separable conv2 operation. Should be length 3 (layerInput, depthWeights, pointWeights)
|
||||
* or length 4 (layerInput, depthWeights, pointWeights, bias) as described in {@link #separableConv2d(SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}
|
||||
* @param conv2DConfig the configuration
|
||||
* @return result of separable convolution 2d operation
|
||||
* See {@link #sconv2d(String, SDVariable[], Conv2DConfig)}.
|
||||
*/
|
||||
public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
|
||||
public SDVariable sconv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) {
|
||||
return sconv2d(null, inputs, conv2DConfig);
|
||||
}
|
||||
|
||||
|
@ -736,7 +689,7 @@ public class SDCNN extends SDOps {
|
|||
* @param conv2DConfig the configuration
|
||||
* @return result of separable convolution 2d operation
|
||||
*/
|
||||
public SDVariable sconv2d(String name, SDVariable[] inputs, Conv2DConfig conv2DConfig) {
|
||||
public SDVariable sconv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) {
|
||||
for(SDVariable v : inputs)
|
||||
validateFloatingPoint("sconv2d", v);
|
||||
SDVariable ret = f().sconv2d(inputs, conv2DConfig);
|
||||
|
@ -747,7 +700,7 @@ public class SDCNN extends SDOps {
|
|||
/**
|
||||
* @see #spaceToBatch(String, SDVariable, int[], int[][])
|
||||
*/
|
||||
public SDVariable spaceToBatch(SDVariable x, int[] blocks, int[][] padding) {
|
||||
public SDVariable spaceToBatch(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) {
|
||||
return spaceToBatch(null, x, blocks, padding);
|
||||
}
|
||||
|
||||
|
@ -762,7 +715,7 @@ public class SDCNN extends SDOps {
|
|||
* @return Output variable
|
||||
* @see #batchToSpace(String, SDVariable, int[], int[][])
|
||||
*/
|
||||
public SDVariable spaceToBatch(String name, SDVariable x, int[] blocks, int[][] padding) {
|
||||
public SDVariable spaceToBatch(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) {
|
||||
SDVariable ret = f().spaceToBatch(x, blocks, padding);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
@ -770,7 +723,7 @@ public class SDCNN extends SDOps {
|
|||
/**
|
||||
* @see #spaceToDepth(String, SDVariable, int, String)
|
||||
*/
|
||||
public SDVariable spaceToDepth(SDVariable x, int blockSize, String dataFormat) {
|
||||
public SDVariable spaceToDepth(@NonNull SDVariable x, int blockSize, @NonNull String dataFormat) {
|
||||
return spaceToDepth(null, x, blockSize, dataFormat);
|
||||
}
|
||||
|
||||
|
@ -788,23 +741,39 @@ public class SDCNN extends SDOps {
|
|||
* @return Output variable
|
||||
* @see #depthToSpace(String, SDVariable, int, String)
|
||||
*/
|
||||
public SDVariable spaceToDepth(String name, SDVariable x, int blockSize, String dataFormat) {
|
||||
public SDVariable spaceToDepth(String name, @NonNull SDVariable x, int blockSize, @NonNull String dataFormat) {
|
||||
SDVariable ret = f().spaceToDepth(x, blockSize, dataFormat);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format.
|
||||
* See {@link #upsampling2d(String, SDVariable, boolean, int, int)},
|
||||
* scale is used for both height and width dimensions.
|
||||
*
|
||||
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
|
||||
* @param scale Scale to upsample in both H and W dimensions
|
||||
* @return Upsampled input
|
||||
* @param scale The scale for both height and width dimensions.
|
||||
*/
|
||||
public SDVariable upsampling2d(SDVariable input, int scale) {
|
||||
public SDVariable upsampling2d(@NonNull SDVariable input, int scale) {
|
||||
return upsampling2d(null, input, true, scale, scale);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #upsampling2d(String, SDVariable, boolean, int, int)},
|
||||
* scale is used for both height and width dimensions.
|
||||
*
|
||||
* @param scale The scale for both height and width dimensions.
|
||||
*/
|
||||
public SDVariable upsampling2d(String name, @NonNull SDVariable input, int scale) {
|
||||
return upsampling2d(name, input, true, scale, scale);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #upsampling2d(String, SDVariable, boolean, int, int)}.
|
||||
*/
|
||||
public SDVariable upsampling2d(@NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) {
|
||||
return upsampling2d(null, input, nchw, scaleH, scaleW);
|
||||
}
|
||||
|
||||
/**
|
||||
* 2D Convolution layer operation - Upsampling 2d
|
||||
*
|
||||
|
@ -814,33 +783,8 @@ public class SDCNN extends SDOps {
|
|||
* @param scaleW Scale to upsample in width dimension
|
||||
* @return Upsampled input
|
||||
*/
|
||||
public SDVariable upsampling2d(String name, SDVariable input, boolean nchw, int scaleH, int scaleW) {
|
||||
public SDVariable upsampling2d(String name, @NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) {
|
||||
SDVariable ret = f().upsampling2d(input, nchw, scaleH, scaleW);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format.
|
||||
*
|
||||
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
|
||||
* @param scale Scale to upsample in both H and W dimensions
|
||||
* @return Upsampled input
|
||||
*/
|
||||
public SDVariable upsampling2d(String name, SDVariable input, int scale) {
|
||||
return upsampling2d(name, input, true, scale, scale);
|
||||
}
|
||||
|
||||
/**
|
||||
* 2D Convolution layer operation - Upsampling 2d
|
||||
*
|
||||
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
|
||||
* or NHWC format (shape [minibatch, height, width, channels])
|
||||
* @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format
|
||||
* @param scaleH Scale to upsample in height dimension
|
||||
* @param scaleW Scale to upsample in width dimension
|
||||
* @return Upsampled input
|
||||
*/
|
||||
public SDVariable upsampling2d(SDVariable input, boolean nchw, int scaleH, int scaleW) {
|
||||
return upsampling2d(null, input, nchw, scaleH, scaleW);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
|||
import org.nd4j.autodiff.samediff.VariableType;
|
||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.graph.DataType;
|
||||
import org.nd4j.graph.DType;
|
||||
import org.nd4j.graph.FlatArray;
|
||||
import org.nd4j.graph.FlatNode;
|
||||
import org.nd4j.graph.FlatProperties;
|
||||
|
@ -66,33 +66,33 @@ public class FlatBuffersMapper {
|
|||
public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) {
|
||||
switch (type) {
|
||||
case FLOAT:
|
||||
return DataType.FLOAT;
|
||||
return DType.FLOAT;
|
||||
case DOUBLE:
|
||||
return DataType.DOUBLE;
|
||||
return DType.DOUBLE;
|
||||
case HALF:
|
||||
return DataType.HALF;
|
||||
return DType.HALF;
|
||||
case INT:
|
||||
return DataType.INT32;
|
||||
return DType.INT32;
|
||||
case LONG:
|
||||
return DataType.INT64;
|
||||
return DType.INT64;
|
||||
case BOOL:
|
||||
return DataType.BOOL;
|
||||
return DType.BOOL;
|
||||
case SHORT:
|
||||
return DataType.INT16;
|
||||
return DType.INT16;
|
||||
case BYTE:
|
||||
return DataType.INT8;
|
||||
return DType.INT8;
|
||||
case UBYTE:
|
||||
return DataType.UINT8;
|
||||
return DType.UINT8;
|
||||
case UTF8:
|
||||
return DataType.UTF8;
|
||||
return DType.UTF8;
|
||||
case UINT16:
|
||||
return DataType.UINT16;
|
||||
return DType.UINT16;
|
||||
case UINT32:
|
||||
return DataType.UINT32;
|
||||
return DType.UINT32;
|
||||
case UINT64:
|
||||
return DataType.UINT64;
|
||||
return DType.UINT64;
|
||||
case BFLOAT16:
|
||||
return DataType.BFLOAT16;
|
||||
return DType.BFLOAT16;
|
||||
default:
|
||||
throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]");
|
||||
}
|
||||
|
@ -102,33 +102,33 @@ public class FlatBuffersMapper {
|
|||
* This method converts enums for DataType
|
||||
*/
|
||||
public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) {
|
||||
if (val == DataType.FLOAT) {
|
||||
if (val == DType.FLOAT) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.FLOAT;
|
||||
} else if (val == DataType.DOUBLE) {
|
||||
} else if (val == DType.DOUBLE) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.DOUBLE;
|
||||
} else if (val == DataType.HALF) {
|
||||
} else if (val == DType.HALF) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.HALF;
|
||||
} else if (val == DataType.INT32) {
|
||||
} else if (val == DType.INT32) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.INT;
|
||||
} else if (val == DataType.INT64) {
|
||||
} else if (val == DType.INT64) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.LONG;
|
||||
} else if (val == DataType.INT8) {
|
||||
} else if (val == DType.INT8) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.BYTE;
|
||||
} else if (val == DataType.BOOL) {
|
||||
} else if (val == DType.BOOL) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.BOOL;
|
||||
} else if (val == DataType.UINT8) {
|
||||
} else if (val == DType.UINT8) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.UBYTE;
|
||||
} else if (val == DataType.INT16) {
|
||||
} else if (val == DType.INT16) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.SHORT;
|
||||
} else if (val == DataType.UTF8) {
|
||||
} else if (val == DType.UTF8) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.UTF8;
|
||||
} else if (val == DataType.UINT16) {
|
||||
} else if (val == DType.UINT16) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.UINT16;
|
||||
} else if (val == DataType.UINT32) {
|
||||
} else if (val == DType.UINT32) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.UINT32;
|
||||
} else if (val == DataType.UINT64) {
|
||||
} else if (val == DType.UINT64) {
|
||||
return org.nd4j.linalg.api.buffer.DataType.UINT64;
|
||||
} else if (val == DataType.BFLOAT16){
|
||||
} else if (val == DType.BFLOAT16){
|
||||
return org.nd4j.linalg.api.buffer.DataType.BFLOAT16;
|
||||
} else {
|
||||
throw new RuntimeException("Unknown datatype: " + val);
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
package org.nd4j.graph;
|
||||
|
||||
public final class DataType {
|
||||
private DataType() { }
|
||||
public final class DType {
|
||||
private DType() { }
|
||||
public static final byte INHERIT = 0;
|
||||
public static final byte BOOL = 1;
|
||||
public static final byte FLOAT8 = 2;
|
|
@ -353,6 +353,10 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D.class,
|
||||
|
|
|
@ -1149,16 +1149,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride(), elementWiseStride(), order, this.dataType(), isEmpty()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setShape(long[] shape) {
|
||||
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride(), elementWiseStride(), ordering(), this.dataType(), isEmpty()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setStride(long[] stride) {
|
||||
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride, elementWiseStride(), ordering(), this.dataType(), isEmpty()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setShapeAndStride(int[] shape, int[] stride) {
|
||||
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 0, ordering(), this.dataType(), false));
|
||||
|
@ -1283,29 +1273,16 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
return scalar.getDouble(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns entropy value for this INDArray
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Number entropyNumber() {
|
||||
return entropy(Integer.MAX_VALUE).getDouble(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns non-normalized Shannon entropy value for this INDArray
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Number shannonEntropyNumber() {
|
||||
return shannonEntropy(Integer.MAX_VALUE).getDouble(0);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns log entropy value for this INDArray
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Number logEntropyNumber() {
|
||||
return logEntropy(Integer.MAX_VALUE).getDouble(0);
|
||||
|
@ -2297,37 +2274,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
return size(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
|
||||
Nd4j.getCompressor().autoDecompress(this);
|
||||
int n = shape.length;
|
||||
|
||||
// FIXME: shapeInfo should be used here
|
||||
if (shape.length < 1)
|
||||
return create(Nd4j.createBufferDetached(shape));
|
||||
if (offsets.length != n)
|
||||
throw new IllegalArgumentException("Invalid offset " + Arrays.toString(offsets));
|
||||
if (stride.length != n)
|
||||
throw new IllegalArgumentException("Invalid stride " + Arrays.toString(stride));
|
||||
|
||||
if (Shape.contentEquals(shape, shapeOf())) {
|
||||
if (ArrayUtil.isZero(offsets)) {
|
||||
return this;
|
||||
} else {
|
||||
throw new IllegalArgumentException("Invalid subArray offsets");
|
||||
}
|
||||
}
|
||||
|
||||
long[] dotProductOffsets = offsets;
|
||||
int[] dotProductStride = stride;
|
||||
|
||||
long offset = Shape.offset(jvmShapeInfo.javaShapeInformation) + NDArrayIndex.offset(dotProductStride, dotProductOffsets);
|
||||
if (offset >= data().length())
|
||||
offset = ArrayUtil.sumLong(offsets);
|
||||
|
||||
return create(data, Arrays.copyOf(shape, shape.length), stride, offset, ordering());
|
||||
}
|
||||
|
||||
protected INDArray create(DataBuffer buffer) {
|
||||
return Nd4j.create(buffer);
|
||||
}
|
||||
|
@ -4016,58 +3962,30 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
return Nd4j.getExecutioner().exec(new AMin(this, dimension));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the sum along the specified dimension(s) of this ndarray
|
||||
*
|
||||
* @param dimension the dimension to getScalar the sum along
|
||||
* @return the sum along the specified dimension of this ndarray
|
||||
*/
|
||||
@Override
|
||||
public INDArray sum(int... dimension) {
|
||||
validateNumericalArray("sum", true);
|
||||
return Nd4j.getExecutioner().exec(new Sum(this, dimension));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the sum along the last dimension of this ndarray
|
||||
*
|
||||
* @param dimension the dimension to getScalar the sum along
|
||||
* @return the sum along the specified dimension of this ndarray
|
||||
*/
|
||||
@Override
|
||||
public INDArray sum(boolean keepDim, int... dimension) {
|
||||
validateNumericalArray("sum", true);
|
||||
return Nd4j.getExecutioner().exec(new Sum(this, null, keepDim, dimension));
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns entropy along dimension
|
||||
* @param dimension
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public INDArray entropy(int... dimension) {
|
||||
validateNumericalArray("entropy", false);
|
||||
return Nd4j.getExecutioner().exec(new Entropy(this, dimension));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns non-normalized Shannon entropy along dimension
|
||||
* @param dimension
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public INDArray shannonEntropy(int... dimension) {
|
||||
validateNumericalArray("shannonEntropy", false);
|
||||
return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns log entropy along dimension
|
||||
* @param dimension
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public INDArray logEntropy(int... dimension) {
|
||||
validateNumericalArray("logEntropy", false);
|
||||
|
|
|
@ -468,16 +468,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setStride(long... stride) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setShape(long... shape) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray putScalar(long row, long col, double value) {
|
||||
return null;
|
||||
|
@ -1284,17 +1274,10 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
|
|||
|
||||
@Override
|
||||
public void setShapeAndStride(int[] shape, int[] stride) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setOrder(char order) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1842,49 +1825,26 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
|
|||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns entropy value for this INDArray
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Number entropyNumber() {
|
||||
return entropy(Integer.MAX_VALUE).getDouble(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns non-normalized Shannon entropy value for this INDArray
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Number shannonEntropyNumber() {
|
||||
return shannonEntropy(Integer.MAX_VALUE).getDouble(0);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns log entropy value for this INDArray
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Number logEntropyNumber() {
|
||||
return logEntropy(Integer.MAX_VALUE).getDouble(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns entropy along dimension
|
||||
* @param dimension
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public INDArray entropy(int... dimension) {
|
||||
return Nd4j.getExecutioner().exec(new Entropy(this, dimension));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns non-normalized Shannon entropy along dimension
|
||||
* @param dimension
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public INDArray shannonEntropy(int... dimension) {
|
||||
return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension));
|
||||
|
|
|
@ -1016,13 +1016,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray {
|
|||
return extendedFlags;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Returns the underlying indices of the element of the given index
|
||||
* such as there really are in the original ndarray
|
||||
|
@ -1138,16 +1131,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray {
|
|||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setStride(long... stride) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setShape(long... shape) {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* This method returns true if this INDArray is special case: no-value INDArray
|
||||
*
|
||||
|
|
|
@ -213,11 +213,6 @@ public abstract class BaseSparseNDArrayCSR extends BaseSparseNDArray {
|
|||
return shapeInformation;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
//TODO use op
|
||||
|
|
|
@ -1854,63 +1854,47 @@ public interface INDArray extends Serializable, AutoCloseable {
|
|||
|
||||
/**
|
||||
* Returns entropy value for this INDArray
|
||||
* @return
|
||||
* @return entropy value
|
||||
*/
|
||||
Number entropyNumber();
|
||||
|
||||
/**
|
||||
* Returns non-normalized Shannon entropy value for this INDArray
|
||||
* @return
|
||||
* @return non-normalized Shannon entropy
|
||||
*/
|
||||
Number shannonEntropyNumber();
|
||||
|
||||
/**
|
||||
* Returns log entropy value for this INDArray
|
||||
* @return
|
||||
* @return log entropy value
|
||||
*/
|
||||
Number logEntropyNumber();
|
||||
|
||||
/**
|
||||
* Returns entropy value for this INDArray along specified dimension(s)
|
||||
* @return
|
||||
* @param dimension specified dimension(s)
|
||||
* @return entropy value
|
||||
*/
|
||||
INDArray entropy(int... dimension);
|
||||
|
||||
/**
|
||||
* Returns entropy value for this INDArray along specified dimension(s)
|
||||
* @return
|
||||
* Returns Shannon entropy value for this INDArray along specified dimension(s)
|
||||
* @param dimension specified dimension(s)
|
||||
* @return Shannon entropy
|
||||
*/
|
||||
INDArray shannonEntropy(int... dimension);
|
||||
|
||||
/**
|
||||
* Returns entropy value for this INDArray along specified dimension(s)
|
||||
* @return
|
||||
* Returns log entropy value for this INDArray along specified dimension(s)
|
||||
* @param dimension specified dimension(s)
|
||||
* @return log entropy value
|
||||
*/
|
||||
INDArray logEntropy(int... dimension);
|
||||
|
||||
|
||||
/**
|
||||
* stride setter
|
||||
* @param stride
|
||||
* @deprecated, use {@link #reshape(int...) }
|
||||
*/
|
||||
@Deprecated
|
||||
void setStride(long... stride);
|
||||
|
||||
/**
|
||||
* Shape setter
|
||||
* @param shape
|
||||
* @deprecated, use {@link #reshape(int...) }
|
||||
*/
|
||||
|
||||
|
||||
@Deprecated
|
||||
void setShape(long... shape);
|
||||
|
||||
/**
|
||||
* Shape and stride setter
|
||||
* @param shape
|
||||
* @param stride
|
||||
* @param shape new value for shape
|
||||
* @param stride new value for stride
|
||||
*/
|
||||
void setShapeAndStride(int[] shape, int[] stride);
|
||||
|
||||
|
@ -1919,15 +1903,7 @@ public interface INDArray extends Serializable, AutoCloseable {
|
|||
* @param order the ordering to set
|
||||
*/
|
||||
void setOrder(char order);
|
||||
|
||||
/**
|
||||
* @param offsets
|
||||
* @param shape
|
||||
* @param stride
|
||||
* @return
|
||||
*/
|
||||
INDArray subArray(long[] offsets, int[] shape, int[] stride);
|
||||
|
||||
|
||||
/**
|
||||
* Returns the elements at the specified indices
|
||||
*
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
|||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.Onnx;
|
||||
|
@ -53,19 +54,19 @@ public class AvgPooling2D extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
|
||||
@Builder(builderMethodName = "builder")
|
||||
public AvgPooling2D(SameDiff sameDiff, SDVariable input, INDArray arrayInput, INDArray arrayOutput, Pooling2DConfig config) {
|
||||
super(null, sameDiff, new SDVariable[]{input}, false);
|
||||
if (arrayInput != null) {
|
||||
addInputArgument(arrayInput);
|
||||
}
|
||||
if (arrayOutput != null) {
|
||||
addOutputArgument(arrayOutput);
|
||||
}
|
||||
@Builder(builderMethodName = "sameDiffBuilder")
|
||||
public AvgPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) {
|
||||
super(sameDiff, new SDVariable[]{input});
|
||||
config.setType(Pooling2D.Pooling2DType.AVG);
|
||||
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public AvgPooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){
|
||||
super(new INDArray[]{input}, wrapOrNull(output));
|
||||
config.setType(Pooling2D.Pooling2DType.AVG);
|
||||
|
||||
this.sameDiff = sameDiff;
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
|||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.Onnx;
|
||||
|
@ -39,6 +40,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -59,18 +61,28 @@ public class Conv1D extends DynamicCustomOp {
|
|||
protected Conv1DConfig config;
|
||||
private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s ";
|
||||
|
||||
@Builder(builderMethodName = "builder")
|
||||
@Builder(builderMethodName = "sameDiffBuilder")
|
||||
public Conv1D(SameDiff sameDiff,
|
||||
SDVariable[] inputFunctions,
|
||||
INDArray[] inputArrays, INDArray[] outputs,
|
||||
Conv1DConfig config) {
|
||||
super(null, inputArrays, outputs);
|
||||
this.sameDiff = sameDiff;
|
||||
super(sameDiff, inputFunctions);
|
||||
initConfig(config);
|
||||
}
|
||||
|
||||
public Conv1D(INDArray[] inputs, INDArray[] outputs, Conv1DConfig config){
|
||||
super(inputs, outputs);
|
||||
|
||||
initConfig(config);
|
||||
}
|
||||
|
||||
public Conv1D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv1DConfig config){
|
||||
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||
}
|
||||
|
||||
private void initConfig(Conv1DConfig config){
|
||||
this.config = config;
|
||||
Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP());
|
||||
addArgs();
|
||||
sameDiff.putOpForId(this.getOwnName(), this);
|
||||
sameDiff.addArgsFor(inputFunctions, this);
|
||||
}
|
||||
|
||||
protected void addArgs() {
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
|||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.Onnx;
|
||||
|
@ -56,23 +57,32 @@ public class Conv2D extends DynamicCustomOp {
|
|||
protected Conv2DConfig config;
|
||||
private static final String INVALID_CONFIGURATION = "Invalid Conv2D configuration : sW = %s pH = %s dW = %s ";
|
||||
|
||||
@Builder(builderMethodName = "builder")
|
||||
@Builder(builderMethodName = "sameDiffBuilder")
|
||||
public Conv2D(SameDiff sameDiff,
|
||||
SDVariable[] inputFunctions,
|
||||
INDArray[] inputArrays, INDArray[] outputs,
|
||||
Conv2DConfig config) {
|
||||
super(null, inputArrays, outputs);
|
||||
this.sameDiff = sameDiff;
|
||||
super(sameDiff, inputFunctions);
|
||||
|
||||
initConfig(config);
|
||||
}
|
||||
|
||||
public Conv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
|
||||
super(inputs, outputs);
|
||||
|
||||
initConfig(config);
|
||||
}
|
||||
|
||||
public Conv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
|
||||
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||
}
|
||||
|
||||
protected void initConfig(Conv2DConfig config){
|
||||
this.config = config;
|
||||
|
||||
Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1,
|
||||
INVALID_CONFIGURATION,
|
||||
config.getSH(), config.getPH(), config.getDW());
|
||||
INVALID_CONFIGURATION,
|
||||
config.getSH(), config.getPH(), config.getDW());
|
||||
addArgs();
|
||||
if(sameDiff != null) {
|
||||
sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point
|
||||
sameDiff.addArgsFor(inputFunctions, this);
|
||||
}
|
||||
}
|
||||
|
||||
protected void addArgs() {
|
||||
|
@ -252,7 +262,6 @@ public class Conv2D extends DynamicCustomOp {
|
|||
Conv2DDerivative conv2DDerivative = Conv2DDerivative.derivativeBuilder()
|
||||
.sameDiff(sameDiff)
|
||||
.config(config)
|
||||
.outputs(outputArguments())
|
||||
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
|
||||
.build();
|
||||
List<SDVariable> ret = Arrays.asList(conv2DDerivative.outputVariables());
|
||||
|
|
|
@ -37,8 +37,8 @@ import java.util.List;
|
|||
public class Conv2DDerivative extends Conv2D {
|
||||
|
||||
@Builder(builderMethodName = "derivativeBuilder")
|
||||
public Conv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig config) {
|
||||
super(sameDiff, inputFunctions, inputArrays, outputs, config);
|
||||
public Conv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig config) {
|
||||
super(sameDiff, inputFunctions, config);
|
||||
}
|
||||
|
||||
public Conv2DDerivative() {}
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
|||
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
|
@ -33,6 +34,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -55,25 +57,27 @@ public class Conv3D extends DynamicCustomOp {
|
|||
public Conv3D() {
|
||||
}
|
||||
|
||||
@Builder(builderMethodName = "builder")
|
||||
public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs,
|
||||
Conv3DConfig conv3DConfig) {
|
||||
super(null, sameDiff, inputFunctions, false);
|
||||
setSameDiff(sameDiff);
|
||||
@Builder(builderMethodName = "sameDiffBuilder")
|
||||
public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig config) {
|
||||
super(sameDiff, inputFunctions);
|
||||
initConfig(config);
|
||||
}
|
||||
|
||||
if (inputs != null)
|
||||
addInputArgument(inputs);
|
||||
if (outputs != null)
|
||||
addOutputArgument(outputs);
|
||||
this.config = conv3DConfig;
|
||||
public Conv3D(INDArray[] inputs, INDArray[] outputs, Conv3DConfig config){
|
||||
super(inputs, outputs);
|
||||
initConfig(config);
|
||||
}
|
||||
|
||||
public Conv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv3DConfig config){
|
||||
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||
}
|
||||
|
||||
private void initConfig(Conv3DConfig config){
|
||||
this.config = config;
|
||||
Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1,
|
||||
INVALID_CONFIGURATION,
|
||||
config.getSW(), config.getPH(), config.getDW());
|
||||
INVALID_CONFIGURATION,
|
||||
config.getSW(), config.getPH(), config.getDW());
|
||||
addArgs();
|
||||
|
||||
|
||||
//for (val arg: iArgs())
|
||||
// System.out.println(getIArgument(arg));
|
||||
}
|
||||
|
||||
|
||||
|
@ -259,8 +263,6 @@ public class Conv3D extends DynamicCustomOp {
|
|||
inputs.add(f1.get(0));
|
||||
Conv3DDerivative conv3DDerivative = Conv3DDerivative.derivativeBuilder()
|
||||
.conv3DConfig(config)
|
||||
.inputFunctions(args())
|
||||
.outputs(outputArguments())
|
||||
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
|
||||
.sameDiff(sameDiff)
|
||||
.build();
|
||||
|
|
|
@ -39,8 +39,8 @@ public class Conv3DDerivative extends Conv3D {
|
|||
public Conv3DDerivative() {}
|
||||
|
||||
@Builder(builderMethodName = "derivativeBuilder")
|
||||
public Conv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, Conv3DConfig conv3DConfig) {
|
||||
super(sameDiff, inputFunctions, inputs, outputs, conv3DConfig);
|
||||
public Conv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig conv3DConfig) {
|
||||
super(sameDiff, inputFunctions, conv3DConfig);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
|||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.Onnx;
|
||||
|
@ -31,6 +32,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
|
@ -51,25 +53,25 @@ public class DeConv2D extends DynamicCustomOp {
|
|||
|
||||
protected DeConv2DConfig config;
|
||||
|
||||
@Builder(builderMethodName = "builder")
|
||||
@Builder(builderMethodName = "sameDiffBuilder")
|
||||
public DeConv2D(SameDiff sameDiff,
|
||||
SDVariable[] inputs,
|
||||
INDArray[] inputArrays, INDArray[] outputs,
|
||||
DeConv2DConfig config) {
|
||||
super(null, inputArrays, outputs);
|
||||
this.sameDiff = sameDiff;
|
||||
super(sameDiff, inputs);
|
||||
this.config = config;
|
||||
|
||||
if (inputArrays != null) {
|
||||
addInputArgument(inputArrays);
|
||||
}
|
||||
if (outputs != null) {
|
||||
addOutputArgument(outputs);
|
||||
}
|
||||
|
||||
addArgs();
|
||||
sameDiff.putOpForId(this.getOwnName(), this);
|
||||
sameDiff.addArgsFor(inputs, this);
|
||||
}
|
||||
|
||||
public DeConv2D(INDArray[] inputs, INDArray[] outputs, DeConv2DConfig config){
|
||||
super(inputs, outputs);
|
||||
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public DeConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull DeConv2DConfig config){
|
||||
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -40,8 +40,8 @@ public class DeConv2DDerivative extends DeConv2D {
|
|||
public DeConv2DDerivative() {}
|
||||
|
||||
@Builder(builderMethodName = "derivativeBuilder")
|
||||
public DeConv2DDerivative(SameDiff sameDiff, SDVariable[] inputs, INDArray[] inputArrays, INDArray[] outputs, DeConv2DConfig config) {
|
||||
super(sameDiff, inputs, inputArrays, outputs, config);
|
||||
public DeConv2DDerivative(SameDiff sameDiff, SDVariable[] inputs, DeConv2DConfig config) {
|
||||
super(sameDiff, inputs, config);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -53,25 +53,21 @@ public class DeConv2DTF extends DynamicCustomOp {
|
|||
|
||||
protected DeConv2DConfig config;
|
||||
|
||||
@Builder(builderMethodName = "builder")
|
||||
@Builder(builderMethodName = "sameDiffBuilder")
|
||||
public DeConv2DTF(SameDiff sameDiff,
|
||||
SDVariable[] inputs,
|
||||
INDArray[] inputArrays, INDArray[] outputs,
|
||||
DeConv2DConfig config) {
|
||||
super(null, inputArrays, outputs);
|
||||
this.sameDiff = sameDiff;
|
||||
super(sameDiff, inputs);
|
||||
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public DeConv2DTF(INDArray[] inputs, INDArray[] outputs, DeConv2DConfig config){
|
||||
super(inputs, outputs);
|
||||
|
||||
this.config = config;
|
||||
|
||||
if (inputArrays != null) {
|
||||
addInputArgument(inputArrays);
|
||||
}
|
||||
if (outputs != null) {
|
||||
addOutputArgument(outputs);
|
||||
}
|
||||
|
||||
addArgs();
|
||||
sameDiff.putOpForId(this.getOwnName(), this);
|
||||
sameDiff.addArgsFor(inputs, this);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
|
@ -53,12 +54,23 @@ public class DeConv3D extends DynamicCustomOp {
|
|||
|
||||
protected DeConv3DConfig config;
|
||||
|
||||
public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, DeConv3DConfig config) {
|
||||
public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
|
||||
super(sameDiff, toArr(input, weights, bias));
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public DeConv3D(INDArray[] inputs, INDArray[] outputs, DeConv3DConfig config){
|
||||
super(inputs, outputs);
|
||||
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull DeConv3DConfig config){
|
||||
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||
}
|
||||
|
||||
private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){
|
||||
if(bias != null){
|
||||
return new SDVariable[]{input, weights, bias};
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
|||
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.Onnx;
|
||||
|
@ -35,6 +36,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -53,17 +55,25 @@ public class DepthwiseConv2D extends DynamicCustomOp {
|
|||
|
||||
protected Conv2DConfig config;
|
||||
|
||||
@Builder(builderMethodName = "builder")
|
||||
@Builder(builderMethodName = "sameDiffBuilder")
|
||||
public DepthwiseConv2D(SameDiff sameDiff,
|
||||
SDVariable[] inputFunctions,
|
||||
INDArray[] inputArrays, INDArray[] outputs,
|
||||
Conv2DConfig config) {
|
||||
super(null, inputArrays, outputs);
|
||||
this.sameDiff = sameDiff;
|
||||
super(sameDiff, inputFunctions);
|
||||
|
||||
this.config = config;
|
||||
addArgs();
|
||||
sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point
|
||||
sameDiff.addArgsFor(inputFunctions, this);
|
||||
}
|
||||
|
||||
public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
|
||||
super(inputs, outputs);
|
||||
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
|
||||
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||
}
|
||||
|
||||
public DepthwiseConv2D() {
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
|||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.Onnx;
|
||||
|
@ -48,18 +49,19 @@ public class LocalResponseNormalization extends DynamicCustomOp {
|
|||
protected LocalResponseNormalizationConfig config;
|
||||
|
||||
|
||||
@Builder(builderMethodName = "builder")
|
||||
public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] inputFunctions,
|
||||
INDArray[] inputs, INDArray[] outputs,boolean inPlace,
|
||||
@Builder(builderMethodName = "sameDiffBuilder")
|
||||
public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] inputFunctions, boolean inPlace,
|
||||
LocalResponseNormalizationConfig config) {
|
||||
super(null,sameDiff, inputFunctions, inPlace);
|
||||
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public LocalResponseNormalization(@NonNull INDArray input, INDArray output, @NonNull LocalResponseNormalizationConfig config){
|
||||
super(new INDArray[]{input}, wrapOrNull(output));
|
||||
|
||||
this.config = config;
|
||||
if(inputs != null) {
|
||||
addInputArgument(inputs);
|
||||
}
|
||||
if(outputs!= null) {
|
||||
addOutputArgument(outputs);
|
||||
}
|
||||
addArgs();
|
||||
}
|
||||
|
||||
|
|
|
@ -33,8 +33,8 @@ import java.util.List;
|
|||
@Slf4j
|
||||
public class LocalResponseNormalizationDerivative extends LocalResponseNormalization {
|
||||
@Builder(builderMethodName = "derivativeBuilder")
|
||||
public LocalResponseNormalizationDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, boolean inPlace, LocalResponseNormalizationConfig config) {
|
||||
super(sameDiff, inputFunctions, inputs, outputs, inPlace, config);
|
||||
public LocalResponseNormalizationDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, boolean inPlace, LocalResponseNormalizationConfig config) {
|
||||
super(sameDiff, inputFunctions, inPlace, config);
|
||||
}
|
||||
|
||||
public LocalResponseNormalizationDerivative() {}
|
||||
|
|
|
@ -51,27 +51,18 @@ public class MaxPooling2D extends DynamicCustomOp {
|
|||
public MaxPooling2D() {
|
||||
}
|
||||
|
||||
@Builder(builderMethodName = "builder")
|
||||
@Builder(builderMethodName = "sameDiffBuilder")
|
||||
@SuppressWarnings("Used in lombok")
|
||||
public MaxPooling2D(SameDiff sameDiff, SDVariable input, INDArray arrayInput, INDArray arrayOutput, Pooling2DConfig config) {
|
||||
public MaxPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) {
|
||||
super(null, sameDiff, new SDVariable[]{input}, false);
|
||||
if (arrayInput != null) {
|
||||
addInputArgument(arrayInput);
|
||||
}
|
||||
|
||||
if (arrayOutput != null) {
|
||||
addOutputArgument(arrayOutput);
|
||||
}
|
||||
config.setType(Pooling2D.Pooling2DType.MAX);
|
||||
|
||||
this.config = config;
|
||||
this.sameDiff = sameDiff;
|
||||
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){
|
||||
super(null, new INDArray[]{input}, output == null ? null : new INDArray[]{output});
|
||||
super(null, new INDArray[]{input}, wrapOrNull(output));
|
||||
config.setType(Pooling2D.Pooling2DType.MAX);
|
||||
|
||||
this.config = config;
|
||||
|
|
|
@ -16,8 +16,14 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.Onnx;
|
||||
|
@ -33,9 +39,6 @@ import org.tensorflow.framework.AttrValue;
|
|||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.*;
|
||||
|
||||
|
||||
/**
|
||||
* Pooling2D operation
|
||||
|
@ -70,21 +73,27 @@ public class Pooling2D extends DynamicCustomOp {
|
|||
|
||||
public Pooling2D() {}
|
||||
|
||||
@Builder(builderMethodName = "builder")
|
||||
@Builder(builderMethodName = "sameDiffBuilder")
|
||||
@SuppressWarnings("Used in lombok")
|
||||
public Pooling2D(SameDiff sameDiff, SDVariable[] inputs,INDArray[] arrayInputs, INDArray[] arrayOutputs,Pooling2DConfig config) {
|
||||
super(null,sameDiff, inputs, false);
|
||||
if(arrayInputs != null) {
|
||||
addInputArgument(arrayInputs);
|
||||
}
|
||||
public Pooling2D(SameDiff sameDiff, SDVariable[] inputs,
|
||||
Pooling2DConfig config) {
|
||||
super(null, sameDiff, inputs, false);
|
||||
|
||||
if(arrayOutputs != null) {
|
||||
addOutputArgument(arrayOutputs);
|
||||
}
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
this.config = config;
|
||||
public Pooling2D(@NonNull INDArray[] inputs, INDArray[] outputs, @NonNull Pooling2DConfig config){
|
||||
super(inputs, outputs);
|
||||
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public Pooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){
|
||||
super(new INDArray[]{input}, wrapOrNull(output));
|
||||
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -36,8 +37,12 @@ import java.util.List;
|
|||
@Slf4j
|
||||
public class Pooling2DDerivative extends Pooling2D {
|
||||
@Builder(builderMethodName = "derivativeBuilder")
|
||||
public Pooling2DDerivative(SameDiff sameDiff, SDVariable[] inputs, INDArray[] arrayInputs, INDArray[] arrayOutputs, Pooling2DConfig config) {
|
||||
super(sameDiff, inputs, arrayInputs, arrayOutputs, config);
|
||||
public Pooling2DDerivative(SameDiff sameDiff, SDVariable[] inputs, Pooling2DConfig config) {
|
||||
super(sameDiff, inputs, config);
|
||||
}
|
||||
|
||||
public Pooling2DDerivative(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Pooling2DConfig config){
|
||||
super(new INDArray[]{input, grad}, wrapOrNull(output), config);
|
||||
}
|
||||
|
||||
public Pooling2DDerivative() {}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -39,9 +40,17 @@ import java.util.List;
|
|||
@Slf4j
|
||||
public class SConv2D extends Conv2D {
|
||||
|
||||
@Builder(builderMethodName = "sBuilder")
|
||||
public SConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig conv2DConfig) {
|
||||
super(sameDiff, inputFunctions, inputArrays, outputs, conv2DConfig);
|
||||
@Builder(builderMethodName = "sameDiffSBuilder")
|
||||
public SConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig conv2DConfig) {
|
||||
super(sameDiff, inputFunctions, conv2DConfig);
|
||||
}
|
||||
|
||||
public SConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
|
||||
super(inputs, outputs, config);
|
||||
}
|
||||
|
||||
public SConv2D(@NonNull INDArray input, @NonNull INDArray depthWeights, INDArray pointWeights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
|
||||
this(wrapFilterNull(input, depthWeights, pointWeights, bias), wrapOrNull(output), config);
|
||||
}
|
||||
|
||||
public SConv2D() {}
|
||||
|
|
|
@ -38,8 +38,8 @@ import java.util.List;
|
|||
public class SConv2DDerivative extends SConv2D {
|
||||
|
||||
@Builder(builderMethodName = "sDerviativeBuilder")
|
||||
public SConv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig conv2DConfig) {
|
||||
super(sameDiff, inputFunctions, inputArrays, outputs, conv2DConfig);
|
||||
public SConv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig conv2DConfig) {
|
||||
super(sameDiff, inputFunctions, conv2DConfig);
|
||||
}
|
||||
|
||||
public SConv2DDerivative() {}
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class BitsHammingDistance extends DynamicCustomOp {
|
||||
|
||||
public BitsHammingDistance(){ }
|
||||
|
||||
public BitsHammingDistance(@NonNull SameDiff sd, @NonNull SDVariable x, @NonNull SDVariable y){
|
||||
super(sd, new SDVariable[]{x, y});
|
||||
}
|
||||
|
||||
public BitsHammingDistance(@NonNull INDArray x, @NonNull INDArray y){
|
||||
super(new INDArray[]{x, y}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "bits_hamming_distance";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected 2 input datatypes, got %s", dataTypes);
|
||||
Preconditions.checkState(dataTypes.get(0).isIntType() && dataTypes.get(1).isIntType(), "Input datatypes must be integer type, got %s", dataTypes);
|
||||
return Collections.singletonList(DataType.LONG);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Bit-wise AND operation, broadcastable
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class BitwiseAnd extends BaseDynamicTransformOp {
|
||||
|
||||
public BitwiseAnd(SameDiff sameDiff, SDVariable x, SDVariable y) {
|
||||
super(sameDiff, new SDVariable[] {x, y} ,false);
|
||||
}
|
||||
|
||||
public BitwiseAnd(INDArray x, INDArray y, INDArray output) {
|
||||
super(new INDArray[]{x, y}, new INDArray[]{output});
|
||||
}
|
||||
|
||||
public BitwiseAnd(INDArray x, INDArray y) {
|
||||
this(x, y,x.ulike());
|
||||
}
|
||||
|
||||
public BitwiseAnd() {}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "bitwise_and";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "bitwise_and";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
throw new UnsupportedOperationException("Not yet implemented: " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Bit-wise OR operation, broadcastable
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class BitwiseOr extends BaseDynamicTransformOp {
|
||||
|
||||
public BitwiseOr(SameDiff sameDiff, SDVariable x, SDVariable y) {
|
||||
super(sameDiff, new SDVariable[] {x, y} ,false);
|
||||
}
|
||||
|
||||
public BitwiseOr(INDArray x, INDArray y, INDArray output) {
|
||||
super(new INDArray[]{x, y}, new INDArray[]{output});
|
||||
}
|
||||
|
||||
public BitwiseOr(INDArray x, INDArray y) {
|
||||
this(x, y,x.ulike());
|
||||
}
|
||||
|
||||
public BitwiseOr() {}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "bitwise_or";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "bitwise_or";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
throw new UnsupportedOperationException("Not yet implemented: " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Bit-wise XOR operation, broadcastable
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class BitwiseXor extends BaseDynamicTransformOp {
|
||||
|
||||
public BitwiseXor(SameDiff sameDiff, SDVariable x, SDVariable y) {
|
||||
super(sameDiff, new SDVariable[] {x, y} ,false);
|
||||
}
|
||||
|
||||
public BitwiseXor(INDArray x, INDArray y, INDArray output) {
|
||||
super(new INDArray[]{x, y}, new INDArray[]{output});
|
||||
}
|
||||
|
||||
public BitwiseXor(INDArray x, INDArray y) {
|
||||
this(x, y,x.ulike());
|
||||
}
|
||||
|
||||
public BitwiseXor() {}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "bitwise_xor";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "bitwise_xor";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
throw new UnsupportedOperationException("Not yet implemented: " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -61,7 +61,7 @@ public class CyclicRShiftBits extends BaseDynamicTransformOp {
|
|||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ public class CyclicShiftBits extends BaseDynamicTransformOp {
|
|||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ public class RShiftBits extends BaseDynamicTransformOp {
|
|||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ public class ShiftBits extends BaseDynamicTransformOp {
|
|||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -235,24 +235,20 @@ public class Convolution {
|
|||
public static INDArray pooling2D(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw,
|
||||
int dh, int dw, boolean isSameMode, Pooling2D.Pooling2DType type, Pooling2D.Divisor divisor,
|
||||
double extra, int virtualHeight, int virtualWidth, INDArray out) {
|
||||
Pooling2D pooling = Pooling2D.builder()
|
||||
.arrayInputs(new INDArray[]{img})
|
||||
.arrayOutputs(new INDArray[]{out})
|
||||
.config(Pooling2DConfig.builder()
|
||||
.dH(dh)
|
||||
.dW(dw)
|
||||
.extra(extra)
|
||||
.kH(kh)
|
||||
.kW(kw)
|
||||
.pH(ph)
|
||||
.pW(pw)
|
||||
.isSameMode(isSameMode)
|
||||
.sH(sy)
|
||||
.sW(sx)
|
||||
.type(type)
|
||||
.divisor(divisor)
|
||||
.build())
|
||||
.build();
|
||||
Pooling2D pooling = new Pooling2D(img, out, Pooling2DConfig.builder()
|
||||
.dH(dh)
|
||||
.dW(dw)
|
||||
.extra(extra)
|
||||
.kH(kh)
|
||||
.kW(kw)
|
||||
.pH(ph)
|
||||
.pW(pw)
|
||||
.isSameMode(isSameMode)
|
||||
.sH(sy)
|
||||
.sW(sx)
|
||||
.type(type)
|
||||
.divisor(divisor)
|
||||
.build());
|
||||
Nd4j.getExecutioner().execAndReturn(pooling);
|
||||
return out;
|
||||
}
|
||||
|
|
|
@ -96,57 +96,6 @@ public abstract class NDArrayIndex implements INDArrayIndex {
|
|||
return offset(arr.stride(), Indices.offsets(arr.shape(), indices));
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the shape and stride for
|
||||
* new axes based dimensions
|
||||
* @param arr the array to update
|
||||
* the shape/strides for
|
||||
* @param indexes the indexes to update based on
|
||||
*/
|
||||
public static void updateForNewAxes(INDArray arr, INDArrayIndex... indexes) {
|
||||
int numNewAxes = NDArrayIndex.numNewAxis(indexes);
|
||||
if (numNewAxes >= 1 && (indexes[0].length() > 1 || indexes[0] instanceof NDArrayIndexAll)) {
|
||||
List<Long> newShape = new ArrayList<>();
|
||||
List<Long> newStrides = new ArrayList<>();
|
||||
int currDimension = 0;
|
||||
for (int i = 0; i < indexes.length; i++) {
|
||||
if (indexes[i] instanceof NewAxis) {
|
||||
newShape.add(1L);
|
||||
newStrides.add(0L);
|
||||
} else {
|
||||
newShape.add(arr.size(currDimension));
|
||||
newStrides.add(arr.size(currDimension));
|
||||
currDimension++;
|
||||
}
|
||||
}
|
||||
|
||||
while (currDimension < arr.rank()) {
|
||||
newShape.add((long) currDimension);
|
||||
newStrides.add((long) currDimension);
|
||||
currDimension++;
|
||||
}
|
||||
|
||||
long[] newShapeArr = Longs.toArray(newShape);
|
||||
long[] newStrideArr = Longs.toArray(newStrides);
|
||||
|
||||
// FIXME: this is wrong, it breaks shapeInfo immutability
|
||||
arr.setShape(newShapeArr);
|
||||
arr.setStride(newStrideArr);
|
||||
|
||||
|
||||
} else {
|
||||
if (numNewAxes > 0) {
|
||||
long[] newShape = Longs.concat(ArrayUtil.toLongArray(ArrayUtil.nTimes(numNewAxes, 1)), arr.shape());
|
||||
long[] newStrides = Longs.concat(new long[numNewAxes], arr.stride());
|
||||
arr.setShape(newShape);
|
||||
arr.setStride(newStrides);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Compute the offset given an array of offsets.
|
||||
* The offset is computed(for both fortran an d c ordering) as:
|
||||
|
|
|
@ -54,7 +54,7 @@ public class DeallocatorService {
|
|||
deallocatorThreads = new Thread[numThreads];
|
||||
queues = new ReferenceQueue[numThreads];
|
||||
for (int e = 0; e < numThreads; e++) {
|
||||
log.debug("Starting deallocator thread {}", e + 1);
|
||||
log.trace("Starting deallocator thread {}", e + 1);
|
||||
queues[e] = new ReferenceQueue<>();
|
||||
|
||||
int deviceId = e % numDevices;
|
||||
|
|
|
@ -1151,4 +1151,6 @@ public interface NativeOps {
|
|||
|
||||
int lastErrorCode();
|
||||
String lastErrorMessage();
|
||||
|
||||
boolean isBlasVersionMatches(int major, int minor, int build);
|
||||
}
|
||||
|
|
|
@ -101,7 +101,7 @@ public class NativeOpsHolder {
|
|||
}
|
||||
//deviceNativeOps.setOmpNumThreads(4);
|
||||
|
||||
log.info("Number of threads used for NativeOps: {}", deviceNativeOps.ompGetMaxThreads());
|
||||
log.info("Number of threads used for OpenMP: {}", deviceNativeOps.ompGetMaxThreads());
|
||||
} catch (Exception | Error e) {
|
||||
throw new RuntimeException(
|
||||
"ND4J is probably missing dependencies. For more information, please refer to: http://nd4j.org/getstarted.html",
|
||||
|
|
|
@ -51,7 +51,8 @@ public abstract class Nd4jBlas implements Blas {
|
|||
numThreads = NativeOpsHolder.getCores(Runtime.getRuntime().availableProcessors());
|
||||
setMaxThreads(numThreads);
|
||||
}
|
||||
log.info("Number of threads used for BLAS: {}", getMaxThreads());
|
||||
|
||||
log.info("Number of threads used for OpenMP BLAS: {}", getMaxThreads());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -52,6 +52,7 @@ public class JCublasBackend extends Nd4jBackend {
|
|||
throw new RuntimeException("No CUDA devices were found in system");
|
||||
}
|
||||
Loader.load(org.bytedeco.cuda.global.cublas.class);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -108,6 +108,22 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
/*
|
||||
val major = new int[1];
|
||||
val minor = new int[1];
|
||||
val build = new int[1];
|
||||
org.bytedeco.cuda.global.cublas.cublasGetProperty(0, major);
|
||||
org.bytedeco.cuda.global.cublas.cublasGetProperty(1, minor);
|
||||
org.bytedeco.cuda.global.cublas.cublasGetProperty(2, build);
|
||||
|
||||
val pew = new int[100];
|
||||
org.bytedeco.cuda.global.cudart.cudaDriverGetVersion(pew);
|
||||
|
||||
nativeOps.isBlasVersionMatches(major[0], minor[0], build[0]);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
*/
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -28,10 +28,12 @@ import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
|||
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
|
||||
import org.nd4j.jita.conf.CudaEnvironment;
|
||||
import org.nd4j.linalg.api.blas.impl.BaseLevel3;
|
||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
||||
import org.nd4j.linalg.factory.DataTypeValidation;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.jcublas.CublasPointer;
|
||||
|
@ -113,8 +115,13 @@ public class JcublasLevel3 extends BaseLevel3 {
|
|||
@Override
|
||||
protected void sgemm(char Order, char TransA, char TransB, int M, int N, int K, float alpha, INDArray A, int lda,
|
||||
INDArray B, int ldb, float beta, INDArray C, int ldc) {
|
||||
//A = Shape.toOffsetZero(A);
|
||||
//B = Shape.toOffsetZero(B);
|
||||
/*
|
||||
val ctx = AtomicAllocator.getInstance().getDeviceContext();
|
||||
val handle = ctx.getCublasHandle();
|
||||
synchronized (handle) {
|
||||
Nd4j.exec(new Mmul(A, B, C, MMulTranspose.builder().transposeA(false).transposeB(false).build()));
|
||||
}
|
||||
*/
|
||||
|
||||
Nd4j.getExecutioner().push();
|
||||
|
||||
|
@ -141,6 +148,7 @@ public class JcublasLevel3 extends BaseLevel3 {
|
|||
}
|
||||
|
||||
allocator.registerAction(ctx, C, A, B);
|
||||
|
||||
OpExecutionerUtil.checkForAny(C);
|
||||
}
|
||||
|
||||
|
|
|
@ -557,6 +557,13 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
|
|||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public Environment(Pointer p) { super(p); }
|
||||
|
||||
/**
|
||||
* These 3 fields are mostly for CUDA/cuBLAS version tracking
|
||||
*/
|
||||
public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter);
|
||||
public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter);
|
||||
public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter);
|
||||
|
||||
public static native Environment getInstance();
|
||||
|
||||
public native @Cast("bool") boolean isVerbose();
|
||||
|
@ -1874,7 +1881,7 @@ public native void setOmpNumThreads(int threads);
|
|||
public native void setOmpMinThreads(int threads);
|
||||
|
||||
|
||||
|
||||
public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build);
|
||||
|
||||
/**
|
||||
*
|
||||
|
|
|
@ -557,6 +557,13 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
|||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public Environment(Pointer p) { super(p); }
|
||||
|
||||
/**
|
||||
* These 3 fields are mostly for CUDA/cuBLAS version tracking
|
||||
*/
|
||||
public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter);
|
||||
public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter);
|
||||
public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter);
|
||||
|
||||
public static native Environment getInstance();
|
||||
|
||||
public native @Cast("bool") boolean isVerbose();
|
||||
|
@ -1874,7 +1881,7 @@ public native void setOmpNumThreads(int threads);
|
|||
public native void setOmpMinThreads(int threads);
|
||||
|
||||
|
||||
|
||||
public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build);
|
||||
|
||||
/**
|
||||
*
|
||||
|
@ -21929,6 +21936,78 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* This operation applies bitwise AND
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* \tparam T
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_bitwise_and)
|
||||
@Namespace("nd4j::ops") public static class bitwise_and extends BroadcastableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public bitwise_and(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public bitwise_and(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public bitwise_and position(long position) {
|
||||
return (bitwise_and)super.position(position);
|
||||
}
|
||||
|
||||
public bitwise_and() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* This operation applies bitwise OR
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* \tparam T
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_bitwise_or)
|
||||
@Namespace("nd4j::ops") public static class bitwise_or extends BroadcastableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public bitwise_or(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public bitwise_or(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public bitwise_or position(long position) {
|
||||
return (bitwise_or)super.position(position);
|
||||
}
|
||||
|
||||
public bitwise_or() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* This operation applies bitwise XOR
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* \tparam T
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_bitwise_xor)
|
||||
@Namespace("nd4j::ops") public static class bitwise_xor extends BroadcastableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public bitwise_xor(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public bitwise_xor(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public bitwise_xor position(long position) {
|
||||
return (bitwise_xor)super.position(position);
|
||||
}
|
||||
|
||||
public bitwise_xor() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* This operation returns hamming distance based on bits
|
||||
*
|
||||
|
|
|
@ -389,10 +389,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
.build();
|
||||
|
||||
INDArray input = Nd4j.create(inSize);
|
||||
AvgPooling2D avgPooling2D = AvgPooling2D.builder()
|
||||
.arrayInput(input)
|
||||
.config(conf)
|
||||
.build();
|
||||
AvgPooling2D avgPooling2D = new AvgPooling2D(input, null, conf);
|
||||
|
||||
val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D);
|
||||
|
||||
|
@ -410,10 +407,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
//Test backprop:
|
||||
Pooling2DDerivative avg2dDeriv = Pooling2DDerivative.derivativeBuilder()
|
||||
.arrayInputs(new INDArray[]{input, grad})
|
||||
.config(conf)
|
||||
.build();
|
||||
Pooling2DDerivative avg2dDeriv = new Pooling2DDerivative(input, grad, null, conf);
|
||||
|
||||
val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv);
|
||||
assertEquals(1, outSizesBP.size());
|
||||
|
@ -435,10 +429,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
.build();
|
||||
|
||||
INDArray input = Nd4j.create(inSize);
|
||||
AvgPooling2D avgPooling2D = AvgPooling2D.builder()
|
||||
.arrayInput(input)
|
||||
.config(conf)
|
||||
.build();
|
||||
AvgPooling2D avgPooling2D = new AvgPooling2D(input, null, conf);
|
||||
|
||||
val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D);
|
||||
assertEquals(1, outSizes.size());
|
||||
|
@ -454,11 +445,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
INDArray grad = Nd4j.create(exp);
|
||||
|
||||
//Test backprop:
|
||||
Pooling2DDerivative avg2dDeriv = Pooling2DDerivative.derivativeBuilder()
|
||||
.arrayInputs(new INDArray[]{input, grad}) //Original input, and output gradient (eps - same shape as output)
|
||||
.arrayOutputs(new INDArray[]{Nd4j.create(inSize)}) //Output for BP: same shape as original input
|
||||
.config(conf)
|
||||
.build();
|
||||
Pooling2DDerivative avg2dDeriv = new Pooling2DDerivative(input, grad, Nd4j.create(inSize), conf);
|
||||
|
||||
val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv);
|
||||
assertEquals(1, outSizesBP.size());
|
||||
|
@ -749,7 +736,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
.isSameMode(false)
|
||||
.build();
|
||||
|
||||
SDVariable out = sd.cnn().conv2d(vars, c);
|
||||
SDVariable out = sd.cnn().conv2d("conv", vars, c);
|
||||
out = sd.nn().tanh("out", out);
|
||||
|
||||
INDArray outArr = sd.execAndEndResult();
|
||||
|
|
Loading…
Reference in New Issue