commit
b7226bdd7a
|
@ -87,11 +87,7 @@ public class MKLDNNSubsamplingHelper implements SubsamplingHelper {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
Pooling2DDerivative d = Pooling2DDerivative.derivativeBuilder()
|
Pooling2DDerivative d = new Pooling2DDerivative(input, epsilon, gradAtInput, conf);
|
||||||
.config(conf)
|
|
||||||
.arrayInputs(new INDArray[]{input, epsilon})
|
|
||||||
.arrayOutputs(new INDArray[]{gradAtInput})
|
|
||||||
.build();
|
|
||||||
|
|
||||||
Nd4j.exec(d);
|
Nd4j.exec(d);
|
||||||
return new Pair<Gradient,INDArray>(new DefaultGradient(), gradAtInput);
|
return new Pair<Gradient,INDArray>(new DefaultGradient(), gradAtInput);
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef SAMEDIFF_BLASVERSIONHELPER_H
|
||||||
|
#define SAMEDIFF_BLASVERSIONHELPER_H
|
||||||
|
|
||||||
|
#include <dll.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
class ND4J_EXPORT BlasVersionHelper {
|
||||||
|
public:
|
||||||
|
int _blasMajorVersion = 0;
|
||||||
|
int _blasMinorVersion = 0;
|
||||||
|
int _blasPatchVersion = 0;
|
||||||
|
|
||||||
|
BlasVersionHelper();
|
||||||
|
~BlasVersionHelper() = default;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif //DEV_TESTS_BLASVERSIONHELPER_H
|
|
@ -253,20 +253,20 @@ if(CUDA_BLAS)
|
||||||
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
|
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
|
||||||
|
|
||||||
if (NOT BUILD_TESTS)
|
if (NOT BUILD_TESTS)
|
||||||
CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
|
CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA}
|
||||||
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
||||||
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
||||||
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
|
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
|
||||||
Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
|
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
|
||||||
else()
|
else()
|
||||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true")
|
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true")
|
||||||
|
|
||||||
CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
|
CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA}
|
||||||
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
||||||
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
||||||
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
|
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
|
||||||
Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
|
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@
|
||||||
|
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
#include "BlasVersionHelper.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
@ -66,6 +66,13 @@ namespace nd4j {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
|
BlasVersionHelper ver;
|
||||||
|
_blasMajorVersion = ver._blasMajorVersion;
|
||||||
|
_blasMinorVersion = ver._blasMinorVersion;
|
||||||
|
_blasPatchVersion = ver._blasPatchVersion;
|
||||||
|
printf("ND4J CUDA build version: %i.%i.%i\n", _blasMajorVersion, _blasMinorVersion, _blasPatchVersion);
|
||||||
|
fflush(stdout);
|
||||||
|
|
||||||
int devCnt = 0;
|
int devCnt = 0;
|
||||||
cudaGetDeviceCount(&devCnt);
|
cudaGetDeviceCount(&devCnt);
|
||||||
auto devProperties = new cudaDeviceProp[devCnt];
|
auto devProperties = new cudaDeviceProp[devCnt];
|
||||||
|
|
|
@ -56,6 +56,13 @@ namespace nd4j{
|
||||||
Environment();
|
Environment();
|
||||||
~Environment();
|
~Environment();
|
||||||
public:
|
public:
|
||||||
|
/**
|
||||||
|
* These 3 fields are mostly for CUDA/cuBLAS version tracking
|
||||||
|
*/
|
||||||
|
int _blasMajorVersion = 0;
|
||||||
|
int _blasMinorVersion = 0;
|
||||||
|
int _blasPatchVersion = 0;
|
||||||
|
|
||||||
static Environment* getInstance();
|
static Environment* getInstance();
|
||||||
|
|
||||||
bool isVerbose();
|
bool isVerbose();
|
||||||
|
|
|
@ -647,7 +647,7 @@ ND4J_EXPORT void setOmpNumThreads(int threads);
|
||||||
ND4J_EXPORT void setOmpMinThreads(int threads);
|
ND4J_EXPORT void setOmpMinThreads(int threads);
|
||||||
|
|
||||||
|
|
||||||
|
ND4J_EXPORT bool isBlasVersionMatches(int major, int minor, int build);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
|
|
@ -583,7 +583,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
|
||||||
auto fName = builder.CreateString(*(var->getName()));
|
auto fName = builder.CreateString(*(var->getName()));
|
||||||
auto id = CreateIntPair(builder, var->id(), var->index());
|
auto id = CreateIntPair(builder, var->id(), var->index());
|
||||||
|
|
||||||
auto fv = CreateFlatVariable(builder, id, fName, static_cast<nd4j::graph::DataType>(array->dataType()), 0, fArray);
|
auto fv = CreateFlatVariable(builder, id, fName, static_cast<nd4j::graph::DType>(array->dataType()), 0, fArray);
|
||||||
|
|
||||||
variables_vector.push_back(fv);
|
variables_vector.push_back(fv);
|
||||||
arrays++;
|
arrays++;
|
||||||
|
|
|
@ -728,6 +728,10 @@ void execReduce3Tad(Nd4jPointer *extraPointers,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isBlasVersionMatches(int major, int minor, int build) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param opNum
|
* @param opNum
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "../BlasVersionHelper.h"
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
BlasVersionHelper::BlasVersionHelper() {
|
||||||
|
_blasMajorVersion = __CUDACC_VER_MAJOR__;
|
||||||
|
_blasMinorVersion = __CUDACC_VER_MINOR__;
|
||||||
|
_blasPatchVersion = __CUDACC_VER_BUILD__;
|
||||||
|
}
|
||||||
|
}
|
|
@ -3357,6 +3357,18 @@ void deleteTadPack(nd4j::TadPack* ptr) {
|
||||||
delete ptr;
|
delete ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isBlasVersionMatches(int major, int minor, int build) {
|
||||||
|
auto result = major == Environment::getInstance()->_blasMajorVersion && minor == Environment::getInstance()->_blasMinorVersion && build == Environment::getInstance()->_blasPatchVersion;
|
||||||
|
|
||||||
|
if (!result) {
|
||||||
|
nd4j_printf("CUDA/cuBLAS version mismatch. Expected: %i.%i.%i but got %i.%i.%i instead\n", Environment::getInstance()->_blasMajorVersion, Environment::getInstance()->_blasMinorVersion, Environment::getInstance()->_blasPatchVersion, major, minor, build);
|
||||||
|
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(152);
|
||||||
|
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("CUDA/cuBLAS version mismatch");
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) {
|
nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) {
|
||||||
return nd4j::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype);
|
return nd4j::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype);
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ namespace nd4j {
|
||||||
public:
|
public:
|
||||||
static int asInt(DataType type);
|
static int asInt(DataType type);
|
||||||
static DataType fromInt(int dtype);
|
static DataType fromInt(int dtype);
|
||||||
static DataType fromFlatDataType(nd4j::graph::DataType dtype);
|
static DataType fromFlatDataType(nd4j::graph::DType dtype);
|
||||||
FORCEINLINE static std::string asString(DataType dataType);
|
FORCEINLINE static std::string asString(DataType dataType);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
|
@ -27,7 +27,7 @@ namespace nd4j {
|
||||||
return (DataType) val;
|
return (DataType) val;
|
||||||
}
|
}
|
||||||
|
|
||||||
DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DataType dtype) {
|
DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DType dtype) {
|
||||||
return (DataType) dtype;
|
return (DataType) dtype;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,56 +40,56 @@ inline const char *EnumNameByteOrder(ByteOrder e) {
|
||||||
return EnumNamesByteOrder()[index];
|
return EnumNamesByteOrder()[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
enum DataType {
|
enum DType {
|
||||||
DataType_INHERIT = 0,
|
DType_INHERIT = 0,
|
||||||
DataType_BOOL = 1,
|
DType_BOOL = 1,
|
||||||
DataType_FLOAT8 = 2,
|
DType_FLOAT8 = 2,
|
||||||
DataType_HALF = 3,
|
DType_HALF = 3,
|
||||||
DataType_HALF2 = 4,
|
DType_HALF2 = 4,
|
||||||
DataType_FLOAT = 5,
|
DType_FLOAT = 5,
|
||||||
DataType_DOUBLE = 6,
|
DType_DOUBLE = 6,
|
||||||
DataType_INT8 = 7,
|
DType_INT8 = 7,
|
||||||
DataType_INT16 = 8,
|
DType_INT16 = 8,
|
||||||
DataType_INT32 = 9,
|
DType_INT32 = 9,
|
||||||
DataType_INT64 = 10,
|
DType_INT64 = 10,
|
||||||
DataType_UINT8 = 11,
|
DType_UINT8 = 11,
|
||||||
DataType_UINT16 = 12,
|
DType_UINT16 = 12,
|
||||||
DataType_UINT32 = 13,
|
DType_UINT32 = 13,
|
||||||
DataType_UINT64 = 14,
|
DType_UINT64 = 14,
|
||||||
DataType_QINT8 = 15,
|
DType_QINT8 = 15,
|
||||||
DataType_QINT16 = 16,
|
DType_QINT16 = 16,
|
||||||
DataType_BFLOAT16 = 17,
|
DType_BFLOAT16 = 17,
|
||||||
DataType_UTF8 = 50,
|
DType_UTF8 = 50,
|
||||||
DataType_MIN = DataType_INHERIT,
|
DType_MIN = DType_INHERIT,
|
||||||
DataType_MAX = DataType_UTF8
|
DType_MAX = DType_UTF8
|
||||||
};
|
};
|
||||||
|
|
||||||
inline const DataType (&EnumValuesDataType())[19] {
|
inline const DType (&EnumValuesDType())[19] {
|
||||||
static const DataType values[] = {
|
static const DType values[] = {
|
||||||
DataType_INHERIT,
|
DType_INHERIT,
|
||||||
DataType_BOOL,
|
DType_BOOL,
|
||||||
DataType_FLOAT8,
|
DType_FLOAT8,
|
||||||
DataType_HALF,
|
DType_HALF,
|
||||||
DataType_HALF2,
|
DType_HALF2,
|
||||||
DataType_FLOAT,
|
DType_FLOAT,
|
||||||
DataType_DOUBLE,
|
DType_DOUBLE,
|
||||||
DataType_INT8,
|
DType_INT8,
|
||||||
DataType_INT16,
|
DType_INT16,
|
||||||
DataType_INT32,
|
DType_INT32,
|
||||||
DataType_INT64,
|
DType_INT64,
|
||||||
DataType_UINT8,
|
DType_UINT8,
|
||||||
DataType_UINT16,
|
DType_UINT16,
|
||||||
DataType_UINT32,
|
DType_UINT32,
|
||||||
DataType_UINT64,
|
DType_UINT64,
|
||||||
DataType_QINT8,
|
DType_QINT8,
|
||||||
DataType_QINT16,
|
DType_QINT16,
|
||||||
DataType_BFLOAT16,
|
DType_BFLOAT16,
|
||||||
DataType_UTF8
|
DType_UTF8
|
||||||
};
|
};
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const char * const *EnumNamesDataType() {
|
inline const char * const *EnumNamesDType() {
|
||||||
static const char * const names[] = {
|
static const char * const names[] = {
|
||||||
"INHERIT",
|
"INHERIT",
|
||||||
"BOOL",
|
"BOOL",
|
||||||
|
@ -147,9 +147,9 @@ inline const char * const *EnumNamesDataType() {
|
||||||
return names;
|
return names;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const char *EnumNameDataType(DataType e) {
|
inline const char *EnumNameDType(DType e) {
|
||||||
const size_t index = static_cast<int>(e);
|
const size_t index = static_cast<int>(e);
|
||||||
return EnumNamesDataType()[index];
|
return EnumNamesDType()[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
|
@ -165,8 +165,8 @@ struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
const flatbuffers::Vector<int8_t> *buffer() const {
|
const flatbuffers::Vector<int8_t> *buffer() const {
|
||||||
return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_BUFFER);
|
return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_BUFFER);
|
||||||
}
|
}
|
||||||
DataType dtype() const {
|
DType dtype() const {
|
||||||
return static_cast<DataType>(GetField<int8_t>(VT_DTYPE, 0));
|
return static_cast<DType>(GetField<int8_t>(VT_DTYPE, 0));
|
||||||
}
|
}
|
||||||
ByteOrder byteOrder() const {
|
ByteOrder byteOrder() const {
|
||||||
return static_cast<ByteOrder>(GetField<int8_t>(VT_BYTEORDER, 0));
|
return static_cast<ByteOrder>(GetField<int8_t>(VT_BYTEORDER, 0));
|
||||||
|
@ -192,7 +192,7 @@ struct FlatArrayBuilder {
|
||||||
void add_buffer(flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer) {
|
void add_buffer(flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer) {
|
||||||
fbb_.AddOffset(FlatArray::VT_BUFFER, buffer);
|
fbb_.AddOffset(FlatArray::VT_BUFFER, buffer);
|
||||||
}
|
}
|
||||||
void add_dtype(DataType dtype) {
|
void add_dtype(DType dtype) {
|
||||||
fbb_.AddElement<int8_t>(FlatArray::VT_DTYPE, static_cast<int8_t>(dtype), 0);
|
fbb_.AddElement<int8_t>(FlatArray::VT_DTYPE, static_cast<int8_t>(dtype), 0);
|
||||||
}
|
}
|
||||||
void add_byteOrder(ByteOrder byteOrder) {
|
void add_byteOrder(ByteOrder byteOrder) {
|
||||||
|
@ -214,7 +214,7 @@ inline flatbuffers::Offset<FlatArray> CreateFlatArray(
|
||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
||||||
flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer = 0,
|
flatbuffers::Offset<flatbuffers::Vector<int8_t>> buffer = 0,
|
||||||
DataType dtype = DataType_INHERIT,
|
DType dtype = DType_INHERIT,
|
||||||
ByteOrder byteOrder = ByteOrder_LE) {
|
ByteOrder byteOrder = ByteOrder_LE) {
|
||||||
FlatArrayBuilder builder_(_fbb);
|
FlatArrayBuilder builder_(_fbb);
|
||||||
builder_.add_buffer(buffer);
|
builder_.add_buffer(buffer);
|
||||||
|
@ -228,7 +228,7 @@ inline flatbuffers::Offset<FlatArray> CreateFlatArrayDirect(
|
||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
const std::vector<int64_t> *shape = nullptr,
|
const std::vector<int64_t> *shape = nullptr,
|
||||||
const std::vector<int8_t> *buffer = nullptr,
|
const std::vector<int8_t> *buffer = nullptr,
|
||||||
DataType dtype = DataType_INHERIT,
|
DType dtype = DType_INHERIT,
|
||||||
ByteOrder byteOrder = ByteOrder_LE) {
|
ByteOrder byteOrder = ByteOrder_LE) {
|
||||||
return nd4j::graph::CreateFlatArray(
|
return nd4j::graph::CreateFlatArray(
|
||||||
_fbb,
|
_fbb,
|
||||||
|
|
|
@ -23,7 +23,7 @@ nd4j.graph.ByteOrder = {
|
||||||
/**
|
/**
|
||||||
* @enum
|
* @enum
|
||||||
*/
|
*/
|
||||||
nd4j.graph.DataType = {
|
nd4j.graph.DType = {
|
||||||
INHERIT: 0,
|
INHERIT: 0,
|
||||||
BOOL: 1,
|
BOOL: 1,
|
||||||
FLOAT8: 2,
|
FLOAT8: 2,
|
||||||
|
@ -123,11 +123,11 @@ nd4j.graph.FlatArray.prototype.bufferArray = function() {
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @returns {nd4j.graph.DataType}
|
* @returns {nd4j.graph.DType}
|
||||||
*/
|
*/
|
||||||
nd4j.graph.FlatArray.prototype.dtype = function() {
|
nd4j.graph.FlatArray.prototype.dtype = function() {
|
||||||
var offset = this.bb.__offset(this.bb_pos, 8);
|
var offset = this.bb.__offset(this.bb_pos, 8);
|
||||||
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT;
|
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -205,10 +205,10 @@ nd4j.graph.FlatArray.startBufferVector = function(builder, numElems) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {flatbuffers.Builder} builder
|
* @param {flatbuffers.Builder} builder
|
||||||
* @param {nd4j.graph.DataType} dtype
|
* @param {nd4j.graph.DType} dtype
|
||||||
*/
|
*/
|
||||||
nd4j.graph.FlatArray.addDtype = function(builder, dtype) {
|
nd4j.graph.FlatArray.addDtype = function(builder, dtype) {
|
||||||
builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT);
|
builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
namespace nd4j.graph
|
namespace nd4j.graph
|
||||||
{
|
{
|
||||||
|
|
||||||
public enum DataType : sbyte
|
public enum DType : sbyte
|
||||||
{
|
{
|
||||||
INHERIT = 0,
|
INHERIT = 0,
|
||||||
BOOL = 1,
|
BOOL = 1,
|
|
@ -2,8 +2,8 @@
|
||||||
|
|
||||||
package nd4j.graph;
|
package nd4j.graph;
|
||||||
|
|
||||||
public final class DataType {
|
public final class DType {
|
||||||
private DataType() { }
|
private DType() { }
|
||||||
public static final byte INHERIT = 0;
|
public static final byte INHERIT = 0;
|
||||||
public static final byte BOOL = 1;
|
public static final byte BOOL = 1;
|
||||||
public static final byte FLOAT8 = 2;
|
public static final byte FLOAT8 = 2;
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
# namespace: graph
|
# namespace: graph
|
||||||
|
|
||||||
class DataType(object):
|
class DType(object):
|
||||||
INHERIT = 0
|
INHERIT = 0
|
||||||
BOOL = 1
|
BOOL = 1
|
||||||
FLOAT8 = 2
|
FLOAT8 = 2
|
|
@ -33,13 +33,13 @@ public struct FlatArray : IFlatbufferObject
|
||||||
public ArraySegment<byte>? GetBufferBytes() { return __p.__vector_as_arraysegment(6); }
|
public ArraySegment<byte>? GetBufferBytes() { return __p.__vector_as_arraysegment(6); }
|
||||||
#endif
|
#endif
|
||||||
public sbyte[] GetBufferArray() { return __p.__vector_as_array<sbyte>(6); }
|
public sbyte[] GetBufferArray() { return __p.__vector_as_array<sbyte>(6); }
|
||||||
public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } }
|
public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
|
||||||
public ByteOrder ByteOrder { get { int o = __p.__offset(10); return o != 0 ? (ByteOrder)__p.bb.GetSbyte(o + __p.bb_pos) : ByteOrder.LE; } }
|
public ByteOrder ByteOrder { get { int o = __p.__offset(10); return o != 0 ? (ByteOrder)__p.bb.GetSbyte(o + __p.bb_pos) : ByteOrder.LE; } }
|
||||||
|
|
||||||
public static Offset<FlatArray> CreateFlatArray(FlatBufferBuilder builder,
|
public static Offset<FlatArray> CreateFlatArray(FlatBufferBuilder builder,
|
||||||
VectorOffset shapeOffset = default(VectorOffset),
|
VectorOffset shapeOffset = default(VectorOffset),
|
||||||
VectorOffset bufferOffset = default(VectorOffset),
|
VectorOffset bufferOffset = default(VectorOffset),
|
||||||
DataType dtype = DataType.INHERIT,
|
DType dtype = DType.INHERIT,
|
||||||
ByteOrder byteOrder = ByteOrder.LE) {
|
ByteOrder byteOrder = ByteOrder.LE) {
|
||||||
builder.StartObject(4);
|
builder.StartObject(4);
|
||||||
FlatArray.AddBuffer(builder, bufferOffset);
|
FlatArray.AddBuffer(builder, bufferOffset);
|
||||||
|
@ -58,7 +58,7 @@ public struct FlatArray : IFlatbufferObject
|
||||||
public static VectorOffset CreateBufferVector(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte(data[i]); return builder.EndVector(); }
|
public static VectorOffset CreateBufferVector(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte(data[i]); return builder.EndVector(); }
|
||||||
public static VectorOffset CreateBufferVectorBlock(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
public static VectorOffset CreateBufferVectorBlock(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
||||||
public static void StartBufferVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
public static void StartBufferVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
||||||
public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
||||||
public static void AddByteOrder(FlatBufferBuilder builder, ByteOrder byteOrder) { builder.AddSbyte(3, (sbyte)byteOrder, 0); }
|
public static void AddByteOrder(FlatBufferBuilder builder, ByteOrder byteOrder) { builder.AddSbyte(3, (sbyte)byteOrder, 0); }
|
||||||
public static Offset<FlatArray> EndFlatArray(FlatBufferBuilder builder) {
|
public static Offset<FlatArray> EndFlatArray(FlatBufferBuilder builder) {
|
||||||
int o = builder.EndObject();
|
int o = builder.EndObject();
|
||||||
|
|
|
@ -97,14 +97,14 @@ public struct FlatNode : IFlatbufferObject
|
||||||
public ArraySegment<byte>? GetOpNameBytes() { return __p.__vector_as_arraysegment(36); }
|
public ArraySegment<byte>? GetOpNameBytes() { return __p.__vector_as_arraysegment(36); }
|
||||||
#endif
|
#endif
|
||||||
public byte[] GetOpNameArray() { return __p.__vector_as_array<byte>(36); }
|
public byte[] GetOpNameArray() { return __p.__vector_as_array<byte>(36); }
|
||||||
public DataType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DataType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DataType)0; }
|
public DType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; }
|
||||||
public int OutputTypesLength { get { int o = __p.__offset(38); return o != 0 ? __p.__vector_len(o) : 0; } }
|
public int OutputTypesLength { get { int o = __p.__offset(38); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
#if ENABLE_SPAN_T
|
#if ENABLE_SPAN_T
|
||||||
public Span<byte> GetOutputTypesBytes() { return __p.__vector_as_span(38); }
|
public Span<byte> GetOutputTypesBytes() { return __p.__vector_as_span(38); }
|
||||||
#else
|
#else
|
||||||
public ArraySegment<byte>? GetOutputTypesBytes() { return __p.__vector_as_arraysegment(38); }
|
public ArraySegment<byte>? GetOutputTypesBytes() { return __p.__vector_as_arraysegment(38); }
|
||||||
#endif
|
#endif
|
||||||
public DataType[] GetOutputTypesArray() { return __p.__vector_as_array<DataType>(38); }
|
public DType[] GetOutputTypesArray() { return __p.__vector_as_array<DType>(38); }
|
||||||
public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
|
public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
|
||||||
|
|
||||||
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
|
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
|
||||||
|
@ -196,8 +196,8 @@ public struct FlatNode : IFlatbufferObject
|
||||||
public static void StartOutputNamesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
public static void StartOutputNamesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||||
public static void AddOpName(FlatBufferBuilder builder, StringOffset opNameOffset) { builder.AddOffset(16, opNameOffset.Value, 0); }
|
public static void AddOpName(FlatBufferBuilder builder, StringOffset opNameOffset) { builder.AddOffset(16, opNameOffset.Value, 0); }
|
||||||
public static void AddOutputTypes(FlatBufferBuilder builder, VectorOffset outputTypesOffset) { builder.AddOffset(17, outputTypesOffset.Value, 0); }
|
public static void AddOutputTypes(FlatBufferBuilder builder, VectorOffset outputTypesOffset) { builder.AddOffset(17, outputTypesOffset.Value, 0); }
|
||||||
public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
|
public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
|
||||||
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
||||||
public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
||||||
public static void AddScalar(FlatBufferBuilder builder, Offset<FlatArray> scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); }
|
public static void AddScalar(FlatBufferBuilder builder, Offset<FlatArray> scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); }
|
||||||
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {
|
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {
|
||||||
|
|
|
@ -25,7 +25,7 @@ public struct FlatVariable : IFlatbufferObject
|
||||||
public ArraySegment<byte>? GetNameBytes() { return __p.__vector_as_arraysegment(6); }
|
public ArraySegment<byte>? GetNameBytes() { return __p.__vector_as_arraysegment(6); }
|
||||||
#endif
|
#endif
|
||||||
public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); }
|
public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); }
|
||||||
public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } }
|
public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
|
||||||
public long Shape(int j) { int o = __p.__offset(10); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; }
|
public long Shape(int j) { int o = __p.__offset(10); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; }
|
||||||
public int ShapeLength { get { int o = __p.__offset(10); return o != 0 ? __p.__vector_len(o) : 0; } }
|
public int ShapeLength { get { int o = __p.__offset(10); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
#if ENABLE_SPAN_T
|
#if ENABLE_SPAN_T
|
||||||
|
@ -41,7 +41,7 @@ public struct FlatVariable : IFlatbufferObject
|
||||||
public static Offset<FlatVariable> CreateFlatVariable(FlatBufferBuilder builder,
|
public static Offset<FlatVariable> CreateFlatVariable(FlatBufferBuilder builder,
|
||||||
Offset<IntPair> idOffset = default(Offset<IntPair>),
|
Offset<IntPair> idOffset = default(Offset<IntPair>),
|
||||||
StringOffset nameOffset = default(StringOffset),
|
StringOffset nameOffset = default(StringOffset),
|
||||||
DataType dtype = DataType.INHERIT,
|
DType dtype = DType.INHERIT,
|
||||||
VectorOffset shapeOffset = default(VectorOffset),
|
VectorOffset shapeOffset = default(VectorOffset),
|
||||||
Offset<FlatArray> ndarrayOffset = default(Offset<FlatArray>),
|
Offset<FlatArray> ndarrayOffset = default(Offset<FlatArray>),
|
||||||
int device = 0,
|
int device = 0,
|
||||||
|
@ -60,7 +60,7 @@ public struct FlatVariable : IFlatbufferObject
|
||||||
public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); }
|
public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); }
|
||||||
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
|
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
|
||||||
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
||||||
public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
||||||
public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(3, shapeOffset.Value, 0); }
|
public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(3, shapeOffset.Value, 0); }
|
||||||
public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); }
|
public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); }
|
||||||
public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); }
|
public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); }
|
||||||
|
|
|
@ -312,11 +312,11 @@ nd4j.graph.FlatNode.prototype.opName = function(optionalEncoding) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {number} index
|
* @param {number} index
|
||||||
* @returns {nd4j.graph.DataType}
|
* @returns {nd4j.graph.DType}
|
||||||
*/
|
*/
|
||||||
nd4j.graph.FlatNode.prototype.outputTypes = function(index) {
|
nd4j.graph.FlatNode.prototype.outputTypes = function(index) {
|
||||||
var offset = this.bb.__offset(this.bb_pos, 38);
|
var offset = this.bb.__offset(this.bb_pos, 38);
|
||||||
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DataType} */ (0);
|
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -686,7 +686,7 @@ nd4j.graph.FlatNode.addOutputTypes = function(builder, outputTypesOffset) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {flatbuffers.Builder} builder
|
* @param {flatbuffers.Builder} builder
|
||||||
* @param {Array.<nd4j.graph.DataType>} data
|
* @param {Array.<nd4j.graph.DType>} data
|
||||||
* @returns {flatbuffers.Offset}
|
* @returns {flatbuffers.Offset}
|
||||||
*/
|
*/
|
||||||
nd4j.graph.FlatNode.createOutputTypesVector = function(builder, data) {
|
nd4j.graph.FlatNode.createOutputTypesVector = function(builder, data) {
|
||||||
|
|
|
@ -65,8 +65,8 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
const flatbuffers::String *name() const {
|
const flatbuffers::String *name() const {
|
||||||
return GetPointer<const flatbuffers::String *>(VT_NAME);
|
return GetPointer<const flatbuffers::String *>(VT_NAME);
|
||||||
}
|
}
|
||||||
DataType dtype() const {
|
DType dtype() const {
|
||||||
return static_cast<DataType>(GetField<int8_t>(VT_DTYPE, 0));
|
return static_cast<DType>(GetField<int8_t>(VT_DTYPE, 0));
|
||||||
}
|
}
|
||||||
const flatbuffers::Vector<int64_t> *shape() const {
|
const flatbuffers::Vector<int64_t> *shape() const {
|
||||||
return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE);
|
return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE);
|
||||||
|
@ -106,7 +106,7 @@ struct FlatVariableBuilder {
|
||||||
void add_name(flatbuffers::Offset<flatbuffers::String> name) {
|
void add_name(flatbuffers::Offset<flatbuffers::String> name) {
|
||||||
fbb_.AddOffset(FlatVariable::VT_NAME, name);
|
fbb_.AddOffset(FlatVariable::VT_NAME, name);
|
||||||
}
|
}
|
||||||
void add_dtype(DataType dtype) {
|
void add_dtype(DType dtype) {
|
||||||
fbb_.AddElement<int8_t>(FlatVariable::VT_DTYPE, static_cast<int8_t>(dtype), 0);
|
fbb_.AddElement<int8_t>(FlatVariable::VT_DTYPE, static_cast<int8_t>(dtype), 0);
|
||||||
}
|
}
|
||||||
void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) {
|
void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) {
|
||||||
|
@ -137,7 +137,7 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariable(
|
||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
flatbuffers::Offset<IntPair> id = 0,
|
flatbuffers::Offset<IntPair> id = 0,
|
||||||
flatbuffers::Offset<flatbuffers::String> name = 0,
|
flatbuffers::Offset<flatbuffers::String> name = 0,
|
||||||
DataType dtype = DataType_INHERIT,
|
DType dtype = DType_INHERIT,
|
||||||
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
||||||
flatbuffers::Offset<FlatArray> ndarray = 0,
|
flatbuffers::Offset<FlatArray> ndarray = 0,
|
||||||
int32_t device = 0,
|
int32_t device = 0,
|
||||||
|
@ -157,7 +157,7 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariableDirect(
|
||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
flatbuffers::Offset<IntPair> id = 0,
|
flatbuffers::Offset<IntPair> id = 0,
|
||||||
const char *name = nullptr,
|
const char *name = nullptr,
|
||||||
DataType dtype = DataType_INHERIT,
|
DType dtype = DType_INHERIT,
|
||||||
const std::vector<int64_t> *shape = nullptr,
|
const std::vector<int64_t> *shape = nullptr,
|
||||||
flatbuffers::Offset<FlatArray> ndarray = 0,
|
flatbuffers::Offset<FlatArray> ndarray = 0,
|
||||||
int32_t device = 0,
|
int32_t device = 0,
|
||||||
|
|
|
@ -76,11 +76,11 @@ nd4j.graph.FlatVariable.prototype.name = function(optionalEncoding) {
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @returns {nd4j.graph.DataType}
|
* @returns {nd4j.graph.DType}
|
||||||
*/
|
*/
|
||||||
nd4j.graph.FlatVariable.prototype.dtype = function() {
|
nd4j.graph.FlatVariable.prototype.dtype = function() {
|
||||||
var offset = this.bb.__offset(this.bb_pos, 8);
|
var offset = this.bb.__offset(this.bb_pos, 8);
|
||||||
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT;
|
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -150,10 +150,10 @@ nd4j.graph.FlatVariable.addName = function(builder, nameOffset) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {flatbuffers.Builder} builder
|
* @param {flatbuffers.Builder} builder
|
||||||
* @param {nd4j.graph.DataType} dtype
|
* @param {nd4j.graph.DType} dtype
|
||||||
*/
|
*/
|
||||||
nd4j.graph.FlatVariable.addDtype = function(builder, dtype) {
|
nd4j.graph.FlatVariable.addDtype = function(builder, dtype) {
|
||||||
builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT);
|
builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -111,7 +111,7 @@ namespace nd4j {
|
||||||
|
|
||||||
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
|
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
|
||||||
|
|
||||||
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array.dataType()), bo);
|
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DType>(array.dataType()), bo);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -219,7 +219,7 @@ namespace nd4j {
|
||||||
throw std::runtime_error("CONSTANT variable must have NDArray bundled");
|
throw std::runtime_error("CONSTANT variable must have NDArray bundled");
|
||||||
|
|
||||||
auto ar = flatVariable->ndarray();
|
auto ar = flatVariable->ndarray();
|
||||||
if (ar->dtype() == DataType_UTF8) {
|
if (ar->dtype() == DType_UTF8) {
|
||||||
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
|
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
|
||||||
} else {
|
} else {
|
||||||
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
|
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
|
||||||
|
@ -320,7 +320,7 @@ namespace nd4j {
|
||||||
auto fBuffer = builder.CreateVector(array->asByteVector());
|
auto fBuffer = builder.CreateVector(array->asByteVector());
|
||||||
|
|
||||||
// packing array
|
// packing array
|
||||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DataType) array->dataType());
|
auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DType) array->dataType());
|
||||||
|
|
||||||
// packing id/index of this var
|
// packing id/index of this var
|
||||||
auto fVid = CreateIntPair(builder, this->_id, this->_index);
|
auto fVid = CreateIntPair(builder, this->_id, this->_index);
|
||||||
|
@ -331,7 +331,7 @@ namespace nd4j {
|
||||||
stringId = builder.CreateString(this->_name);
|
stringId = builder.CreateString(this->_name);
|
||||||
|
|
||||||
// returning array
|
// returning array
|
||||||
return CreateFlatVariable(builder, fVid, stringId, static_cast<nd4j::graph::DataType>(array->dataType()), 0, fArray);
|
return CreateFlatVariable(builder, fVid, stringId, static_cast<nd4j::graph::DType>(array->dataType()), 0, fArray);
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList");
|
throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList");
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,7 @@ enum ByteOrder:byte {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DataType for arrays/buffers
|
// DataType for arrays/buffers
|
||||||
enum DataType:byte {
|
enum DType:byte {
|
||||||
INHERIT,
|
INHERIT,
|
||||||
BOOL,
|
BOOL,
|
||||||
FLOAT8,
|
FLOAT8,
|
||||||
|
@ -49,7 +49,7 @@ enum DataType:byte {
|
||||||
table FlatArray {
|
table FlatArray {
|
||||||
shape:[long]; // shape in Nd4j format
|
shape:[long]; // shape in Nd4j format
|
||||||
buffer:[byte]; // byte buffer with data
|
buffer:[byte]; // byte buffer with data
|
||||||
dtype:DataType; // data type of actual data within buffer
|
dtype:DType; // data type of actual data within buffer
|
||||||
byteOrder:ByteOrder; // byte order of buffer
|
byteOrder:ByteOrder; // byte order of buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ table FlatNode {
|
||||||
opName:string; //Used to help resolving the class. In a few cases, multiple classes/opNames are mapped to same hash, and might have different config/properties/differentiability
|
opName:string; //Used to help resolving the class. In a few cases, multiple classes/opNames are mapped to same hash, and might have different config/properties/differentiability
|
||||||
|
|
||||||
// output data types (optional)
|
// output data types (optional)
|
||||||
outputTypes:[DataType];
|
outputTypes:[DType];
|
||||||
|
|
||||||
//Scalar value - used for scalar ops. Should be single value only.
|
//Scalar value - used for scalar ops. Should be single value only.
|
||||||
scalar:FlatArray;
|
scalar:FlatArray;
|
||||||
|
|
|
@ -51,7 +51,7 @@ table UIVariable {
|
||||||
id:IntPair; //Existing IntPair class
|
id:IntPair; //Existing IntPair class
|
||||||
name:string;
|
name:string;
|
||||||
type:VarType; //Use existing VarType: VARIABLE, CONSTANT, ARRAY, PLACEHOLDER
|
type:VarType; //Use existing VarType: VARIABLE, CONSTANT, ARRAY, PLACEHOLDER
|
||||||
datatype:DataType;
|
datatype:DType;
|
||||||
shape:[long];
|
shape:[long];
|
||||||
controlDeps:[string]; //Input control dependencies: variable x -> this
|
controlDeps:[string]; //Input control dependencies: variable x -> this
|
||||||
outputOfOp:string; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
|
outputOfOp:string; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
|
||||||
|
|
|
@ -30,7 +30,7 @@ enum VarType:byte {
|
||||||
table FlatVariable {
|
table FlatVariable {
|
||||||
id:IntPair; // ID of the Variable, in format of IntPair.first is node Id, IntPair.second is output index of the node
|
id:IntPair; // ID of the Variable, in format of IntPair.first is node Id, IntPair.second is output index of the node
|
||||||
name:string; // symbolic ID of the Variable (if defined)
|
name:string; // symbolic ID of the Variable (if defined)
|
||||||
dtype:DataType;
|
dtype:DType;
|
||||||
|
|
||||||
shape:[long]; // shape is absolutely optional. either shape or ndarray might be set
|
shape:[long]; // shape is absolutely optional. either shape or ndarray might be set
|
||||||
ndarray:FlatArray;
|
ndarray:FlatArray;
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_bitwise_and)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/helpers.h>
|
||||||
|
#include <ops/declarable/helpers/shift.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
BROADCASTABLE_OP_IMPL(bitwise_and, 0, 0) {
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto y = INPUT_VARIABLE(1);
|
||||||
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||||
|
|
||||||
|
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), y, z, false);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(bitwise_and) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes({ALL_INTS})
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,50 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_bitwise_or)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/helpers.h>
|
||||||
|
#include <ops/declarable/helpers/shift.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
BROADCASTABLE_OP_IMPL(bitwise_or, 0, 0) {
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto y = INPUT_VARIABLE(1);
|
||||||
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||||
|
|
||||||
|
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), y, z, false);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(bitwise_or) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes({ALL_INTS})
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,50 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_bitwise_xor)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/helpers.h>
|
||||||
|
#include <ops/declarable/helpers/shift.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
BROADCASTABLE_OP_IMPL(bitwise_xor, 0, 0) {
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto y = INPUT_VARIABLE(1);
|
||||||
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||||
|
|
||||||
|
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), y, z, false);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(bitwise_xor) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes({ALL_INTS})
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -29,21 +29,26 @@ namespace ops {
|
||||||
CUSTOM_OP_IMPL(dynamic_stitch, 2, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(dynamic_stitch, 2, 1, false, 0, 0) {
|
||||||
int numOfData = block.width();
|
int numOfData = block.width();
|
||||||
// int k = 0;
|
// int k = 0;
|
||||||
|
// checking input data size
|
||||||
REQUIRE_TRUE(numOfData % 2 == 0, 0,
|
REQUIRE_TRUE(numOfData % 2 == 0, 0,
|
||||||
"dynamic_stitch: The input params should contains"
|
"dynamic_stitch: The input params should contains"
|
||||||
" both indeces and data lists with same length.");
|
" both indeces and data lists with same length.");
|
||||||
|
// split input data list on two equal parts
|
||||||
numOfData /= 2;
|
numOfData /= 2;
|
||||||
|
|
||||||
|
// form input lists to use with helpers - both indices and float data inputs
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
std::vector<NDArray*> inputs(numOfData);
|
std::vector<NDArray*> inputs(numOfData);
|
||||||
std::vector<NDArray*> indices(numOfData);
|
std::vector<NDArray*> indices(numOfData);
|
||||||
|
|
||||||
for (int e = 0; e < numOfData; e++) {
|
for (int e = 0; e < numOfData; e++) {
|
||||||
auto data = INPUT_VARIABLE(numOfData + e);
|
auto data = INPUT_VARIABLE(numOfData + e);
|
||||||
auto index = INPUT_VARIABLE(e);
|
auto index = INPUT_VARIABLE(e);
|
||||||
|
|
||||||
inputs[e] = data;
|
inputs[e] = data;
|
||||||
indices[e] = index;
|
indices[e] = index;
|
||||||
}
|
}
|
||||||
|
// run helper
|
||||||
return helpers::dynamicStitchFunctor(block.launchContext(), inputs, indices, output);
|
return helpers::dynamicStitchFunctor(block.launchContext(), inputs, indices, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,17 +64,17 @@ namespace ops {
|
||||||
numOfData /= 2; // only index part it's needed to review
|
numOfData /= 2; // only index part it's needed to review
|
||||||
auto restShape = inputShape->at(numOfData);
|
auto restShape = inputShape->at(numOfData);
|
||||||
auto firstShape = inputShape->at(0);
|
auto firstShape = inputShape->at(0);
|
||||||
|
// check up inputs to avoid non-int indices and calculate max value from indices to output shape length
|
||||||
for(int i = 0; i < numOfData; i++) {
|
for(int i = 0; i < numOfData; i++) {
|
||||||
auto input = INPUT_VARIABLE(i);
|
auto input = INPUT_VARIABLE(i);
|
||||||
REQUIRE_TRUE(input->isZ(), 0, "dynamic_stitch: Indices should be integer, but %d type given.", (int)input->dataType() );
|
REQUIRE_TRUE(input->isZ(), 0, "dynamic_stitch: Indices should be integer, but %d type given.", (int)input->dataType() );
|
||||||
// FIXME: we have reduce::Max, cinsider using it instead
|
|
||||||
auto maxV = input->reduceNumber(reduce::Max);
|
auto maxV = input->reduceNumber(reduce::Max);
|
||||||
if (maxV.e<Nd4jLong>(0) > maxValue) maxValue = maxV.e<Nd4jLong>(0);
|
if (maxV.e<Nd4jLong>(0) > maxValue) maxValue = maxV.e<Nd4jLong>(0);
|
||||||
}
|
}
|
||||||
|
// calculate output rank - difference between indices shape and data shape
|
||||||
int outRank = shape::rank(restShape) - shape::rank(firstShape) + 1;
|
int outRank = shape::rank(restShape) - shape::rank(firstShape) + 1; // at least 1D tensor
|
||||||
std::vector<Nd4jLong> outShape(outRank);
|
std::vector<Nd4jLong> outShape(outRank);
|
||||||
|
// fill up output shape template: the first to max index, and rests - to vals from the first data input
|
||||||
outShape[0] = maxValue + 1;
|
outShape[0] = maxValue + 1;
|
||||||
for(int i = 1; i < outRank; ++i)
|
for(int i = 1; i < outRank; ++i)
|
||||||
outShape[i] = shape::sizeAt(restShape, i);
|
outShape[i] = shape::sizeAt(restShape, i);
|
||||||
|
|
|
@ -33,12 +33,13 @@ namespace nd4j {
|
||||||
* 0: 1D row-vector (or with shape (1, m))
|
* 0: 1D row-vector (or with shape (1, m))
|
||||||
* 1: 1D integer vector with slice nums
|
* 1: 1D integer vector with slice nums
|
||||||
* 2: 1D float-point values vector with same shape as above
|
* 2: 1D float-point values vector with same shape as above
|
||||||
|
* 3: 2D float-point matrix with data to search
|
||||||
*
|
*
|
||||||
* Int args:
|
* Int args:
|
||||||
* 0: N - number of slices
|
* 0: N - number of slices
|
||||||
*
|
*
|
||||||
* Output:
|
* Output:
|
||||||
* 0: 1D vector with edge forces for input and values
|
* 0: 2D matrix with the same shape and type as the 3th argument
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_barnes_edge_forces)
|
#if NOT_EXCLUDED(OP_barnes_edge_forces)
|
||||||
DECLARE_CUSTOM_OP(barnes_edge_forces, 4, 1, false, 0, 1);
|
DECLARE_CUSTOM_OP(barnes_edge_forces, 4, 1, false, 0, 1);
|
||||||
|
@ -54,7 +55,9 @@ namespace nd4j {
|
||||||
* 2: 1D float vector with values
|
* 2: 1D float vector with values
|
||||||
*
|
*
|
||||||
* Output:
|
* Output:
|
||||||
* 0: symmetric 2D matrix with given values on given places
|
* 0: 1D int result row-vector
|
||||||
|
* 1: 1D int result col-vector
|
||||||
|
* 2: a float-point tensor with shape 1xN, with values from the last input vector
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_barnes_symmetrized)
|
#if NOT_EXCLUDED(OP_barnes_symmetrized)
|
||||||
DECLARE_CUSTOM_OP(barnes_symmetrized, 3, 3, false, 0, -1);
|
DECLARE_CUSTOM_OP(barnes_symmetrized, 3, 3, false, 0, -1);
|
||||||
|
|
|
@ -81,6 +81,39 @@ namespace nd4j {
|
||||||
DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0);
|
DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation applies bitwise AND
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||||
|
*
|
||||||
|
* @tparam T
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_bitwise_and)
|
||||||
|
DECLARE_BROADCASTABLE_OP(bitwise_and, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation applies bitwise OR
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||||
|
*
|
||||||
|
* @tparam T
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_bitwise_or)
|
||||||
|
DECLARE_BROADCASTABLE_OP(bitwise_or, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation applies bitwise XOR
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||||
|
*
|
||||||
|
* @tparam T
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_bitwise_xor)
|
||||||
|
DECLARE_BROADCASTABLE_OP(bitwise_xor, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This operation returns hamming distance based on bits
|
* This operation returns hamming distance based on bits
|
||||||
*
|
*
|
||||||
|
|
|
@ -120,7 +120,7 @@ namespace nd4j {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This operation unstacks given NDArray into NDArrayList
|
* This operation unstacks given NDArray into NDArrayList by the first dimension
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_unstack_list)
|
#if NOT_EXCLUDED(OP_unstack_list)
|
||||||
DECLARE_LIST_OP(unstack_list, 1, 1, 0, 0);
|
DECLARE_LIST_OP(unstack_list, 1, 1, 0, 0);
|
||||||
|
|
|
@ -594,21 +594,46 @@ namespace nd4j {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
* This operation rearranges data from depth into blocks of spatial data. This is the reverse transformation
|
||||||
|
* of space_to_depth op. This op output is a copy of the input tensor where values from the depth dimension
|
||||||
|
* are moved in spatial blocks to the height and width dimensions. Int attr 0 indicates the input
|
||||||
|
* block size and how the data is moved.
|
||||||
|
* Input:
|
||||||
|
* 0 - 4D tensor on given type
|
||||||
|
* Output:
|
||||||
|
* 0 - 4D tensor of given type and proper shape
|
||||||
*
|
*
|
||||||
*
|
* Int arguments:
|
||||||
*
|
* 0 - block size
|
||||||
|
* 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels }
|
||||||
|
* 1 ("NCHW"): shape{ batch, channels, height, width }
|
||||||
|
* 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 }
|
||||||
|
* optional (default 0)
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_depth_to_space)
|
#if NOT_EXCLUDED(OP_depth_to_space)
|
||||||
DECLARE_CUSTOM_OP(depth_to_space, 1, 1, false, 0, 2);
|
DECLARE_CUSTOM_OP(depth_to_space, 1, 1, false, 0, -1);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
* This operation rearranges blocks of spatial data, into depth.This op output is a copy of the input tensor
|
||||||
|
* where values from the height and width dimensions are moved to the depth dimension. Int attr 0 indicates
|
||||||
|
* the input block size.
|
||||||
*
|
*
|
||||||
|
* Input:
|
||||||
|
* - 4D tensor of given type
|
||||||
|
* Output:
|
||||||
|
* - 4D tensor
|
||||||
*
|
*
|
||||||
|
* Int arguments:
|
||||||
|
* 0 - block size
|
||||||
|
* 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels }
|
||||||
|
* 1 ("NCHW"): shape{ batch, channels, height, width }
|
||||||
|
* 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 }
|
||||||
|
* optional (default 0)
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_space_to_depth)
|
#if NOT_EXCLUDED(OP_space_to_depth)
|
||||||
DECLARE_CUSTOM_OP(space_to_depth, 1, 1, false, 0, 2);
|
DECLARE_CUSTOM_OP(space_to_depth, 1, 1, false, 0, -1);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -622,13 +647,42 @@ namespace nd4j {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
* Zero-pads and then rearranges (permutes) blocks of spatial data into batch. More specifically, this op
|
||||||
|
* outputs a copy of the input tensor where values from the height and width dimensions are moved to the
|
||||||
|
* batch dimension. After the zero-padding, both height and width of the input must be divisible by the block
|
||||||
|
* size.
|
||||||
*
|
*
|
||||||
|
* Inputs:
|
||||||
|
* 0 - input tensor
|
||||||
|
* 1 - 2D paddings tensor (shape {M, 2})
|
||||||
|
*
|
||||||
|
* Output:
|
||||||
|
* - result tensor
|
||||||
|
*
|
||||||
|
* Int args:
|
||||||
|
* 0 - block size (M)
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_space_to_batch)
|
#if NOT_EXCLUDED(OP_space_to_batch)
|
||||||
DECLARE_CUSTOM_OP(space_to_batch, 2, 1, false, 0, 1);
|
DECLARE_CUSTOM_OP(space_to_batch, 2, 1, false, 0, 1);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This operation divides "spatial" dimensions [1, ..., M] of the input into a grid of blocks of shape
|
||||||
|
* block_shape, and interleaves these blocks with the "batch" dimension (0) such that in the output,
|
||||||
|
* the spatial dimensions [1, ..., M] correspond to the position within the grid, and the batch dimension
|
||||||
|
* combines both the position within a spatial block and the original batch position. Prior to division into
|
||||||
|
* blocks, the spatial dimensions of the input are optionally zero padded according to paddings.
|
||||||
|
*
|
||||||
|
* Inputs:
|
||||||
|
* 0 - input (N-D tensor)
|
||||||
|
* 1 - block_shape - int 1D tensor with M length
|
||||||
|
* 2 - paddings - int 2D tensor with shape {M, 2}
|
||||||
|
*
|
||||||
|
* Output:
|
||||||
|
* - N-D tensor with the same type as input 0.
|
||||||
|
*
|
||||||
|
* */
|
||||||
#if NOT_EXCLUDED(OP_space_to_batch_nd)
|
#if NOT_EXCLUDED(OP_space_to_batch_nd)
|
||||||
DECLARE_CUSTOM_OP(space_to_batch_nd, 3, 1, false, 0, 0);
|
DECLARE_CUSTOM_OP(space_to_batch_nd, 3, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
@ -973,7 +1027,7 @@ namespace nd4j {
|
||||||
* return value:
|
* return value:
|
||||||
* tensor with min values according to indices sets.
|
* tensor with min values according to indices sets.
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_segment_min_bp)
|
#if NOT_EXCLUDED(OP_segment_min)
|
||||||
DECLARE_CUSTOM_OP(segment_min, 2, 1, false, 0, 0);
|
DECLARE_CUSTOM_OP(segment_min, 2, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
#if NOT_EXCLUDED(OP_segment_min_bp)
|
#if NOT_EXCLUDED(OP_segment_min_bp)
|
||||||
|
|
|
@ -118,19 +118,19 @@ namespace nd4j {
|
||||||
|
|
||||||
PointersManager pm(context, "dynamicPartition");
|
PointersManager pm(context, "dynamicPartition");
|
||||||
|
|
||||||
if (sourceDimsLen) {
|
if (sourceDimsLen) { // non-linear case
|
||||||
std::vector<int> sourceDims(sourceDimsLen);
|
std::vector<int> sourceDims(sourceDimsLen);
|
||||||
|
|
||||||
for (int i = sourceDimsLen; i > 0; i--)
|
for (int i = sourceDimsLen; i > 0; i--)
|
||||||
sourceDims[sourceDimsLen - i] = input->rankOf() - i;
|
sourceDims[sourceDimsLen - i] = input->rankOf() - i;
|
||||||
|
//compute tad array for given dimensions
|
||||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), sourceDims);
|
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), sourceDims);
|
||||||
|
|
||||||
std::vector<void *> outBuffers(outSize);
|
std::vector<void *> outBuffers(outSize);
|
||||||
std::vector<Nd4jLong *> tadShapes(outSize);
|
std::vector<Nd4jLong *> tadShapes(outSize);
|
||||||
std::vector<Nd4jLong *> tadOffsets(outSize);
|
std::vector<Nd4jLong *> tadOffsets(outSize);
|
||||||
std::vector<Nd4jLong> numTads(outSize);
|
std::vector<Nd4jLong> numTads(outSize);
|
||||||
|
// fill up dimensions array for before kernel
|
||||||
for (unsigned int i = 0; i < outSize; i++) {
|
for (unsigned int i = 0; i < outSize; i++) {
|
||||||
outputs[i].first = outputList[i];
|
outputs[i].first = outputList[i];
|
||||||
std::vector<int> outDims(outputs[i].first->rankOf() - 1);
|
std::vector<int> outDims(outputs[i].first->rankOf() - 1);
|
||||||
|
@ -151,10 +151,10 @@ namespace nd4j {
|
||||||
auto dOutBuffers = reinterpret_cast<void **>(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *)));
|
auto dOutBuffers = reinterpret_cast<void **>(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *)));
|
||||||
auto dOutTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *)));
|
auto dOutTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *)));
|
||||||
auto dOutTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *)));
|
auto dOutTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *)));
|
||||||
|
// run kernel on device
|
||||||
dynamicPartitionTadKernel<X,Y><<<256, 256, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize);
|
dynamicPartitionTadKernel<X,Y><<<256, 256, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize);
|
||||||
|
|
||||||
} else {
|
} else { // linear case
|
||||||
auto numThreads = 256;
|
auto numThreads = 256;
|
||||||
auto shmemSize = numThreads * sizeof(Y) * 2 + 1024;
|
auto shmemSize = numThreads * sizeof(Y) * 2 + 1024;
|
||||||
|
|
||||||
|
@ -169,7 +169,6 @@ namespace nd4j {
|
||||||
auto dOutBuffers = reinterpret_cast<void **>(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *)));
|
auto dOutBuffers = reinterpret_cast<void **>(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *)));
|
||||||
auto dOutShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(outShapes.data(), outShapes.size() * sizeof(Nd4jLong *)));
|
auto dOutShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(outShapes.data(), outShapes.size() * sizeof(Nd4jLong *)));
|
||||||
|
|
||||||
|
|
||||||
dynamicPartitionScalarKernel<X,Y><<<256, numThreads, shmemSize, *context->getCudaStream()>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), indices->getSpecialBuffer(), indices-> getSpecialShapeInfo(), dOutBuffers, dOutShapes, outSize);
|
dynamicPartitionScalarKernel<X,Y><<<256, numThreads, shmemSize, *context->getCudaStream()>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), indices->getSpecialBuffer(), indices-> getSpecialShapeInfo(), dOutBuffers, dOutShapes, outSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -544,8 +544,8 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_1) {
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests13, adjustSaturation_2) {
|
TEST_F(DeclarableOpsTests13, adjustSaturation_2) {
|
||||||
|
|
||||||
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::FLOAT32);
|
NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
nd4j::ops::adjust_saturation op;
|
nd4j::ops::adjust_saturation op;
|
||||||
auto results = op.execute({&input}, {10}, {2});
|
auto results = op.execute({&input}, {10}, {2});
|
||||||
|
@ -553,7 +553,8 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_2) {
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
auto result = results->at(0);
|
auto result = results->at(0);
|
||||||
// result->printIndexedBuffer();
|
// result->printIndexedBuffer("Result2");
|
||||||
|
// exp.printIndexedBuffer("Expect2");
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(result));
|
ASSERT_TRUE(exp.isSameShape(result));
|
||||||
ASSERT_TRUE(exp.equalsTo(result));
|
ASSERT_TRUE(exp.equalsTo(result));
|
||||||
|
|
|
@ -94,10 +94,10 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) {
|
||||||
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
|
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
|
||||||
auto fBuffer = builder.CreateVector(array->asByteVector());
|
auto fBuffer = builder.CreateVector(array->asByteVector());
|
||||||
|
|
||||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT);
|
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT);
|
||||||
auto fVid = CreateIntPair(builder, -1);
|
auto fVid = CreateIntPair(builder, -1);
|
||||||
|
|
||||||
auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray);
|
auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray);
|
||||||
|
|
||||||
std::vector<int> outputs1, outputs2, inputs1, inputs2;
|
std::vector<int> outputs1, outputs2, inputs1, inputs2;
|
||||||
outputs1.push_back(2);
|
outputs1.push_back(2);
|
||||||
|
@ -265,7 +265,7 @@ TEST_F(FlatBuffersTest, ExplicitOutputTest1) {
|
||||||
|
|
||||||
auto name1 = builder.CreateString("wow1");
|
auto name1 = builder.CreateString("wow1");
|
||||||
|
|
||||||
auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DataType::FLOAT);
|
auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DType::FLOAT);
|
||||||
|
|
||||||
std::vector<flatbuffers::Offset<FlatVariable>> variables_vector;
|
std::vector<flatbuffers::Offset<FlatVariable>> variables_vector;
|
||||||
variables_vector.push_back(fXVar);
|
variables_vector.push_back(fXVar);
|
||||||
|
|
|
@ -73,9 +73,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_1) {
|
||||||
auto fBuffer = builder.CreateVector(vec);
|
auto fBuffer = builder.CreateVector(vec);
|
||||||
auto fVid = CreateIntPair(builder, 1, 12);
|
auto fVid = CreateIntPair(builder, 1, 12);
|
||||||
|
|
||||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT);
|
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT);
|
||||||
|
|
||||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray);
|
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray);
|
||||||
|
|
||||||
builder.Finish(flatVar);
|
builder.Finish(flatVar);
|
||||||
|
|
||||||
|
@ -107,9 +107,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_2) {
|
||||||
auto fBuffer = builder.CreateVector(vec);
|
auto fBuffer = builder.CreateVector(vec);
|
||||||
auto fVid = CreateIntPair(builder, 1, 12);
|
auto fVid = CreateIntPair(builder, 1, 12);
|
||||||
|
|
||||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE);
|
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE);
|
||||||
|
|
||||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray);
|
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray);
|
||||||
|
|
||||||
builder.Finish(flatVar);
|
builder.Finish(flatVar);
|
||||||
|
|
||||||
|
@ -144,9 +144,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_3) {
|
||||||
auto fBuffer = builder.CreateVector(vec);
|
auto fBuffer = builder.CreateVector(vec);
|
||||||
auto fVid = CreateIntPair(builder, 1, 12);
|
auto fVid = CreateIntPair(builder, 1, 12);
|
||||||
|
|
||||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE);
|
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE);
|
||||||
|
|
||||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray);
|
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray);
|
||||||
|
|
||||||
builder.Finish(flatVar);
|
builder.Finish(flatVar);
|
||||||
|
|
||||||
|
@ -180,7 +180,7 @@ TEST_F(VariableTests, Test_FlatVariableDataType_4) {
|
||||||
auto fShape = builder.CreateVector(original.getShapeAsFlatVector());
|
auto fShape = builder.CreateVector(original.getShapeAsFlatVector());
|
||||||
auto fVid = CreateIntPair(builder, 37, 12);
|
auto fVid = CreateIntPair(builder, 37, 12);
|
||||||
|
|
||||||
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER);
|
auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER);
|
||||||
|
|
||||||
builder.Finish(flatVar);
|
builder.Finish(flatVar);
|
||||||
|
|
||||||
|
|
|
@ -469,7 +469,7 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable localResponseNormalization(SDVariable input, LocalResponseNormalizationConfig lrnConfig) {
|
public SDVariable localResponseNormalization(SDVariable input, LocalResponseNormalizationConfig lrnConfig) {
|
||||||
LocalResponseNormalization lrn = LocalResponseNormalization.builder()
|
LocalResponseNormalization lrn = LocalResponseNormalization.sameDiffBuilder()
|
||||||
.inputFunctions(new SDVariable[]{input})
|
.inputFunctions(new SDVariable[]{input})
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.config(lrnConfig)
|
.config(lrnConfig)
|
||||||
|
@ -487,7 +487,7 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
|
public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
|
||||||
Conv1D conv1D = Conv1D.builder()
|
Conv1D conv1D = Conv1D.sameDiffBuilder()
|
||||||
.inputFunctions(new SDVariable[]{input, weights})
|
.inputFunctions(new SDVariable[]{input, weights})
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.config(conv1DConfig)
|
.config(conv1DConfig)
|
||||||
|
@ -496,6 +496,34 @@ public class DifferentialFunctionFactory {
|
||||||
return conv1D.outputVariable();
|
return conv1D.outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Conv1d operation.
|
||||||
|
*
|
||||||
|
* @param input the inputs to conv1d
|
||||||
|
* @param weights conv1d weights
|
||||||
|
* @param bias conv1d bias
|
||||||
|
* @param conv1DConfig the configuration
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public SDVariable conv1d(SDVariable input, SDVariable weights, SDVariable bias, Conv1DConfig conv1DConfig) {
|
||||||
|
|
||||||
|
SDVariable[] args;
|
||||||
|
|
||||||
|
if(bias == null){
|
||||||
|
args = new SDVariable[]{input, weights};
|
||||||
|
} else {
|
||||||
|
args = new SDVariable[]{input, weights, bias};
|
||||||
|
}
|
||||||
|
|
||||||
|
Conv1D conv1D = Conv1D.sameDiffBuilder()
|
||||||
|
.inputFunctions(args)
|
||||||
|
.sameDiff(sameDiff())
|
||||||
|
.config(conv1DConfig)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return conv1D.outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Conv2d operation.
|
* Conv2d operation.
|
||||||
*
|
*
|
||||||
|
@ -504,7 +532,7 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
|
public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
|
||||||
Conv2D conv2D = Conv2D.builder()
|
Conv2D conv2D = Conv2D.sameDiffBuilder()
|
||||||
.inputFunctions(inputs)
|
.inputFunctions(inputs)
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.config(conv2DConfig)
|
.config(conv2DConfig)
|
||||||
|
@ -530,7 +558,7 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
|
public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
|
||||||
AvgPooling2D avgPooling2D = AvgPooling2D.builder()
|
AvgPooling2D avgPooling2D = AvgPooling2D.sameDiffBuilder()
|
||||||
.input(input)
|
.input(input)
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.config(pooling2DConfig)
|
.config(pooling2DConfig)
|
||||||
|
@ -547,7 +575,7 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
|
public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
|
||||||
MaxPooling2D maxPooling2D = MaxPooling2D.builder()
|
MaxPooling2D maxPooling2D = MaxPooling2D.sameDiffBuilder()
|
||||||
.input(input)
|
.input(input)
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.config(pooling2DConfig)
|
.config(pooling2DConfig)
|
||||||
|
@ -590,7 +618,7 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
|
public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
|
||||||
SConv2D sconv2D = SConv2D.sBuilder()
|
SConv2D sconv2D = SConv2D.sameDiffSBuilder()
|
||||||
.inputFunctions(inputs)
|
.inputFunctions(inputs)
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.conv2DConfig(conv2DConfig)
|
.conv2DConfig(conv2DConfig)
|
||||||
|
@ -609,7 +637,7 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
|
public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
|
||||||
SConv2D depthWiseConv2D = SConv2D.sBuilder()
|
SConv2D depthWiseConv2D = SConv2D.sameDiffSBuilder()
|
||||||
.inputFunctions(inputs)
|
.inputFunctions(inputs)
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.conv2DConfig(depthConv2DConfig)
|
.conv2DConfig(depthConv2DConfig)
|
||||||
|
@ -627,7 +655,7 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
|
public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
|
||||||
DeConv2D deconv2D = DeConv2D.builder()
|
DeConv2D deconv2D = DeConv2D.sameDiffBuilder()
|
||||||
.inputs(inputs)
|
.inputs(inputs)
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.config(deconv2DConfig)
|
.config(deconv2DConfig)
|
||||||
|
@ -654,9 +682,9 @@ public class DifferentialFunctionFactory {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig) {
|
public SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig) {
|
||||||
Conv3D conv3D = Conv3D.builder()
|
Conv3D conv3D = Conv3D.sameDiffBuilder()
|
||||||
.inputFunctions(inputs)
|
.inputFunctions(inputs)
|
||||||
.conv3DConfig(conv3DConfig)
|
.config(conv3DConfig)
|
||||||
.sameDiff(sameDiff())
|
.sameDiff(sameDiff())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -1260,6 +1288,22 @@ public class DifferentialFunctionFactory {
|
||||||
return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable();
|
return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SDVariable bitwiseHammingDist(SDVariable x, SDVariable y) {
|
||||||
|
return new BitsHammingDistance(sameDiff(), x, y).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable bitwiseAnd(SDVariable x, SDVariable y){
|
||||||
|
return new BitwiseAnd(sameDiff(), x, y).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable bitwiseOr(SDVariable x, SDVariable y){
|
||||||
|
return new BitwiseOr(sameDiff(), x, y).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable bitwiseXor(SDVariable x, SDVariable y){
|
||||||
|
return new BitwiseXor(sameDiff(), x, y).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
public SDVariable eq(SDVariable iX, SDVariable i_y) {
|
public SDVariable eq(SDVariable iX, SDVariable i_y) {
|
||||||
return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable();
|
return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
|
@ -188,6 +188,11 @@ public class SameDiff extends SDBaseOps {
|
||||||
*/
|
*/
|
||||||
public final SDImage image = new SDImage(this);
|
public final SDImage image = new SDImage(this);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Op creator object for bitwise operations
|
||||||
|
*/
|
||||||
|
public final SDBitwise bitwise = new SDBitwise(this);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Op creator object for math operations
|
* Op creator object for math operations
|
||||||
*/
|
*/
|
||||||
|
@ -237,6 +242,13 @@ public class SameDiff extends SDBaseOps {
|
||||||
return image;
|
return image;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Op creator object for bitwise operations
|
||||||
|
*/
|
||||||
|
public SDBitwise bitwise(){
|
||||||
|
return bitwise;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* For import, many times we have variables
|
* For import, many times we have variables
|
||||||
|
|
|
@ -0,0 +1,205 @@
|
||||||
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
|
||||||
|
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class SDBitwise extends SDOps {
|
||||||
|
public SDBitwise(SameDiff sameDiff) {
|
||||||
|
super(sameDiff);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #leftShift(String, SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
public SDVariable leftShift(@NonNull SDVariable x, @NonNull SDVariable y){
|
||||||
|
return leftShift(null, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise left shift operation. Supports broadcasting.
|
||||||
|
*
|
||||||
|
* @param name Name of the output variable. May be null.
|
||||||
|
* @param x Input to be bit shifted (must be an integer type)
|
||||||
|
* @param y Amount to shift elements of x array (must be an integer type)
|
||||||
|
* @return Bitwise shifted input x
|
||||||
|
*/
|
||||||
|
public SDVariable leftShift(String name, SDVariable x, SDVariable y){
|
||||||
|
validateInteger("bitwise left shift", x);
|
||||||
|
validateInteger("bitwise left shift", y);
|
||||||
|
|
||||||
|
SDVariable ret = f().shift(x, y);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #rightShift(String, SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
public SDVariable rightShift(SDVariable x, SDVariable y){
|
||||||
|
return rightShift(null, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise right shift operation. Supports broadcasting.
|
||||||
|
*
|
||||||
|
* @param name Name of the output variable. May be null.
|
||||||
|
* @param x Input to be bit shifted (must be an integer type)
|
||||||
|
* @param y Amount to shift elements of x array (must be an integer type)
|
||||||
|
* @return Bitwise shifted input x
|
||||||
|
*/
|
||||||
|
public SDVariable rightShift(String name, SDVariable x, SDVariable y){
|
||||||
|
validateInteger("bitwise right shift", x);
|
||||||
|
validateInteger("bitwise right shift", y);
|
||||||
|
|
||||||
|
SDVariable ret = f().rshift(x, y);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #leftShiftCyclic(String, SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
public SDVariable leftShiftCyclic(SDVariable x, SDVariable y){
|
||||||
|
return leftShiftCyclic(null, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise left cyclical shift operation. Supports broadcasting.
|
||||||
|
* Unlike {@link #leftShift(String, SDVariable, SDVariable)} the bits will "wrap around":
|
||||||
|
* {@code leftShiftCyclic(01110000, 2) -> 11000001}
|
||||||
|
*
|
||||||
|
* @param name Name of the output variable. May be null.
|
||||||
|
* @param x Input to be bit shifted (must be an integer type)
|
||||||
|
* @param y Amount to shift elements of x array (must be an integer type)
|
||||||
|
* @return Bitwise cyclic shifted input x
|
||||||
|
*/
|
||||||
|
public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y){
|
||||||
|
validateInteger("bitwise left shift (cyclic)", x);
|
||||||
|
validateInteger("bitwise left shift (cyclic)", y);
|
||||||
|
|
||||||
|
SDVariable ret = f().rotl(x, y);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #rightShiftCyclic(String, SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
public SDVariable rightShiftCyclic(SDVariable x, SDVariable y){
|
||||||
|
return rightShiftCyclic(null, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise right cyclical shift operation. Supports broadcasting.
|
||||||
|
* Unlike {@link #rightShift(String, SDVariable, SDVariable)} the bits will "wrap around":
|
||||||
|
* {@code rightShiftCyclic(00001110, 2) -> 10000011}
|
||||||
|
*
|
||||||
|
* @param name Name of the output variable. May be null.
|
||||||
|
* @param x Input to be bit shifted (must be an integer type)
|
||||||
|
* @param y Amount to shift elements of x array (must be an integer type)
|
||||||
|
* @return Bitwise cyclic shifted input x
|
||||||
|
*/
|
||||||
|
public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y){
|
||||||
|
validateInteger("bitwise right shift (cyclic)", x);
|
||||||
|
validateInteger("bitwise right shift (cyclic)", y);
|
||||||
|
|
||||||
|
SDVariable ret = f().rotr(x, y);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #bitsHammingDistance(String, SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
public SDVariable bitsHammingDistance(SDVariable x, SDVariable y){
|
||||||
|
return bitsHammingDistance(null, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise Hamming distance reduction over all elements of both input arrays.<br>
|
||||||
|
* For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
|
||||||
|
*
|
||||||
|
* @param name Name of the output variable. May be null.
|
||||||
|
* @param x First input array. Must be integer type.
|
||||||
|
* @param y First input array. Must be integer type, same type as x
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y){
|
||||||
|
validateInteger("bitwise hamming distance", x);
|
||||||
|
validateInteger("bitwise hamming distance", y);
|
||||||
|
|
||||||
|
SDVariable ret = f().bitwiseHammingDist(x, y);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #and(String, SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
public SDVariable and(SDVariable x, SDVariable y){
|
||||||
|
return and(null, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise AND operation. Supports broadcasting.
|
||||||
|
*
|
||||||
|
* @param name Name of the output variable. May be null.
|
||||||
|
* @param x First input array. Must be integer type.
|
||||||
|
* @param y First input array. Must be integer type, same type as x
|
||||||
|
* @return Bitwise AND array
|
||||||
|
*/
|
||||||
|
public SDVariable and(String name, SDVariable x, SDVariable y){
|
||||||
|
validateInteger("bitwise AND", x);
|
||||||
|
validateInteger("bitwise AND", y);
|
||||||
|
|
||||||
|
SDVariable ret = f().bitwiseAnd(x, y);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #or(String, SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
public SDVariable or(SDVariable x, SDVariable y){
|
||||||
|
return or(null, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise OR operation. Supports broadcasting.
|
||||||
|
*
|
||||||
|
* @param name Name of the output variable. May be null.
|
||||||
|
* @param x First input array. Must be integer type.
|
||||||
|
* @param y First input array. Must be integer type, same type as x
|
||||||
|
* @return Bitwise OR array
|
||||||
|
*/
|
||||||
|
public SDVariable or(String name, SDVariable x, SDVariable y){
|
||||||
|
validateInteger("bitwise OR", x);
|
||||||
|
validateInteger("bitwise OR", y);
|
||||||
|
|
||||||
|
SDVariable ret = f().bitwiseOr(x, y);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #xor(String, SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
public SDVariable xor(SDVariable x, SDVariable y){
|
||||||
|
return xor(null, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bitwise XOR operation (exclusive OR). Supports broadcasting.
|
||||||
|
*
|
||||||
|
* @param name Name of the output variable. May be null.
|
||||||
|
* @param x First input array. Must be integer type.
|
||||||
|
* @param y First input array. Must be integer type, same type as x
|
||||||
|
* @return Bitwise XOR array
|
||||||
|
*/
|
||||||
|
public SDVariable xor(String name, SDVariable x, SDVariable y){
|
||||||
|
validateInteger("bitwise XOR", x);
|
||||||
|
validateInteger("bitwise XOR", y);
|
||||||
|
|
||||||
|
SDVariable ret = f().bitwiseXor(x, y);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
}
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.ops;
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
|
||||||
|
@ -38,14 +39,9 @@ public class SDCNN extends SDOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D Convolution layer operation - average pooling 2d
|
* See {@link #avgPooling2d(String, SDVariable, Pooling2DConfig)}.
|
||||||
*
|
|
||||||
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
|
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
|
||||||
* @param pooling2DConfig the configuration for
|
|
||||||
* @return Result after applying average pooling on the input
|
|
||||||
*/
|
*/
|
||||||
public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
|
public SDVariable avgPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
|
||||||
return avgPooling2d(null, input, pooling2DConfig);
|
return avgPooling2d(null, input, pooling2DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,22 +54,16 @@ public class SDCNN extends SDOps {
|
||||||
* @param pooling2DConfig the configuration
|
* @param pooling2DConfig the configuration
|
||||||
* @return Result after applying average pooling on the input
|
* @return Result after applying average pooling on the input
|
||||||
*/
|
*/
|
||||||
public SDVariable avgPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) {
|
public SDVariable avgPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
|
||||||
validateFloatingPoint("avgPooling2d", input);
|
validateFloatingPoint("avgPooling2d", input);
|
||||||
SDVariable ret = f().avgPooling2d(input, pooling2DConfig);
|
SDVariable ret = f().avgPooling2d(input, pooling2DConfig);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 3D convolution layer operation - average pooling 3d
|
* See {@link #avgPooling3d(String, SDVariable, Pooling3DConfig)}.
|
||||||
*
|
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels])
|
|
||||||
* @param pooling3DConfig the configuration
|
|
||||||
* @return Result after applying average pooling on the input
|
|
||||||
*/
|
*/
|
||||||
public SDVariable avgPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) {
|
public SDVariable avgPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
|
||||||
return avgPooling3d(null, input, pooling3DConfig);
|
return avgPooling3d(null, input, pooling3DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,7 +77,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param pooling3DConfig the configuration
|
* @param pooling3DConfig the configuration
|
||||||
* @return Result after applying average pooling on the input
|
* @return Result after applying average pooling on the input
|
||||||
*/
|
*/
|
||||||
public SDVariable avgPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) {
|
public SDVariable avgPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
|
||||||
validateFloatingPoint("avgPooling3d", input);
|
validateFloatingPoint("avgPooling3d", input);
|
||||||
SDVariable ret = f().avgPooling3d(input, pooling3DConfig);
|
SDVariable ret = f().avgPooling3d(input, pooling3DConfig);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
@ -96,7 +86,7 @@ public class SDCNN extends SDOps {
|
||||||
/**
|
/**
|
||||||
* @see #batchToSpace(String, SDVariable, int[], int[][])
|
* @see #batchToSpace(String, SDVariable, int[], int[][])
|
||||||
*/
|
*/
|
||||||
public SDVariable batchToSpace(SDVariable x, int[] blocks, int[][] crops) {
|
public SDVariable batchToSpace(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) {
|
||||||
return batchToSpace(null, x, blocks, crops);
|
return batchToSpace(null, x, blocks, crops);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,7 +101,7 @@ public class SDCNN extends SDOps {
|
||||||
* @return Output variable
|
* @return Output variable
|
||||||
* @see #spaceToBatch(String, SDVariable, int[], int[][])
|
* @see #spaceToBatch(String, SDVariable, int[], int[][])
|
||||||
*/
|
*/
|
||||||
public SDVariable batchToSpace(String name, SDVariable x, int[] blocks, int[][] crops) {
|
public SDVariable batchToSpace(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) {
|
||||||
validateNumerical("batchToSpace", x);
|
validateNumerical("batchToSpace", x);
|
||||||
SDVariable ret = f().batchToSpace(x, blocks, crops);
|
SDVariable ret = f().batchToSpace(x, blocks, crops);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
@ -119,14 +109,9 @@ public class SDCNN extends SDOps {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* col2im operation for use in 2D convolution operations. Outputs a 4d array with shape
|
* See {@link #col2Im(String, SDVariable, Conv2DConfig)}.
|
||||||
* [minibatch, inputChannels, height, width]
|
|
||||||
*
|
|
||||||
* @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
|
|
||||||
* @param config Convolution configuration for the col2im operation
|
|
||||||
* @return Col2Im output variable
|
|
||||||
*/
|
*/
|
||||||
public SDVariable col2Im(SDVariable in, Conv2DConfig config) {
|
public SDVariable col2Im(@NonNull SDVariable in, @NonNull Conv2DConfig config) {
|
||||||
return col2Im(null, in, config);
|
return col2Im(null, in, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,33 +124,22 @@ public class SDCNN extends SDOps {
|
||||||
* @param config Convolution configuration for the col2im operation
|
* @param config Convolution configuration for the col2im operation
|
||||||
* @return Col2Im output variable
|
* @return Col2Im output variable
|
||||||
*/
|
*/
|
||||||
public SDVariable col2Im(String name, SDVariable in, Conv2DConfig config) {
|
public SDVariable col2Im(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) {
|
||||||
SDVariable ret = f().col2Im(in, config);
|
SDVariable ret = f().col2Im(in, config);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 1D Convolution layer operation - Conv1d
|
* See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias.
|
||||||
*
|
|
||||||
* @param input the input array/activations for the conv1d op
|
|
||||||
* @param weights weights for conv1d op - rank 3 array with values [kernelSize, inputChannels, outputChannels]
|
|
||||||
* @param conv1DConfig the configuration
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
|
public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) {
|
||||||
return conv1d(null, input, weights, conv1DConfig);
|
return conv1d((String) null, input, weights, conv1DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Conv1d operation.
|
* See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias.
|
||||||
*
|
|
||||||
* @param name name of the operation in SameDiff
|
|
||||||
* @param input the inputs to conv1d
|
|
||||||
* @param weights weights for conv1d op - rank 3 array with values [kernelSize, inputChannels, outputChannels]
|
|
||||||
* @param conv1DConfig the configuration
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public SDVariable conv1d(String name, SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
|
public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) {
|
||||||
validateFloatingPoint("conv1d", input);
|
validateFloatingPoint("conv1d", input);
|
||||||
validateFloatingPoint("conv1d", weights);
|
validateFloatingPoint("conv1d", weights);
|
||||||
SDVariable ret = f().conv1d(input, weights, conv1DConfig);
|
SDVariable ret = f().conv1d(input, weights, conv1DConfig);
|
||||||
|
@ -173,21 +147,55 @@ public class SDCNN extends SDOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D Convolution operation (without bias)
|
* See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}.
|
||||||
*
|
|
||||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
|
||||||
* @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels]
|
|
||||||
* @param config Conv2DConfig configuration
|
|
||||||
* @return result of conv2d op
|
|
||||||
*/
|
*/
|
||||||
public SDVariable conv2d(SDVariable layerInput, SDVariable weights, Conv2DConfig config) {
|
public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) {
|
||||||
|
return conv1d(null, input, weights, bias, conv1DConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Conv1d operation.
|
||||||
|
*
|
||||||
|
* @param name name of the operation in SameDiff
|
||||||
|
* @param input the inputs to conv1d
|
||||||
|
* @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels]
|
||||||
|
* @param bias bias for conv1d op - rank 1 array with shape [outputChannels]. May be null.
|
||||||
|
* @param conv1DConfig the configuration
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) {
|
||||||
|
validateFloatingPoint("conv1d", input);
|
||||||
|
validateFloatingPoint("conv1d", weights);
|
||||||
|
validateFloatingPoint("conv1d", bias);
|
||||||
|
SDVariable ret = f().conv1d(input, weights, bias, conv1DConfig);
|
||||||
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||||
|
*/
|
||||||
|
public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) {
|
||||||
return conv2d(layerInput, weights, null, config);
|
return conv2d(layerInput, weights, null, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||||
|
*/
|
||||||
|
public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) {
|
||||||
|
return conv2d(name, layerInput, weights, null, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
|
||||||
|
*/
|
||||||
|
public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) {
|
||||||
|
return conv2d(null, layerInput, weights, bias, config);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D Convolution operation with optional bias
|
* 2D Convolution operation with optional bias
|
||||||
*
|
*
|
||||||
|
* @param name name of the operation in SameDiff
|
||||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
||||||
* @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels]
|
* @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels]
|
||||||
|
@ -195,7 +203,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param config Conv2DConfig configuration
|
* @param config Conv2DConfig configuration
|
||||||
* @return result of conv2d op
|
* @return result of conv2d op
|
||||||
*/
|
*/
|
||||||
public SDVariable conv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, Conv2DConfig config) {
|
public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) {
|
||||||
validateFloatingPoint("conv2d", "input", layerInput);
|
validateFloatingPoint("conv2d", "input", layerInput);
|
||||||
validateFloatingPoint("conv2d", "weights", weights);
|
validateFloatingPoint("conv2d", "weights", weights);
|
||||||
validateFloatingPoint("conv2d", "bias", bias);
|
validateFloatingPoint("conv2d", "bias", bias);
|
||||||
|
@ -204,18 +212,13 @@ public class SDCNN extends SDOps {
|
||||||
arr[1] = weights;
|
arr[1] = weights;
|
||||||
if (bias != null)
|
if (bias != null)
|
||||||
arr[2] = bias;
|
arr[2] = bias;
|
||||||
return conv2d(arr, config);
|
return conv2d(name, arr, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D Convolution operation with optional bias
|
* See {@link #conv2d(String, SDVariable[], Conv2DConfig)}.
|
||||||
*
|
|
||||||
* @param inputs an array with either 2 elements (layerInput, weights) or 3 elements (layerInput, weights, bias) as
|
|
||||||
* described in {@link #conv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)}
|
|
||||||
* @param config Conv2DConfig configuration
|
|
||||||
* @return result of convolution 2d operation
|
|
||||||
*/
|
*/
|
||||||
public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig config) {
|
public SDVariable conv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) {
|
||||||
return conv2d(null, inputs, config);
|
return conv2d(null, inputs, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -228,7 +231,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param config Conv2DConfig configuration
|
* @param config Conv2DConfig configuration
|
||||||
* @return result of convolution 2d operation
|
* @return result of convolution 2d operation
|
||||||
*/
|
*/
|
||||||
public SDVariable conv2d(String name, SDVariable[] inputs, Conv2DConfig config) {
|
public SDVariable conv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) {
|
||||||
for(SDVariable v : inputs)
|
for(SDVariable v : inputs)
|
||||||
validateNumerical("conv2d", v);
|
validateNumerical("conv2d", v);
|
||||||
SDVariable ret = f().conv2d(inputs, config);
|
SDVariable ret = f().conv2d(inputs, config);
|
||||||
|
@ -236,19 +239,26 @@ public class SDCNN extends SDOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convolution 3D operation without bias
|
* See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias.
|
||||||
*
|
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels])
|
|
||||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
|
|
||||||
* @param conv3DConfig the configuration
|
|
||||||
* @return Conv3d output variable
|
|
||||||
*/
|
*/
|
||||||
public SDVariable conv3d(SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) {
|
public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) {
|
||||||
return conv3d(null, input, weights, null, conv3DConfig);
|
return conv3d(null, input, weights, null, conv3DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias.
|
||||||
|
*/
|
||||||
|
public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) {
|
||||||
|
return conv3d(name, input, weights, null, conv3DConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}.
|
||||||
|
*/
|
||||||
|
public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) {
|
||||||
|
return conv3d(null, input, weights, bias, conv3DConfig);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convolution 3D operation with optional bias
|
* Convolution 3D operation with optional bias
|
||||||
*
|
*
|
||||||
|
@ -261,7 +271,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param conv3DConfig the configuration
|
* @param conv3DConfig the configuration
|
||||||
* @return Conv3d output variable
|
* @return Conv3d output variable
|
||||||
*/
|
*/
|
||||||
public SDVariable conv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) {
|
public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) {
|
||||||
validateFloatingPoint("conv3d", "input", input);
|
validateFloatingPoint("conv3d", "input", input);
|
||||||
validateFloatingPoint("conv3d", "weights", weights);
|
validateFloatingPoint("conv3d", "weights", weights);
|
||||||
validateFloatingPoint("conv3d", "bias", bias);
|
validateFloatingPoint("conv3d", "bias", bias);
|
||||||
|
@ -276,51 +286,30 @@ public class SDCNN extends SDOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convolution 3D operation with optional bias
|
* See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias.
|
||||||
*
|
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels])
|
|
||||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
|
|
||||||
* @param bias Optional 1D bias array with shape [outputChannels]. May be null.
|
|
||||||
* @param conv3DConfig the configuration
|
|
||||||
* @return Conv3d output variable
|
|
||||||
*/
|
*/
|
||||||
public SDVariable conv3d(SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) {
|
public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||||
return conv3d(null, input, weights, bias, conv3DConfig);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convolution 3D operation without bias
|
|
||||||
*
|
|
||||||
* @param name Name of the output variable
|
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels])
|
|
||||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
|
|
||||||
* @param conv3DConfig the configuration
|
|
||||||
* @return Conv3d output variable
|
|
||||||
*/
|
|
||||||
public SDVariable conv3d(String name, SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) {
|
|
||||||
return conv3d(name, input, weights, null, conv3DConfig);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 2D deconvolution operation without bias
|
|
||||||
*
|
|
||||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
|
||||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth].
|
|
||||||
* @param deconv2DConfig DeConv2DConfig configuration
|
|
||||||
* @return result of deconv2d op
|
|
||||||
*/
|
|
||||||
public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, DeConv2DConfig deconv2DConfig) {
|
|
||||||
return deconv2d(layerInput, weights, null, deconv2DConfig);
|
return deconv2d(layerInput, weights, null, deconv2DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias.
|
||||||
|
*/
|
||||||
|
public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||||
|
return deconv2d(name, layerInput, weights, null, deconv2DConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}.
|
||||||
|
*/
|
||||||
|
public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||||
|
return deconv2d(null, layerInput, weights, bias, deconv2DConfig);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D deconvolution operation with optional bias
|
* 2D deconvolution operation with optional bias
|
||||||
*
|
*
|
||||||
|
* @param name name of the operation in SameDiff
|
||||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
||||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth].
|
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth].
|
||||||
|
@ -328,7 +317,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param deconv2DConfig DeConv2DConfig configuration
|
* @param deconv2DConfig DeConv2DConfig configuration
|
||||||
* @return result of deconv2d op
|
* @return result of deconv2d op
|
||||||
*/
|
*/
|
||||||
public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, DeConv2DConfig deconv2DConfig) {
|
public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||||
validateFloatingPoint("deconv2d", "input", layerInput);
|
validateFloatingPoint("deconv2d", "input", layerInput);
|
||||||
validateFloatingPoint("deconv2d", "weights", weights);
|
validateFloatingPoint("deconv2d", "weights", weights);
|
||||||
validateFloatingPoint("deconv2d", "bias", bias);
|
validateFloatingPoint("deconv2d", "bias", bias);
|
||||||
|
@ -337,18 +326,13 @@ public class SDCNN extends SDOps {
|
||||||
arr[1] = weights;
|
arr[1] = weights;
|
||||||
if (bias != null)
|
if (bias != null)
|
||||||
arr[2] = bias;
|
arr[2] = bias;
|
||||||
return deconv2d(arr, deconv2DConfig);
|
return deconv2d(name, arr, deconv2DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D deconvolution operation with or without optional bias
|
* See {@link #deconv2d(String, SDVariable[], DeConv2DConfig)}.
|
||||||
*
|
|
||||||
* @param inputs Inputs to the deconvolution 2d operation - input array of length 2 (layerInput, weights)
|
|
||||||
* or length 3 (layerInput, weights, bias) as described in {@link #deconv2d(SDVariable[], DeConv2DConfig)}
|
|
||||||
* @param deconv2DConfig the configuration
|
|
||||||
* @return result of deconv2d op
|
|
||||||
*/
|
*/
|
||||||
public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
|
public SDVariable deconv2d(@NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||||
return deconv2d(null, inputs, deconv2DConfig);
|
return deconv2d(null, inputs, deconv2DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -361,13 +345,34 @@ public class SDCNN extends SDOps {
|
||||||
* @param deconv2DConfig the configuration
|
* @param deconv2DConfig the configuration
|
||||||
* @return result of deconv2d op
|
* @return result of deconv2d op
|
||||||
*/
|
*/
|
||||||
public SDVariable deconv2d(String name, SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
|
public SDVariable deconv2d(String name, @NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) {
|
||||||
for(SDVariable v : inputs)
|
for(SDVariable v : inputs)
|
||||||
validateNumerical("deconv2d", v);
|
validateNumerical("deconv2d", v);
|
||||||
SDVariable ret = f().deconv2d(inputs, deconv2DConfig);
|
SDVariable ret = f().deconv2d(inputs, deconv2DConfig);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias.
|
||||||
|
*/
|
||||||
|
public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) {
|
||||||
|
return deconv3d(input, weights, null, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias.
|
||||||
|
*/
|
||||||
|
public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) {
|
||||||
|
return deconv3d(name, input, weights, null, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}.
|
||||||
|
*/
|
||||||
|
public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
|
||||||
|
return deconv3d(null, input, weights, bias, config);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 3D CNN deconvolution operation with or without optional bias
|
* 3D CNN deconvolution operation with or without optional bias
|
||||||
*
|
*
|
||||||
|
@ -377,7 +382,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels]
|
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels]
|
||||||
* @param config Configuration
|
* @param config Configuration
|
||||||
*/
|
*/
|
||||||
public SDVariable deconv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) {
|
public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
|
||||||
validateFloatingPoint("conv3d", input);
|
validateFloatingPoint("conv3d", input);
|
||||||
validateFloatingPoint("conv3d", weights);
|
validateFloatingPoint("conv3d", weights);
|
||||||
validateFloatingPoint("conv3d", bias);
|
validateFloatingPoint("conv3d", bias);
|
||||||
|
@ -386,41 +391,9 @@ public class SDCNN extends SDOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 3D CNN deconvolution operation with or without optional bias
|
* See {@link #depthToSpace(String, SDVariable, int, String)}.
|
||||||
*
|
|
||||||
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
|
||||||
* @param weights Weights array - shape [kD, kH, kW, oC, iC]
|
|
||||||
* @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels]
|
|
||||||
* @param config Configuration
|
|
||||||
*/
|
*/
|
||||||
public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) {
|
public SDVariable depthToSpace(@NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) {
|
||||||
return deconv3d(null, input, weights, bias, config);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 3D CNN deconvolution operation with no bias
|
|
||||||
*
|
|
||||||
* @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
|
||||||
* @param weights Weights array - shape [kD, kH, kW, oC, iC]
|
|
||||||
* @param config Configuration
|
|
||||||
*/
|
|
||||||
public SDVariable deconv3d(SDVariable input, SDVariable weights, DeConv3DConfig config) {
|
|
||||||
return deconv3d(input, weights, null, config);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convolution 2d layer batch to space operation on 4d input.<br>
|
|
||||||
* Reduces input channels dimension by rearranging data into a larger spatial dimensions<br>
|
|
||||||
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
|
|
||||||
* = [mb, 2, 4, 4]
|
|
||||||
*
|
|
||||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
|
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
|
||||||
* @param blockSize Block size, in the height/width dimension
|
|
||||||
* @param dataFormat Data format: "NCHW" or "NHWC"
|
|
||||||
* @return Output variable
|
|
||||||
*/
|
|
||||||
public SDVariable depthToSpace(SDVariable x, int blockSize, String dataFormat) {
|
|
||||||
return depthToSpace(null, x, blockSize, dataFormat);
|
return depthToSpace(null, x, blockSize, dataFormat);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -438,27 +411,36 @@ public class SDCNN extends SDOps {
|
||||||
* @return Output variable
|
* @return Output variable
|
||||||
* @see #depthToSpace(String, SDVariable, int, String)
|
* @see #depthToSpace(String, SDVariable, int, String)
|
||||||
*/
|
*/
|
||||||
public SDVariable depthToSpace(String name, SDVariable x, int blockSize, String dataFormat) {
|
public SDVariable depthToSpace(String name, @NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) {
|
||||||
SDVariable ret = f().depthToSpace(x, blockSize, dataFormat);
|
SDVariable ret = f().depthToSpace(x, blockSize, dataFormat);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Depth-wise 2D convolution operation without bias
|
* See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||||
*
|
|
||||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
|
||||||
* @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
|
|
||||||
* @param config Conv2DConfig configuration
|
|
||||||
* @return result of conv2d op
|
|
||||||
*/
|
*/
|
||||||
public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, Conv2DConfig config) {
|
public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) {
|
||||||
return depthWiseConv2d(layerInput, depthWeights, null, config);
|
return depthWiseConv2d(layerInput, depthWeights, null, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||||
|
*/
|
||||||
|
public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) {
|
||||||
|
return depthWiseConv2d(name, layerInput, depthWeights, null, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
|
||||||
|
*/
|
||||||
|
public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) {
|
||||||
|
return depthWiseConv2d(null, layerInput, depthWeights, bias, config);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Depth-wise 2D convolution operation with optional bias
|
* Depth-wise 2D convolution operation with optional bias
|
||||||
*
|
*
|
||||||
|
* @param name name of the operation in SameDiff
|
||||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
||||||
* @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
|
* @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
|
||||||
|
@ -466,7 +448,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param config Conv2DConfig configuration
|
* @param config Conv2DConfig configuration
|
||||||
* @return result of depthwise conv2d op
|
* @return result of depthwise conv2d op
|
||||||
*/
|
*/
|
||||||
public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable bias, Conv2DConfig config) {
|
public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) {
|
||||||
validateFloatingPoint("depthwiseConv2d", "input", layerInput);
|
validateFloatingPoint("depthwiseConv2d", "input", layerInput);
|
||||||
validateFloatingPoint("depthwiseConv2d", "depth weights", depthWeights);
|
validateFloatingPoint("depthwiseConv2d", "depth weights", depthWeights);
|
||||||
validateFloatingPoint("depthwiseConv2d", "bias", bias);
|
validateFloatingPoint("depthwiseConv2d", "bias", bias);
|
||||||
|
@ -475,19 +457,13 @@ public class SDCNN extends SDOps {
|
||||||
arr[1] = depthWeights;
|
arr[1] = depthWeights;
|
||||||
if (bias != null)
|
if (bias != null)
|
||||||
arr[2] = bias;
|
arr[2] = bias;
|
||||||
return depthWiseConv2d(arr, config);
|
return depthWiseConv2d(name, arr, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Depth-wise convolution 2D operation.
|
* See {@link #depthWiseConv2d(String, SDVariable[], Conv2DConfig)}.
|
||||||
*
|
|
||||||
* @param inputs the inputs to depth-wise conv2d. An array with either 2 elements (layerInput, depthWeights)
|
|
||||||
* or 3 elements (layerInput, depthWeights, bias) as described in
|
|
||||||
* {@link #depthWiseConv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)}
|
|
||||||
* @param depthConv2DConfig the configuration
|
|
||||||
* @return result of depthwise conv2d op
|
|
||||||
*/
|
*/
|
||||||
public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
|
public SDVariable depthWiseConv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) {
|
||||||
return depthWiseConv2d(null, inputs, depthConv2DConfig);
|
return depthWiseConv2d(null, inputs, depthConv2DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -501,7 +477,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param depthConv2DConfig the configuration
|
* @param depthConv2DConfig the configuration
|
||||||
* @return result of depthwise conv2d op
|
* @return result of depthwise conv2d op
|
||||||
*/
|
*/
|
||||||
public SDVariable depthWiseConv2d(String name, SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
|
public SDVariable depthWiseConv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) {
|
||||||
for(SDVariable v : inputs)
|
for(SDVariable v : inputs)
|
||||||
validateFloatingPoint("depthWiseConv2d", v);
|
validateFloatingPoint("depthWiseConv2d", v);
|
||||||
SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig);
|
SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig);
|
||||||
|
@ -509,17 +485,10 @@ public class SDCNN extends SDOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TODO doc string
|
* See {@link #dilation2D(String, SDVariable, SDVariable, int[], int[], boolean)}.
|
||||||
*
|
|
||||||
* @param df
|
|
||||||
* @param weights
|
|
||||||
* @param strides
|
|
||||||
* @param rates
|
|
||||||
* @param isSameMode
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides,
|
public SDVariable dilation2D(@NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides,
|
||||||
int[] rates, boolean isSameMode) {
|
@NonNull int[] rates, @NonNull boolean isSameMode) {
|
||||||
return dilation2D(null, df, weights, strides, rates, isSameMode);
|
return dilation2D(null, df, weights, strides, rates, isSameMode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -534,8 +503,8 @@ public class SDCNN extends SDOps {
|
||||||
* @param isSameMode
|
* @param isSameMode
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable dilation2D(String name, SDVariable df, SDVariable weights, int[] strides,
|
public SDVariable dilation2D(String name, @NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides,
|
||||||
int[] rates, boolean isSameMode) {
|
@NonNull int[] rates, @NonNull boolean isSameMode) {
|
||||||
SDVariable ret = f().dilation2D(df, weights, strides, rates, isSameMode);
|
SDVariable ret = f().dilation2D(df, weights, strides, rates, isSameMode);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
@ -555,21 +524,16 @@ public class SDCNN extends SDOps {
|
||||||
* @param sameMode If true: use same mode padding. If false
|
* @param sameMode If true: use same mode padding. If false
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable extractImagePatches(String name, SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) {
|
public SDVariable extractImagePatches(String name, @NonNull SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) {
|
||||||
SDVariable ret = f().extractImagePatches(input, kH, kW, sH, sW, rH, rW, sameMode);
|
SDVariable ret = f().extractImagePatches(input, kH, kW, sH, sW, rH, rW, sameMode);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* im2col operation for use in 2D convolution operations. Outputs a 6d array with shape
|
* See {@link #im2Col(String, SDVariable, Conv2DConfig)}.
|
||||||
* [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
|
|
||||||
*
|
|
||||||
* @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width]
|
|
||||||
* @param config Convolution configuration for the im2col operation
|
|
||||||
* @return Im2Col output variable
|
|
||||||
*/
|
*/
|
||||||
public SDVariable im2Col(SDVariable in, Conv2DConfig config) {
|
public SDVariable im2Col(@NonNull SDVariable in, @NonNull Conv2DConfig config) {
|
||||||
return im2Col(null, in, config);
|
return im2Col(null, in, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -582,20 +546,16 @@ public class SDCNN extends SDOps {
|
||||||
* @param config Convolution configuration for the im2col operation
|
* @param config Convolution configuration for the im2col operation
|
||||||
* @return Im2Col output variable
|
* @return Im2Col output variable
|
||||||
*/
|
*/
|
||||||
public SDVariable im2Col(String name, SDVariable in, Conv2DConfig config) {
|
public SDVariable im2Col(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) {
|
||||||
SDVariable ret = f().im2Col(in, config);
|
SDVariable ret = f().im2Col(in, config);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D convolution layer operation - local response normalization
|
* See {@link #localResponseNormalization(String, SDVariable, LocalResponseNormalizationConfig)}.
|
||||||
*
|
|
||||||
* @param inputs the inputs to lrn
|
|
||||||
* @param lrnConfig the configuration
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public SDVariable localResponseNormalization(SDVariable inputs, LocalResponseNormalizationConfig lrnConfig) {
|
public SDVariable localResponseNormalization(@NonNull SDVariable inputs, @NonNull LocalResponseNormalizationConfig lrnConfig) {
|
||||||
return localResponseNormalization(null, inputs, lrnConfig);
|
return localResponseNormalization(null, inputs, lrnConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -607,8 +567,8 @@ public class SDCNN extends SDOps {
|
||||||
* @param lrnConfig the configuration
|
* @param lrnConfig the configuration
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable localResponseNormalization(String name, SDVariable input,
|
public SDVariable localResponseNormalization(String name, @NonNull SDVariable input,
|
||||||
LocalResponseNormalizationConfig lrnConfig) {
|
@NonNull LocalResponseNormalizationConfig lrnConfig) {
|
||||||
validateFloatingPoint("local response normalization", input);
|
validateFloatingPoint("local response normalization", input);
|
||||||
SDVariable ret = f().localResponseNormalization(input, lrnConfig);
|
SDVariable ret = f().localResponseNormalization(input, lrnConfig);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
@ -616,14 +576,9 @@ public class SDCNN extends SDOps {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D Convolution layer operation - max pooling 2d
|
* See {@link #maxPooling2d(String, SDVariable, Pooling2DConfig)}.
|
||||||
*
|
|
||||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
|
||||||
* @param pooling2DConfig the configuration
|
|
||||||
* @return Result after applying max pooling on the input
|
|
||||||
*/
|
*/
|
||||||
public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
|
public SDVariable maxPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
|
||||||
return maxPooling2d(null, input, pooling2DConfig);
|
return maxPooling2d(null, input, pooling2DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -636,22 +591,16 @@ public class SDCNN extends SDOps {
|
||||||
* @param pooling2DConfig the configuration
|
* @param pooling2DConfig the configuration
|
||||||
* @return Result after applying max pooling on the input
|
* @return Result after applying max pooling on the input
|
||||||
*/
|
*/
|
||||||
public SDVariable maxPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) {
|
public SDVariable maxPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) {
|
||||||
validateNumerical("maxPooling2d", input);
|
validateNumerical("maxPooling2d", input);
|
||||||
SDVariable ret = f().maxPooling2d(input, pooling2DConfig);
|
SDVariable ret = f().maxPooling2d(input, pooling2DConfig);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 3D convolution layer operation - max pooling 3d operation.
|
* See {@link #maxPooling3d(String, SDVariable, Pooling3DConfig)}.
|
||||||
*
|
|
||||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
|
||||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
|
||||||
* (shape [minibatch, depth, height, width, channels])
|
|
||||||
* @param pooling3DConfig the configuration
|
|
||||||
* @return Result after applying max pooling on the input
|
|
||||||
*/
|
*/
|
||||||
public SDVariable maxPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) {
|
public SDVariable maxPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
|
||||||
return maxPooling3d(null, input, pooling3DConfig);
|
return maxPooling3d(null, input, pooling3DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -665,7 +614,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param pooling3DConfig the configuration
|
* @param pooling3DConfig the configuration
|
||||||
* @return Result after applying max pooling on the input
|
* @return Result after applying max pooling on the input
|
||||||
*/
|
*/
|
||||||
public SDVariable maxPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) {
|
public SDVariable maxPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) {
|
||||||
validateNumerical("maxPooling3d", input);
|
validateNumerical("maxPooling3d", input);
|
||||||
SDVariable ret = f().maxPooling3d(input, pooling3DConfig);
|
SDVariable ret = f().maxPooling3d(input, pooling3DConfig);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
|
@ -673,21 +622,30 @@ public class SDCNN extends SDOps {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Separable 2D convolution operation without bias
|
* See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||||
*
|
|
||||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
|
||||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
|
|
||||||
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
|
|
||||||
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]
|
|
||||||
* May be null
|
|
||||||
* @param config Conv2DConfig configuration
|
|
||||||
* @return result of separable convolution 2d operation
|
|
||||||
*/
|
*/
|
||||||
public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights,
|
public SDVariable separableConv2d(SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
|
||||||
Conv2DConfig config) {
|
@NonNull Conv2DConfig config) {
|
||||||
return separableConv2d(layerInput, depthWeights, pointWeights, null, config);
|
return separableConv2d(layerInput, depthWeights, pointWeights, null, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias.
|
||||||
|
*/
|
||||||
|
public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
|
||||||
|
@NonNull Conv2DConfig config) {
|
||||||
|
return separableConv2d(layerInput, depthWeights, pointWeights, null, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}.
|
||||||
|
*/
|
||||||
|
public SDVariable separableConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
|
||||||
|
SDVariable bias, @NonNull Conv2DConfig config) {
|
||||||
|
return separableConv2d(null, layerInput, depthWeights, pointWeights, bias, config);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Separable 2D convolution operation with optional bias
|
* Separable 2D convolution operation with optional bias
|
||||||
*
|
*
|
||||||
|
@ -700,8 +658,8 @@ public class SDCNN extends SDOps {
|
||||||
* @param config Conv2DConfig configuration
|
* @param config Conv2DConfig configuration
|
||||||
* @return result of separable convolution 2d operation
|
* @return result of separable convolution 2d operation
|
||||||
*/
|
*/
|
||||||
public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights,
|
public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights,
|
||||||
SDVariable bias, Conv2DConfig config) {
|
SDVariable bias, @NonNull Conv2DConfig config) {
|
||||||
validateFloatingPoint("separableConv2d", "input", layerInput);
|
validateFloatingPoint("separableConv2d", "input", layerInput);
|
||||||
validateFloatingPoint("separableConv2d", "depthWeights", depthWeights);
|
validateFloatingPoint("separableConv2d", "depthWeights", depthWeights);
|
||||||
validateFloatingPoint("separableConv2d", "pointWeights", pointWeights);
|
validateFloatingPoint("separableConv2d", "pointWeights", pointWeights);
|
||||||
|
@ -712,18 +670,13 @@ public class SDCNN extends SDOps {
|
||||||
arr[2] = pointWeights;
|
arr[2] = pointWeights;
|
||||||
if (bias != null)
|
if (bias != null)
|
||||||
arr[3] = bias;
|
arr[3] = bias;
|
||||||
return sconv2d(arr, config);
|
return sconv2d(name, arr, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Separable 2D convolution operation with/without optional bias
|
* See {@link #sconv2d(String, SDVariable[], Conv2DConfig)}.
|
||||||
*
|
|
||||||
* @param inputs the inputs to separable conv2 operation. Should be length 3 (layerInput, depthWeights, pointWeights)
|
|
||||||
* or length 4 (layerInput, depthWeights, pointWeights, bias) as described in {@link #separableConv2d(SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}
|
|
||||||
* @param conv2DConfig the configuration
|
|
||||||
* @return result of separable convolution 2d operation
|
|
||||||
*/
|
*/
|
||||||
public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
|
public SDVariable sconv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) {
|
||||||
return sconv2d(null, inputs, conv2DConfig);
|
return sconv2d(null, inputs, conv2DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -736,7 +689,7 @@ public class SDCNN extends SDOps {
|
||||||
* @param conv2DConfig the configuration
|
* @param conv2DConfig the configuration
|
||||||
* @return result of separable convolution 2d operation
|
* @return result of separable convolution 2d operation
|
||||||
*/
|
*/
|
||||||
public SDVariable sconv2d(String name, SDVariable[] inputs, Conv2DConfig conv2DConfig) {
|
public SDVariable sconv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) {
|
||||||
for(SDVariable v : inputs)
|
for(SDVariable v : inputs)
|
||||||
validateFloatingPoint("sconv2d", v);
|
validateFloatingPoint("sconv2d", v);
|
||||||
SDVariable ret = f().sconv2d(inputs, conv2DConfig);
|
SDVariable ret = f().sconv2d(inputs, conv2DConfig);
|
||||||
|
@ -747,7 +700,7 @@ public class SDCNN extends SDOps {
|
||||||
/**
|
/**
|
||||||
* @see #spaceToBatch(String, SDVariable, int[], int[][])
|
* @see #spaceToBatch(String, SDVariable, int[], int[][])
|
||||||
*/
|
*/
|
||||||
public SDVariable spaceToBatch(SDVariable x, int[] blocks, int[][] padding) {
|
public SDVariable spaceToBatch(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) {
|
||||||
return spaceToBatch(null, x, blocks, padding);
|
return spaceToBatch(null, x, blocks, padding);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -762,7 +715,7 @@ public class SDCNN extends SDOps {
|
||||||
* @return Output variable
|
* @return Output variable
|
||||||
* @see #batchToSpace(String, SDVariable, int[], int[][])
|
* @see #batchToSpace(String, SDVariable, int[], int[][])
|
||||||
*/
|
*/
|
||||||
public SDVariable spaceToBatch(String name, SDVariable x, int[] blocks, int[][] padding) {
|
public SDVariable spaceToBatch(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) {
|
||||||
SDVariable ret = f().spaceToBatch(x, blocks, padding);
|
SDVariable ret = f().spaceToBatch(x, blocks, padding);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
@ -770,7 +723,7 @@ public class SDCNN extends SDOps {
|
||||||
/**
|
/**
|
||||||
* @see #spaceToDepth(String, SDVariable, int, String)
|
* @see #spaceToDepth(String, SDVariable, int, String)
|
||||||
*/
|
*/
|
||||||
public SDVariable spaceToDepth(SDVariable x, int blockSize, String dataFormat) {
|
public SDVariable spaceToDepth(@NonNull SDVariable x, int blockSize, @NonNull String dataFormat) {
|
||||||
return spaceToDepth(null, x, blockSize, dataFormat);
|
return spaceToDepth(null, x, blockSize, dataFormat);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -788,23 +741,39 @@ public class SDCNN extends SDOps {
|
||||||
* @return Output variable
|
* @return Output variable
|
||||||
* @see #depthToSpace(String, SDVariable, int, String)
|
* @see #depthToSpace(String, SDVariable, int, String)
|
||||||
*/
|
*/
|
||||||
public SDVariable spaceToDepth(String name, SDVariable x, int blockSize, String dataFormat) {
|
public SDVariable spaceToDepth(String name, @NonNull SDVariable x, int blockSize, @NonNull String dataFormat) {
|
||||||
SDVariable ret = f().spaceToDepth(x, blockSize, dataFormat);
|
SDVariable ret = f().spaceToDepth(x, blockSize, dataFormat);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format.
|
* See {@link #upsampling2d(String, SDVariable, boolean, int, int)},
|
||||||
|
* scale is used for both height and width dimensions.
|
||||||
*
|
*
|
||||||
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
|
* @param scale The scale for both height and width dimensions.
|
||||||
* @param scale Scale to upsample in both H and W dimensions
|
|
||||||
* @return Upsampled input
|
|
||||||
*/
|
*/
|
||||||
public SDVariable upsampling2d(SDVariable input, int scale) {
|
public SDVariable upsampling2d(@NonNull SDVariable input, int scale) {
|
||||||
return upsampling2d(null, input, true, scale, scale);
|
return upsampling2d(null, input, true, scale, scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #upsampling2d(String, SDVariable, boolean, int, int)},
|
||||||
|
* scale is used for both height and width dimensions.
|
||||||
|
*
|
||||||
|
* @param scale The scale for both height and width dimensions.
|
||||||
|
*/
|
||||||
|
public SDVariable upsampling2d(String name, @NonNull SDVariable input, int scale) {
|
||||||
|
return upsampling2d(name, input, true, scale, scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #upsampling2d(String, SDVariable, boolean, int, int)}.
|
||||||
|
*/
|
||||||
|
public SDVariable upsampling2d(@NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) {
|
||||||
|
return upsampling2d(null, input, nchw, scaleH, scaleW);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2D Convolution layer operation - Upsampling 2d
|
* 2D Convolution layer operation - Upsampling 2d
|
||||||
*
|
*
|
||||||
|
@ -814,33 +783,8 @@ public class SDCNN extends SDOps {
|
||||||
* @param scaleW Scale to upsample in width dimension
|
* @param scaleW Scale to upsample in width dimension
|
||||||
* @return Upsampled input
|
* @return Upsampled input
|
||||||
*/
|
*/
|
||||||
public SDVariable upsampling2d(String name, SDVariable input, boolean nchw, int scaleH, int scaleW) {
|
public SDVariable upsampling2d(String name, @NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) {
|
||||||
SDVariable ret = f().upsampling2d(input, nchw, scaleH, scaleW);
|
SDVariable ret = f().upsampling2d(input, nchw, scaleH, scaleW);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format.
|
|
||||||
*
|
|
||||||
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
|
|
||||||
* @param scale Scale to upsample in both H and W dimensions
|
|
||||||
* @return Upsampled input
|
|
||||||
*/
|
|
||||||
public SDVariable upsampling2d(String name, SDVariable input, int scale) {
|
|
||||||
return upsampling2d(name, input, true, scale, scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 2D Convolution layer operation - Upsampling 2d
|
|
||||||
*
|
|
||||||
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
|
|
||||||
* or NHWC format (shape [minibatch, height, width, channels])
|
|
||||||
* @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format
|
|
||||||
* @param scaleH Scale to upsample in height dimension
|
|
||||||
* @param scaleW Scale to upsample in width dimension
|
|
||||||
* @return Upsampled input
|
|
||||||
*/
|
|
||||||
public SDVariable upsampling2d(SDVariable input, boolean nchw, int scaleH, int scaleW) {
|
|
||||||
return upsampling2d(null, input, nchw, scaleH, scaleW);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.VariableType;
|
import org.nd4j.autodiff.samediff.VariableType;
|
||||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.graph.DataType;
|
import org.nd4j.graph.DType;
|
||||||
import org.nd4j.graph.FlatArray;
|
import org.nd4j.graph.FlatArray;
|
||||||
import org.nd4j.graph.FlatNode;
|
import org.nd4j.graph.FlatNode;
|
||||||
import org.nd4j.graph.FlatProperties;
|
import org.nd4j.graph.FlatProperties;
|
||||||
|
@ -66,33 +66,33 @@ public class FlatBuffersMapper {
|
||||||
public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) {
|
public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
return DataType.FLOAT;
|
return DType.FLOAT;
|
||||||
case DOUBLE:
|
case DOUBLE:
|
||||||
return DataType.DOUBLE;
|
return DType.DOUBLE;
|
||||||
case HALF:
|
case HALF:
|
||||||
return DataType.HALF;
|
return DType.HALF;
|
||||||
case INT:
|
case INT:
|
||||||
return DataType.INT32;
|
return DType.INT32;
|
||||||
case LONG:
|
case LONG:
|
||||||
return DataType.INT64;
|
return DType.INT64;
|
||||||
case BOOL:
|
case BOOL:
|
||||||
return DataType.BOOL;
|
return DType.BOOL;
|
||||||
case SHORT:
|
case SHORT:
|
||||||
return DataType.INT16;
|
return DType.INT16;
|
||||||
case BYTE:
|
case BYTE:
|
||||||
return DataType.INT8;
|
return DType.INT8;
|
||||||
case UBYTE:
|
case UBYTE:
|
||||||
return DataType.UINT8;
|
return DType.UINT8;
|
||||||
case UTF8:
|
case UTF8:
|
||||||
return DataType.UTF8;
|
return DType.UTF8;
|
||||||
case UINT16:
|
case UINT16:
|
||||||
return DataType.UINT16;
|
return DType.UINT16;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
return DataType.UINT32;
|
return DType.UINT32;
|
||||||
case UINT64:
|
case UINT64:
|
||||||
return DataType.UINT64;
|
return DType.UINT64;
|
||||||
case BFLOAT16:
|
case BFLOAT16:
|
||||||
return DataType.BFLOAT16;
|
return DType.BFLOAT16;
|
||||||
default:
|
default:
|
||||||
throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]");
|
throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]");
|
||||||
}
|
}
|
||||||
|
@ -102,33 +102,33 @@ public class FlatBuffersMapper {
|
||||||
* This method converts enums for DataType
|
* This method converts enums for DataType
|
||||||
*/
|
*/
|
||||||
public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) {
|
public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) {
|
||||||
if (val == DataType.FLOAT) {
|
if (val == DType.FLOAT) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.FLOAT;
|
return org.nd4j.linalg.api.buffer.DataType.FLOAT;
|
||||||
} else if (val == DataType.DOUBLE) {
|
} else if (val == DType.DOUBLE) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.DOUBLE;
|
return org.nd4j.linalg.api.buffer.DataType.DOUBLE;
|
||||||
} else if (val == DataType.HALF) {
|
} else if (val == DType.HALF) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.HALF;
|
return org.nd4j.linalg.api.buffer.DataType.HALF;
|
||||||
} else if (val == DataType.INT32) {
|
} else if (val == DType.INT32) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.INT;
|
return org.nd4j.linalg.api.buffer.DataType.INT;
|
||||||
} else if (val == DataType.INT64) {
|
} else if (val == DType.INT64) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.LONG;
|
return org.nd4j.linalg.api.buffer.DataType.LONG;
|
||||||
} else if (val == DataType.INT8) {
|
} else if (val == DType.INT8) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.BYTE;
|
return org.nd4j.linalg.api.buffer.DataType.BYTE;
|
||||||
} else if (val == DataType.BOOL) {
|
} else if (val == DType.BOOL) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.BOOL;
|
return org.nd4j.linalg.api.buffer.DataType.BOOL;
|
||||||
} else if (val == DataType.UINT8) {
|
} else if (val == DType.UINT8) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.UBYTE;
|
return org.nd4j.linalg.api.buffer.DataType.UBYTE;
|
||||||
} else if (val == DataType.INT16) {
|
} else if (val == DType.INT16) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.SHORT;
|
return org.nd4j.linalg.api.buffer.DataType.SHORT;
|
||||||
} else if (val == DataType.UTF8) {
|
} else if (val == DType.UTF8) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.UTF8;
|
return org.nd4j.linalg.api.buffer.DataType.UTF8;
|
||||||
} else if (val == DataType.UINT16) {
|
} else if (val == DType.UINT16) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.UINT16;
|
return org.nd4j.linalg.api.buffer.DataType.UINT16;
|
||||||
} else if (val == DataType.UINT32) {
|
} else if (val == DType.UINT32) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.UINT32;
|
return org.nd4j.linalg.api.buffer.DataType.UINT32;
|
||||||
} else if (val == DataType.UINT64) {
|
} else if (val == DType.UINT64) {
|
||||||
return org.nd4j.linalg.api.buffer.DataType.UINT64;
|
return org.nd4j.linalg.api.buffer.DataType.UINT64;
|
||||||
} else if (val == DataType.BFLOAT16){
|
} else if (val == DType.BFLOAT16){
|
||||||
return org.nd4j.linalg.api.buffer.DataType.BFLOAT16;
|
return org.nd4j.linalg.api.buffer.DataType.BFLOAT16;
|
||||||
} else {
|
} else {
|
||||||
throw new RuntimeException("Unknown datatype: " + val);
|
throw new RuntimeException("Unknown datatype: " + val);
|
||||||
|
|
|
@ -2,8 +2,8 @@
|
||||||
|
|
||||||
package org.nd4j.graph;
|
package org.nd4j.graph;
|
||||||
|
|
||||||
public final class DataType {
|
public final class DType {
|
||||||
private DataType() { }
|
private DType() { }
|
||||||
public static final byte INHERIT = 0;
|
public static final byte INHERIT = 0;
|
||||||
public static final byte BOOL = 1;
|
public static final byte BOOL = 1;
|
||||||
public static final byte FLOAT8 = 2;
|
public static final byte FLOAT8 = 2;
|
|
@ -353,6 +353,10 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D.class,
|
||||||
|
|
|
@ -1149,16 +1149,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride(), elementWiseStride(), order, this.dataType(), isEmpty()));
|
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride(), elementWiseStride(), order, this.dataType(), isEmpty()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setShape(long[] shape) {
|
|
||||||
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride(), elementWiseStride(), ordering(), this.dataType(), isEmpty()));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setStride(long[] stride) {
|
|
||||||
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride, elementWiseStride(), ordering(), this.dataType(), isEmpty()));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setShapeAndStride(int[] shape, int[] stride) {
|
public void setShapeAndStride(int[] shape, int[] stride) {
|
||||||
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 0, ordering(), this.dataType(), false));
|
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 0, ordering(), this.dataType(), false));
|
||||||
|
@ -1283,29 +1273,16 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return scalar.getDouble(0);
|
return scalar.getDouble(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns entropy value for this INDArray
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public Number entropyNumber() {
|
public Number entropyNumber() {
|
||||||
return entropy(Integer.MAX_VALUE).getDouble(0);
|
return entropy(Integer.MAX_VALUE).getDouble(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns non-normalized Shannon entropy value for this INDArray
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public Number shannonEntropyNumber() {
|
public Number shannonEntropyNumber() {
|
||||||
return shannonEntropy(Integer.MAX_VALUE).getDouble(0);
|
return shannonEntropy(Integer.MAX_VALUE).getDouble(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns log entropy value for this INDArray
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public Number logEntropyNumber() {
|
public Number logEntropyNumber() {
|
||||||
return logEntropy(Integer.MAX_VALUE).getDouble(0);
|
return logEntropy(Integer.MAX_VALUE).getDouble(0);
|
||||||
|
@ -2297,37 +2274,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return size(0);
|
return size(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
|
|
||||||
Nd4j.getCompressor().autoDecompress(this);
|
|
||||||
int n = shape.length;
|
|
||||||
|
|
||||||
// FIXME: shapeInfo should be used here
|
|
||||||
if (shape.length < 1)
|
|
||||||
return create(Nd4j.createBufferDetached(shape));
|
|
||||||
if (offsets.length != n)
|
|
||||||
throw new IllegalArgumentException("Invalid offset " + Arrays.toString(offsets));
|
|
||||||
if (stride.length != n)
|
|
||||||
throw new IllegalArgumentException("Invalid stride " + Arrays.toString(stride));
|
|
||||||
|
|
||||||
if (Shape.contentEquals(shape, shapeOf())) {
|
|
||||||
if (ArrayUtil.isZero(offsets)) {
|
|
||||||
return this;
|
|
||||||
} else {
|
|
||||||
throw new IllegalArgumentException("Invalid subArray offsets");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
long[] dotProductOffsets = offsets;
|
|
||||||
int[] dotProductStride = stride;
|
|
||||||
|
|
||||||
long offset = Shape.offset(jvmShapeInfo.javaShapeInformation) + NDArrayIndex.offset(dotProductStride, dotProductOffsets);
|
|
||||||
if (offset >= data().length())
|
|
||||||
offset = ArrayUtil.sumLong(offsets);
|
|
||||||
|
|
||||||
return create(data, Arrays.copyOf(shape, shape.length), stride, offset, ordering());
|
|
||||||
}
|
|
||||||
|
|
||||||
protected INDArray create(DataBuffer buffer) {
|
protected INDArray create(DataBuffer buffer) {
|
||||||
return Nd4j.create(buffer);
|
return Nd4j.create(buffer);
|
||||||
}
|
}
|
||||||
|
@ -4016,58 +3962,30 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return Nd4j.getExecutioner().exec(new AMin(this, dimension));
|
return Nd4j.getExecutioner().exec(new AMin(this, dimension));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the sum along the specified dimension(s) of this ndarray
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the sum along
|
|
||||||
* @return the sum along the specified dimension of this ndarray
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray sum(int... dimension) {
|
public INDArray sum(int... dimension) {
|
||||||
validateNumericalArray("sum", true);
|
validateNumericalArray("sum", true);
|
||||||
return Nd4j.getExecutioner().exec(new Sum(this, dimension));
|
return Nd4j.getExecutioner().exec(new Sum(this, dimension));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the sum along the last dimension of this ndarray
|
|
||||||
*
|
|
||||||
* @param dimension the dimension to getScalar the sum along
|
|
||||||
* @return the sum along the specified dimension of this ndarray
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray sum(boolean keepDim, int... dimension) {
|
public INDArray sum(boolean keepDim, int... dimension) {
|
||||||
validateNumericalArray("sum", true);
|
validateNumericalArray("sum", true);
|
||||||
return Nd4j.getExecutioner().exec(new Sum(this, null, keepDim, dimension));
|
return Nd4j.getExecutioner().exec(new Sum(this, null, keepDim, dimension));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns entropy along dimension
|
|
||||||
* @param dimension
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray entropy(int... dimension) {
|
public INDArray entropy(int... dimension) {
|
||||||
validateNumericalArray("entropy", false);
|
validateNumericalArray("entropy", false);
|
||||||
return Nd4j.getExecutioner().exec(new Entropy(this, dimension));
|
return Nd4j.getExecutioner().exec(new Entropy(this, dimension));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns non-normalized Shannon entropy along dimension
|
|
||||||
* @param dimension
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray shannonEntropy(int... dimension) {
|
public INDArray shannonEntropy(int... dimension) {
|
||||||
validateNumericalArray("shannonEntropy", false);
|
validateNumericalArray("shannonEntropy", false);
|
||||||
return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension));
|
return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns log entropy along dimension
|
|
||||||
* @param dimension
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray logEntropy(int... dimension) {
|
public INDArray logEntropy(int... dimension) {
|
||||||
validateNumericalArray("logEntropy", false);
|
validateNumericalArray("logEntropy", false);
|
||||||
|
|
|
@ -468,16 +468,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setStride(long... stride) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setShape(long... shape) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray putScalar(long row, long col, double value) {
|
public INDArray putScalar(long row, long col, double value) {
|
||||||
return null;
|
return null;
|
||||||
|
@ -1284,17 +1274,10 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setShapeAndStride(int[] shape, int[] stride) {
|
public void setShapeAndStride(int[] shape, int[] stride) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setOrder(char order) {
|
public void setOrder(char order) {
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
|
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -1842,49 +1825,26 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns entropy value for this INDArray
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public Number entropyNumber() {
|
public Number entropyNumber() {
|
||||||
return entropy(Integer.MAX_VALUE).getDouble(0);
|
return entropy(Integer.MAX_VALUE).getDouble(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns non-normalized Shannon entropy value for this INDArray
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public Number shannonEntropyNumber() {
|
public Number shannonEntropyNumber() {
|
||||||
return shannonEntropy(Integer.MAX_VALUE).getDouble(0);
|
return shannonEntropy(Integer.MAX_VALUE).getDouble(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns log entropy value for this INDArray
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public Number logEntropyNumber() {
|
public Number logEntropyNumber() {
|
||||||
return logEntropy(Integer.MAX_VALUE).getDouble(0);
|
return logEntropy(Integer.MAX_VALUE).getDouble(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns entropy along dimension
|
|
||||||
* @param dimension
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray entropy(int... dimension) {
|
public INDArray entropy(int... dimension) {
|
||||||
return Nd4j.getExecutioner().exec(new Entropy(this, dimension));
|
return Nd4j.getExecutioner().exec(new Entropy(this, dimension));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns non-normalized Shannon entropy along dimension
|
|
||||||
* @param dimension
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray shannonEntropy(int... dimension) {
|
public INDArray shannonEntropy(int... dimension) {
|
||||||
return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension));
|
return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension));
|
||||||
|
|
|
@ -1016,13 +1016,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray {
|
||||||
return extendedFlags;
|
return extendedFlags;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the underlying indices of the element of the given index
|
* Returns the underlying indices of the element of the given index
|
||||||
* such as there really are in the original ndarray
|
* such as there really are in the original ndarray
|
||||||
|
@ -1138,16 +1131,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setStride(long... stride) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setShape(long... shape) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns true if this INDArray is special case: no-value INDArray
|
* This method returns true if this INDArray is special case: no-value INDArray
|
||||||
*
|
*
|
||||||
|
|
|
@ -213,11 +213,6 @@ public abstract class BaseSparseNDArrayCSR extends BaseSparseNDArray {
|
||||||
return shapeInformation;
|
return shapeInformation;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean equals(Object o) {
|
public boolean equals(Object o) {
|
||||||
//TODO use op
|
//TODO use op
|
||||||
|
|
|
@ -1854,63 +1854,47 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns entropy value for this INDArray
|
* Returns entropy value for this INDArray
|
||||||
* @return
|
* @return entropy value
|
||||||
*/
|
*/
|
||||||
Number entropyNumber();
|
Number entropyNumber();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns non-normalized Shannon entropy value for this INDArray
|
* Returns non-normalized Shannon entropy value for this INDArray
|
||||||
* @return
|
* @return non-normalized Shannon entropy
|
||||||
*/
|
*/
|
||||||
Number shannonEntropyNumber();
|
Number shannonEntropyNumber();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns log entropy value for this INDArray
|
* Returns log entropy value for this INDArray
|
||||||
* @return
|
* @return log entropy value
|
||||||
*/
|
*/
|
||||||
Number logEntropyNumber();
|
Number logEntropyNumber();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns entropy value for this INDArray along specified dimension(s)
|
* Returns entropy value for this INDArray along specified dimension(s)
|
||||||
* @return
|
* @param dimension specified dimension(s)
|
||||||
|
* @return entropy value
|
||||||
*/
|
*/
|
||||||
INDArray entropy(int... dimension);
|
INDArray entropy(int... dimension);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns entropy value for this INDArray along specified dimension(s)
|
* Returns Shannon entropy value for this INDArray along specified dimension(s)
|
||||||
* @return
|
* @param dimension specified dimension(s)
|
||||||
|
* @return Shannon entropy
|
||||||
*/
|
*/
|
||||||
INDArray shannonEntropy(int... dimension);
|
INDArray shannonEntropy(int... dimension);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns entropy value for this INDArray along specified dimension(s)
|
* Returns log entropy value for this INDArray along specified dimension(s)
|
||||||
* @return
|
* @param dimension specified dimension(s)
|
||||||
|
* @return log entropy value
|
||||||
*/
|
*/
|
||||||
INDArray logEntropy(int... dimension);
|
INDArray logEntropy(int... dimension);
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* stride setter
|
|
||||||
* @param stride
|
|
||||||
* @deprecated, use {@link #reshape(int...) }
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
void setStride(long... stride);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Shape setter
|
|
||||||
* @param shape
|
|
||||||
* @deprecated, use {@link #reshape(int...) }
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
@Deprecated
|
|
||||||
void setShape(long... shape);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Shape and stride setter
|
* Shape and stride setter
|
||||||
* @param shape
|
* @param shape new value for shape
|
||||||
* @param stride
|
* @param stride new value for stride
|
||||||
*/
|
*/
|
||||||
void setShapeAndStride(int[] shape, int[] stride);
|
void setShapeAndStride(int[] shape, int[] stride);
|
||||||
|
|
||||||
|
@ -1920,14 +1904,6 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
void setOrder(char order);
|
void setOrder(char order);
|
||||||
|
|
||||||
/**
|
|
||||||
* @param offsets
|
|
||||||
* @param shape
|
|
||||||
* @param stride
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
INDArray subArray(long[] offsets, int[] shape, int[] stride);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the elements at the specified indices
|
* Returns the elements at the specified indices
|
||||||
*
|
*
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
@ -53,19 +54,19 @@ public class AvgPooling2D extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
public AvgPooling2D(SameDiff sameDiff, SDVariable input, INDArray arrayInput, INDArray arrayOutput, Pooling2DConfig config) {
|
public AvgPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) {
|
||||||
super(null, sameDiff, new SDVariable[]{input}, false);
|
super(sameDiff, new SDVariable[]{input});
|
||||||
if (arrayInput != null) {
|
|
||||||
addInputArgument(arrayInput);
|
|
||||||
}
|
|
||||||
if (arrayOutput != null) {
|
|
||||||
addOutputArgument(arrayOutput);
|
|
||||||
}
|
|
||||||
config.setType(Pooling2D.Pooling2DType.AVG);
|
config.setType(Pooling2D.Pooling2DType.AVG);
|
||||||
|
|
||||||
|
this.config = config;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public AvgPooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){
|
||||||
|
super(new INDArray[]{input}, wrapOrNull(output));
|
||||||
|
config.setType(Pooling2D.Pooling2DType.AVG);
|
||||||
|
|
||||||
this.sameDiff = sameDiff;
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
@ -39,6 +40,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -59,18 +61,28 @@ public class Conv1D extends DynamicCustomOp {
|
||||||
protected Conv1DConfig config;
|
protected Conv1DConfig config;
|
||||||
private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s ";
|
private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s ";
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
public Conv1D(SameDiff sameDiff,
|
public Conv1D(SameDiff sameDiff,
|
||||||
SDVariable[] inputFunctions,
|
SDVariable[] inputFunctions,
|
||||||
INDArray[] inputArrays, INDArray[] outputs,
|
|
||||||
Conv1DConfig config) {
|
Conv1DConfig config) {
|
||||||
super(null, inputArrays, outputs);
|
super(sameDiff, inputFunctions);
|
||||||
this.sameDiff = sameDiff;
|
initConfig(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Conv1D(INDArray[] inputs, INDArray[] outputs, Conv1DConfig config){
|
||||||
|
super(inputs, outputs);
|
||||||
|
|
||||||
|
initConfig(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Conv1D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv1DConfig config){
|
||||||
|
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initConfig(Conv1DConfig config){
|
||||||
this.config = config;
|
this.config = config;
|
||||||
Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP());
|
Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP());
|
||||||
addArgs();
|
addArgs();
|
||||||
sameDiff.putOpForId(this.getOwnName(), this);
|
|
||||||
sameDiff.addArgsFor(inputFunctions, this);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void addArgs() {
|
protected void addArgs() {
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
@ -56,23 +57,32 @@ public class Conv2D extends DynamicCustomOp {
|
||||||
protected Conv2DConfig config;
|
protected Conv2DConfig config;
|
||||||
private static final String INVALID_CONFIGURATION = "Invalid Conv2D configuration : sW = %s pH = %s dW = %s ";
|
private static final String INVALID_CONFIGURATION = "Invalid Conv2D configuration : sW = %s pH = %s dW = %s ";
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
public Conv2D(SameDiff sameDiff,
|
public Conv2D(SameDiff sameDiff,
|
||||||
SDVariable[] inputFunctions,
|
SDVariable[] inputFunctions,
|
||||||
INDArray[] inputArrays, INDArray[] outputs,
|
|
||||||
Conv2DConfig config) {
|
Conv2DConfig config) {
|
||||||
super(null, inputArrays, outputs);
|
super(sameDiff, inputFunctions);
|
||||||
this.sameDiff = sameDiff;
|
|
||||||
|
initConfig(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Conv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
|
||||||
|
super(inputs, outputs);
|
||||||
|
|
||||||
|
initConfig(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Conv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
|
||||||
|
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void initConfig(Conv2DConfig config){
|
||||||
this.config = config;
|
this.config = config;
|
||||||
|
|
||||||
Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1,
|
Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1,
|
||||||
INVALID_CONFIGURATION,
|
INVALID_CONFIGURATION,
|
||||||
config.getSH(), config.getPH(), config.getDW());
|
config.getSH(), config.getPH(), config.getDW());
|
||||||
addArgs();
|
addArgs();
|
||||||
if(sameDiff != null) {
|
|
||||||
sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point
|
|
||||||
sameDiff.addArgsFor(inputFunctions, this);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void addArgs() {
|
protected void addArgs() {
|
||||||
|
@ -252,7 +262,6 @@ public class Conv2D extends DynamicCustomOp {
|
||||||
Conv2DDerivative conv2DDerivative = Conv2DDerivative.derivativeBuilder()
|
Conv2DDerivative conv2DDerivative = Conv2DDerivative.derivativeBuilder()
|
||||||
.sameDiff(sameDiff)
|
.sameDiff(sameDiff)
|
||||||
.config(config)
|
.config(config)
|
||||||
.outputs(outputArguments())
|
|
||||||
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
|
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
|
||||||
.build();
|
.build();
|
||||||
List<SDVariable> ret = Arrays.asList(conv2DDerivative.outputVariables());
|
List<SDVariable> ret = Arrays.asList(conv2DDerivative.outputVariables());
|
||||||
|
|
|
@ -37,8 +37,8 @@ import java.util.List;
|
||||||
public class Conv2DDerivative extends Conv2D {
|
public class Conv2DDerivative extends Conv2D {
|
||||||
|
|
||||||
@Builder(builderMethodName = "derivativeBuilder")
|
@Builder(builderMethodName = "derivativeBuilder")
|
||||||
public Conv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig config) {
|
public Conv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig config) {
|
||||||
super(sameDiff, inputFunctions, inputArrays, outputs, config);
|
super(sameDiff, inputFunctions, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Conv2DDerivative() {}
|
public Conv2DDerivative() {}
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
|
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
|
@ -33,6 +34,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -55,25 +57,27 @@ public class Conv3D extends DynamicCustomOp {
|
||||||
public Conv3D() {
|
public Conv3D() {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs,
|
public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig config) {
|
||||||
Conv3DConfig conv3DConfig) {
|
super(sameDiff, inputFunctions);
|
||||||
super(null, sameDiff, inputFunctions, false);
|
initConfig(config);
|
||||||
setSameDiff(sameDiff);
|
}
|
||||||
|
|
||||||
if (inputs != null)
|
public Conv3D(INDArray[] inputs, INDArray[] outputs, Conv3DConfig config){
|
||||||
addInputArgument(inputs);
|
super(inputs, outputs);
|
||||||
if (outputs != null)
|
initConfig(config);
|
||||||
addOutputArgument(outputs);
|
}
|
||||||
this.config = conv3DConfig;
|
|
||||||
|
public Conv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv3DConfig config){
|
||||||
|
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initConfig(Conv3DConfig config){
|
||||||
|
this.config = config;
|
||||||
Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1,
|
Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1,
|
||||||
INVALID_CONFIGURATION,
|
INVALID_CONFIGURATION,
|
||||||
config.getSW(), config.getPH(), config.getDW());
|
config.getSW(), config.getPH(), config.getDW());
|
||||||
addArgs();
|
addArgs();
|
||||||
|
|
||||||
|
|
||||||
//for (val arg: iArgs())
|
|
||||||
// System.out.println(getIArgument(arg));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -259,8 +263,6 @@ public class Conv3D extends DynamicCustomOp {
|
||||||
inputs.add(f1.get(0));
|
inputs.add(f1.get(0));
|
||||||
Conv3DDerivative conv3DDerivative = Conv3DDerivative.derivativeBuilder()
|
Conv3DDerivative conv3DDerivative = Conv3DDerivative.derivativeBuilder()
|
||||||
.conv3DConfig(config)
|
.conv3DConfig(config)
|
||||||
.inputFunctions(args())
|
|
||||||
.outputs(outputArguments())
|
|
||||||
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
|
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
|
||||||
.sameDiff(sameDiff)
|
.sameDiff(sameDiff)
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -39,8 +39,8 @@ public class Conv3DDerivative extends Conv3D {
|
||||||
public Conv3DDerivative() {}
|
public Conv3DDerivative() {}
|
||||||
|
|
||||||
@Builder(builderMethodName = "derivativeBuilder")
|
@Builder(builderMethodName = "derivativeBuilder")
|
||||||
public Conv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, Conv3DConfig conv3DConfig) {
|
public Conv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig conv3DConfig) {
|
||||||
super(sameDiff, inputFunctions, inputs, outputs, conv3DConfig);
|
super(sameDiff, inputFunctions, conv3DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
@ -31,6 +32,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
@ -51,25 +53,25 @@ public class DeConv2D extends DynamicCustomOp {
|
||||||
|
|
||||||
protected DeConv2DConfig config;
|
protected DeConv2DConfig config;
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
public DeConv2D(SameDiff sameDiff,
|
public DeConv2D(SameDiff sameDiff,
|
||||||
SDVariable[] inputs,
|
SDVariable[] inputs,
|
||||||
INDArray[] inputArrays, INDArray[] outputs,
|
|
||||||
DeConv2DConfig config) {
|
DeConv2DConfig config) {
|
||||||
super(null, inputArrays, outputs);
|
super(sameDiff, inputs);
|
||||||
this.sameDiff = sameDiff;
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
|
|
||||||
if (inputArrays != null) {
|
addArgs();
|
||||||
addInputArgument(inputArrays);
|
|
||||||
}
|
|
||||||
if (outputs != null) {
|
|
||||||
addOutputArgument(outputs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public DeConv2D(INDArray[] inputs, INDArray[] outputs, DeConv2DConfig config){
|
||||||
|
super(inputs, outputs);
|
||||||
|
|
||||||
|
this.config = config;
|
||||||
addArgs();
|
addArgs();
|
||||||
sameDiff.putOpForId(this.getOwnName(), this);
|
}
|
||||||
sameDiff.addArgsFor(inputs, this);
|
|
||||||
|
public DeConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull DeConv2DConfig config){
|
||||||
|
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -40,8 +40,8 @@ public class DeConv2DDerivative extends DeConv2D {
|
||||||
public DeConv2DDerivative() {}
|
public DeConv2DDerivative() {}
|
||||||
|
|
||||||
@Builder(builderMethodName = "derivativeBuilder")
|
@Builder(builderMethodName = "derivativeBuilder")
|
||||||
public DeConv2DDerivative(SameDiff sameDiff, SDVariable[] inputs, INDArray[] inputArrays, INDArray[] outputs, DeConv2DConfig config) {
|
public DeConv2DDerivative(SameDiff sameDiff, SDVariable[] inputs, DeConv2DConfig config) {
|
||||||
super(sameDiff, inputs, inputArrays, outputs, config);
|
super(sameDiff, inputs, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -53,25 +53,21 @@ public class DeConv2DTF extends DynamicCustomOp {
|
||||||
|
|
||||||
protected DeConv2DConfig config;
|
protected DeConv2DConfig config;
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
public DeConv2DTF(SameDiff sameDiff,
|
public DeConv2DTF(SameDiff sameDiff,
|
||||||
SDVariable[] inputs,
|
SDVariable[] inputs,
|
||||||
INDArray[] inputArrays, INDArray[] outputs,
|
|
||||||
DeConv2DConfig config) {
|
DeConv2DConfig config) {
|
||||||
super(null, inputArrays, outputs);
|
super(sameDiff, inputs);
|
||||||
this.sameDiff = sameDiff;
|
|
||||||
|
this.config = config;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public DeConv2DTF(INDArray[] inputs, INDArray[] outputs, DeConv2DConfig config){
|
||||||
|
super(inputs, outputs);
|
||||||
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
|
|
||||||
if (inputArrays != null) {
|
|
||||||
addInputArgument(inputArrays);
|
|
||||||
}
|
|
||||||
if (outputs != null) {
|
|
||||||
addOutputArgument(outputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
addArgs();
|
addArgs();
|
||||||
sameDiff.putOpForId(this.getOwnName(), this);
|
|
||||||
sameDiff.addArgsFor(inputs, this);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
@ -53,12 +54,23 @@ public class DeConv3D extends DynamicCustomOp {
|
||||||
|
|
||||||
protected DeConv3DConfig config;
|
protected DeConv3DConfig config;
|
||||||
|
|
||||||
public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, DeConv3DConfig config) {
|
public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
|
||||||
super(sameDiff, toArr(input, weights, bias));
|
super(sameDiff, toArr(input, weights, bias));
|
||||||
this.config = config;
|
this.config = config;
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public DeConv3D(INDArray[] inputs, INDArray[] outputs, DeConv3DConfig config){
|
||||||
|
super(inputs, outputs);
|
||||||
|
|
||||||
|
this.config = config;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull DeConv3DConfig config){
|
||||||
|
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||||
|
}
|
||||||
|
|
||||||
private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){
|
private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){
|
||||||
if(bias != null){
|
if(bias != null){
|
||||||
return new SDVariable[]{input, weights, bias};
|
return new SDVariable[]{input, weights, bias};
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
|
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
@ -35,6 +36,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -53,17 +55,25 @@ public class DepthwiseConv2D extends DynamicCustomOp {
|
||||||
|
|
||||||
protected Conv2DConfig config;
|
protected Conv2DConfig config;
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
public DepthwiseConv2D(SameDiff sameDiff,
|
public DepthwiseConv2D(SameDiff sameDiff,
|
||||||
SDVariable[] inputFunctions,
|
SDVariable[] inputFunctions,
|
||||||
INDArray[] inputArrays, INDArray[] outputs,
|
|
||||||
Conv2DConfig config) {
|
Conv2DConfig config) {
|
||||||
super(null, inputArrays, outputs);
|
super(sameDiff, inputFunctions);
|
||||||
this.sameDiff = sameDiff;
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
addArgs();
|
addArgs();
|
||||||
sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point
|
}
|
||||||
sameDiff.addArgsFor(inputFunctions, this);
|
|
||||||
|
public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
|
||||||
|
super(inputs, outputs);
|
||||||
|
|
||||||
|
this.config = config;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
|
||||||
|
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||||
}
|
}
|
||||||
|
|
||||||
public DepthwiseConv2D() {
|
public DepthwiseConv2D() {
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
@ -48,18 +49,19 @@ public class LocalResponseNormalization extends DynamicCustomOp {
|
||||||
protected LocalResponseNormalizationConfig config;
|
protected LocalResponseNormalizationConfig config;
|
||||||
|
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] inputFunctions,
|
public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] inputFunctions, boolean inPlace,
|
||||||
INDArray[] inputs, INDArray[] outputs,boolean inPlace,
|
|
||||||
LocalResponseNormalizationConfig config) {
|
LocalResponseNormalizationConfig config) {
|
||||||
super(null,sameDiff, inputFunctions, inPlace);
|
super(null,sameDiff, inputFunctions, inPlace);
|
||||||
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
if(inputs != null) {
|
addArgs();
|
||||||
addInputArgument(inputs);
|
|
||||||
}
|
|
||||||
if(outputs!= null) {
|
|
||||||
addOutputArgument(outputs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public LocalResponseNormalization(@NonNull INDArray input, INDArray output, @NonNull LocalResponseNormalizationConfig config){
|
||||||
|
super(new INDArray[]{input}, wrapOrNull(output));
|
||||||
|
|
||||||
|
this.config = config;
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,8 +33,8 @@ import java.util.List;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class LocalResponseNormalizationDerivative extends LocalResponseNormalization {
|
public class LocalResponseNormalizationDerivative extends LocalResponseNormalization {
|
||||||
@Builder(builderMethodName = "derivativeBuilder")
|
@Builder(builderMethodName = "derivativeBuilder")
|
||||||
public LocalResponseNormalizationDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, boolean inPlace, LocalResponseNormalizationConfig config) {
|
public LocalResponseNormalizationDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, boolean inPlace, LocalResponseNormalizationConfig config) {
|
||||||
super(sameDiff, inputFunctions, inputs, outputs, inPlace, config);
|
super(sameDiff, inputFunctions, inPlace, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
public LocalResponseNormalizationDerivative() {}
|
public LocalResponseNormalizationDerivative() {}
|
||||||
|
|
|
@ -51,27 +51,18 @@ public class MaxPooling2D extends DynamicCustomOp {
|
||||||
public MaxPooling2D() {
|
public MaxPooling2D() {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
@SuppressWarnings("Used in lombok")
|
@SuppressWarnings("Used in lombok")
|
||||||
public MaxPooling2D(SameDiff sameDiff, SDVariable input, INDArray arrayInput, INDArray arrayOutput, Pooling2DConfig config) {
|
public MaxPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) {
|
||||||
super(null, sameDiff, new SDVariable[]{input}, false);
|
super(null, sameDiff, new SDVariable[]{input}, false);
|
||||||
if (arrayInput != null) {
|
|
||||||
addInputArgument(arrayInput);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (arrayOutput != null) {
|
|
||||||
addOutputArgument(arrayOutput);
|
|
||||||
}
|
|
||||||
config.setType(Pooling2D.Pooling2DType.MAX);
|
config.setType(Pooling2D.Pooling2DType.MAX);
|
||||||
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
this.sameDiff = sameDiff;
|
|
||||||
|
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){
|
public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){
|
||||||
super(null, new INDArray[]{input}, output == null ? null : new INDArray[]{output});
|
super(null, new INDArray[]{input}, wrapOrNull(output));
|
||||||
config.setType(Pooling2D.Pooling2DType.MAX);
|
config.setType(Pooling2D.Pooling2DType.MAX);
|
||||||
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
|
|
|
@ -16,8 +16,14 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
@ -33,9 +39,6 @@ import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
import org.tensorflow.framework.NodeDef;
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
|
||||||
import java.lang.reflect.Field;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Pooling2D operation
|
* Pooling2D operation
|
||||||
|
@ -70,21 +73,27 @@ public class Pooling2D extends DynamicCustomOp {
|
||||||
|
|
||||||
public Pooling2D() {}
|
public Pooling2D() {}
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
@Builder(builderMethodName = "sameDiffBuilder")
|
||||||
@SuppressWarnings("Used in lombok")
|
@SuppressWarnings("Used in lombok")
|
||||||
public Pooling2D(SameDiff sameDiff, SDVariable[] inputs,INDArray[] arrayInputs, INDArray[] arrayOutputs,Pooling2DConfig config) {
|
public Pooling2D(SameDiff sameDiff, SDVariable[] inputs,
|
||||||
|
Pooling2DConfig config) {
|
||||||
super(null, sameDiff, inputs, false);
|
super(null, sameDiff, inputs, false);
|
||||||
if(arrayInputs != null) {
|
|
||||||
addInputArgument(arrayInputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
if(arrayOutputs != null) {
|
|
||||||
addOutputArgument(arrayOutputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
this.config = config;
|
this.config = config;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Pooling2D(@NonNull INDArray[] inputs, INDArray[] outputs, @NonNull Pooling2DConfig config){
|
||||||
|
super(inputs, outputs);
|
||||||
|
|
||||||
|
this.config = config;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Pooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){
|
||||||
|
super(new INDArray[]{input}, wrapOrNull(output));
|
||||||
|
|
||||||
|
this.config = config;
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
|
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -36,8 +37,12 @@ import java.util.List;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class Pooling2DDerivative extends Pooling2D {
|
public class Pooling2DDerivative extends Pooling2D {
|
||||||
@Builder(builderMethodName = "derivativeBuilder")
|
@Builder(builderMethodName = "derivativeBuilder")
|
||||||
public Pooling2DDerivative(SameDiff sameDiff, SDVariable[] inputs, INDArray[] arrayInputs, INDArray[] arrayOutputs, Pooling2DConfig config) {
|
public Pooling2DDerivative(SameDiff sameDiff, SDVariable[] inputs, Pooling2DConfig config) {
|
||||||
super(sameDiff, inputs, arrayInputs, arrayOutputs, config);
|
super(sameDiff, inputs, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Pooling2DDerivative(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Pooling2DConfig config){
|
||||||
|
super(new INDArray[]{input, grad}, wrapOrNull(output), config);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Pooling2DDerivative() {}
|
public Pooling2DDerivative() {}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
|
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -39,9 +40,17 @@ import java.util.List;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class SConv2D extends Conv2D {
|
public class SConv2D extends Conv2D {
|
||||||
|
|
||||||
@Builder(builderMethodName = "sBuilder")
|
@Builder(builderMethodName = "sameDiffSBuilder")
|
||||||
public SConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig conv2DConfig) {
|
public SConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig conv2DConfig) {
|
||||||
super(sameDiff, inputFunctions, inputArrays, outputs, conv2DConfig);
|
super(sameDiff, inputFunctions, conv2DConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
|
||||||
|
super(inputs, outputs, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SConv2D(@NonNull INDArray input, @NonNull INDArray depthWeights, INDArray pointWeights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){
|
||||||
|
this(wrapFilterNull(input, depthWeights, pointWeights, bias), wrapOrNull(output), config);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SConv2D() {}
|
public SConv2D() {}
|
||||||
|
|
|
@ -38,8 +38,8 @@ import java.util.List;
|
||||||
public class SConv2DDerivative extends SConv2D {
|
public class SConv2DDerivative extends SConv2D {
|
||||||
|
|
||||||
@Builder(builderMethodName = "sDerviativeBuilder")
|
@Builder(builderMethodName = "sDerviativeBuilder")
|
||||||
public SConv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig conv2DConfig) {
|
public SConv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig conv2DConfig) {
|
||||||
super(sameDiff, inputFunctions, inputArrays, outputs, conv2DConfig);
|
super(sameDiff, inputFunctions, conv2DConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SConv2DDerivative() {}
|
public SConv2DDerivative() {}
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class BitsHammingDistance extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public BitsHammingDistance(){ }
|
||||||
|
|
||||||
|
public BitsHammingDistance(@NonNull SameDiff sd, @NonNull SDVariable x, @NonNull SDVariable y){
|
||||||
|
super(sd, new SDVariable[]{x, y});
|
||||||
|
}
|
||||||
|
|
||||||
|
public BitsHammingDistance(@NonNull INDArray x, @NonNull INDArray y){
|
||||||
|
super(new INDArray[]{x, y}, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "bits_hamming_distance";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
|
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected 2 input datatypes, got %s", dataTypes);
|
||||||
|
Preconditions.checkState(dataTypes.get(0).isIntType() && dataTypes.get(1).isIntType(), "Input datatypes must be integer type, got %s", dataTypes);
|
||||||
|
return Collections.singletonList(DataType.LONG);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,78 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bit-wise AND operation, broadcastable
|
||||||
|
*
|
||||||
|
* @author raver119@gmail.com
|
||||||
|
*/
|
||||||
|
public class BitwiseAnd extends BaseDynamicTransformOp {
|
||||||
|
|
||||||
|
public BitwiseAnd(SameDiff sameDiff, SDVariable x, SDVariable y) {
|
||||||
|
super(sameDiff, new SDVariable[] {x, y} ,false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public BitwiseAnd(INDArray x, INDArray y, INDArray output) {
|
||||||
|
super(new INDArray[]{x, y}, new INDArray[]{output});
|
||||||
|
}
|
||||||
|
|
||||||
|
public BitwiseAnd(INDArray x, INDArray y) {
|
||||||
|
this(x, y,x.ulike());
|
||||||
|
}
|
||||||
|
|
||||||
|
public BitwiseAnd() {}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "bitwise_and";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String onnxName() {
|
||||||
|
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "bitwise_and";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
|
throw new UnsupportedOperationException("Not yet implemented: " + opName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
|
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,78 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bit-wise OR operation, broadcastable
|
||||||
|
*
|
||||||
|
* @author raver119@gmail.com
|
||||||
|
*/
|
||||||
|
public class BitwiseOr extends BaseDynamicTransformOp {
|
||||||
|
|
||||||
|
public BitwiseOr(SameDiff sameDiff, SDVariable x, SDVariable y) {
|
||||||
|
super(sameDiff, new SDVariable[] {x, y} ,false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public BitwiseOr(INDArray x, INDArray y, INDArray output) {
|
||||||
|
super(new INDArray[]{x, y}, new INDArray[]{output});
|
||||||
|
}
|
||||||
|
|
||||||
|
public BitwiseOr(INDArray x, INDArray y) {
|
||||||
|
this(x, y,x.ulike());
|
||||||
|
}
|
||||||
|
|
||||||
|
public BitwiseOr() {}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "bitwise_or";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String onnxName() {
|
||||||
|
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "bitwise_or";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
|
throw new UnsupportedOperationException("Not yet implemented: " + opName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
|
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,78 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bit-wise XOR operation, broadcastable
|
||||||
|
*
|
||||||
|
* @author raver119@gmail.com
|
||||||
|
*/
|
||||||
|
public class BitwiseXor extends BaseDynamicTransformOp {
|
||||||
|
|
||||||
|
public BitwiseXor(SameDiff sameDiff, SDVariable x, SDVariable y) {
|
||||||
|
super(sameDiff, new SDVariable[] {x, y} ,false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public BitwiseXor(INDArray x, INDArray y, INDArray output) {
|
||||||
|
super(new INDArray[]{x, y}, new INDArray[]{output});
|
||||||
|
}
|
||||||
|
|
||||||
|
public BitwiseXor(INDArray x, INDArray y) {
|
||||||
|
this(x, y,x.ulike());
|
||||||
|
}
|
||||||
|
|
||||||
|
public BitwiseXor() {}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "bitwise_xor";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String onnxName() {
|
||||||
|
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "bitwise_xor";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
|
throw new UnsupportedOperationException("Not yet implemented: " + opName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
|
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
}
|
|
@ -61,7 +61,7 @@ public class CyclicRShiftBits extends BaseDynamicTransformOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@ public class CyclicShiftBits extends BaseDynamicTransformOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@ public class RShiftBits extends BaseDynamicTransformOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@ public class ShiftBits extends BaseDynamicTransformOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -235,10 +235,7 @@ public class Convolution {
|
||||||
public static INDArray pooling2D(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw,
|
public static INDArray pooling2D(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw,
|
||||||
int dh, int dw, boolean isSameMode, Pooling2D.Pooling2DType type, Pooling2D.Divisor divisor,
|
int dh, int dw, boolean isSameMode, Pooling2D.Pooling2DType type, Pooling2D.Divisor divisor,
|
||||||
double extra, int virtualHeight, int virtualWidth, INDArray out) {
|
double extra, int virtualHeight, int virtualWidth, INDArray out) {
|
||||||
Pooling2D pooling = Pooling2D.builder()
|
Pooling2D pooling = new Pooling2D(img, out, Pooling2DConfig.builder()
|
||||||
.arrayInputs(new INDArray[]{img})
|
|
||||||
.arrayOutputs(new INDArray[]{out})
|
|
||||||
.config(Pooling2DConfig.builder()
|
|
||||||
.dH(dh)
|
.dH(dh)
|
||||||
.dW(dw)
|
.dW(dw)
|
||||||
.extra(extra)
|
.extra(extra)
|
||||||
|
@ -251,8 +248,7 @@ public class Convolution {
|
||||||
.sW(sx)
|
.sW(sx)
|
||||||
.type(type)
|
.type(type)
|
||||||
.divisor(divisor)
|
.divisor(divisor)
|
||||||
.build())
|
.build());
|
||||||
.build();
|
|
||||||
Nd4j.getExecutioner().execAndReturn(pooling);
|
Nd4j.getExecutioner().execAndReturn(pooling);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
|
@ -96,57 +96,6 @@ public abstract class NDArrayIndex implements INDArrayIndex {
|
||||||
return offset(arr.stride(), Indices.offsets(arr.shape(), indices));
|
return offset(arr.stride(), Indices.offsets(arr.shape(), indices));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Set the shape and stride for
|
|
||||||
* new axes based dimensions
|
|
||||||
* @param arr the array to update
|
|
||||||
* the shape/strides for
|
|
||||||
* @param indexes the indexes to update based on
|
|
||||||
*/
|
|
||||||
public static void updateForNewAxes(INDArray arr, INDArrayIndex... indexes) {
|
|
||||||
int numNewAxes = NDArrayIndex.numNewAxis(indexes);
|
|
||||||
if (numNewAxes >= 1 && (indexes[0].length() > 1 || indexes[0] instanceof NDArrayIndexAll)) {
|
|
||||||
List<Long> newShape = new ArrayList<>();
|
|
||||||
List<Long> newStrides = new ArrayList<>();
|
|
||||||
int currDimension = 0;
|
|
||||||
for (int i = 0; i < indexes.length; i++) {
|
|
||||||
if (indexes[i] instanceof NewAxis) {
|
|
||||||
newShape.add(1L);
|
|
||||||
newStrides.add(0L);
|
|
||||||
} else {
|
|
||||||
newShape.add(arr.size(currDimension));
|
|
||||||
newStrides.add(arr.size(currDimension));
|
|
||||||
currDimension++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
while (currDimension < arr.rank()) {
|
|
||||||
newShape.add((long) currDimension);
|
|
||||||
newStrides.add((long) currDimension);
|
|
||||||
currDimension++;
|
|
||||||
}
|
|
||||||
|
|
||||||
long[] newShapeArr = Longs.toArray(newShape);
|
|
||||||
long[] newStrideArr = Longs.toArray(newStrides);
|
|
||||||
|
|
||||||
// FIXME: this is wrong, it breaks shapeInfo immutability
|
|
||||||
arr.setShape(newShapeArr);
|
|
||||||
arr.setStride(newStrideArr);
|
|
||||||
|
|
||||||
|
|
||||||
} else {
|
|
||||||
if (numNewAxes > 0) {
|
|
||||||
long[] newShape = Longs.concat(ArrayUtil.toLongArray(ArrayUtil.nTimes(numNewAxes, 1)), arr.shape());
|
|
||||||
long[] newStrides = Longs.concat(new long[numNewAxes], arr.stride());
|
|
||||||
arr.setShape(newShape);
|
|
||||||
arr.setStride(newStrides);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute the offset given an array of offsets.
|
* Compute the offset given an array of offsets.
|
||||||
* The offset is computed(for both fortran an d c ordering) as:
|
* The offset is computed(for both fortran an d c ordering) as:
|
||||||
|
|
|
@ -54,7 +54,7 @@ public class DeallocatorService {
|
||||||
deallocatorThreads = new Thread[numThreads];
|
deallocatorThreads = new Thread[numThreads];
|
||||||
queues = new ReferenceQueue[numThreads];
|
queues = new ReferenceQueue[numThreads];
|
||||||
for (int e = 0; e < numThreads; e++) {
|
for (int e = 0; e < numThreads; e++) {
|
||||||
log.debug("Starting deallocator thread {}", e + 1);
|
log.trace("Starting deallocator thread {}", e + 1);
|
||||||
queues[e] = new ReferenceQueue<>();
|
queues[e] = new ReferenceQueue<>();
|
||||||
|
|
||||||
int deviceId = e % numDevices;
|
int deviceId = e % numDevices;
|
||||||
|
|
|
@ -1151,4 +1151,6 @@ public interface NativeOps {
|
||||||
|
|
||||||
int lastErrorCode();
|
int lastErrorCode();
|
||||||
String lastErrorMessage();
|
String lastErrorMessage();
|
||||||
|
|
||||||
|
boolean isBlasVersionMatches(int major, int minor, int build);
|
||||||
}
|
}
|
||||||
|
|
|
@ -101,7 +101,7 @@ public class NativeOpsHolder {
|
||||||
}
|
}
|
||||||
//deviceNativeOps.setOmpNumThreads(4);
|
//deviceNativeOps.setOmpNumThreads(4);
|
||||||
|
|
||||||
log.info("Number of threads used for NativeOps: {}", deviceNativeOps.ompGetMaxThreads());
|
log.info("Number of threads used for OpenMP: {}", deviceNativeOps.ompGetMaxThreads());
|
||||||
} catch (Exception | Error e) {
|
} catch (Exception | Error e) {
|
||||||
throw new RuntimeException(
|
throw new RuntimeException(
|
||||||
"ND4J is probably missing dependencies. For more information, please refer to: http://nd4j.org/getstarted.html",
|
"ND4J is probably missing dependencies. For more information, please refer to: http://nd4j.org/getstarted.html",
|
||||||
|
|
|
@ -51,7 +51,8 @@ public abstract class Nd4jBlas implements Blas {
|
||||||
numThreads = NativeOpsHolder.getCores(Runtime.getRuntime().availableProcessors());
|
numThreads = NativeOpsHolder.getCores(Runtime.getRuntime().availableProcessors());
|
||||||
setMaxThreads(numThreads);
|
setMaxThreads(numThreads);
|
||||||
}
|
}
|
||||||
log.info("Number of threads used for BLAS: {}", getMaxThreads());
|
|
||||||
|
log.info("Number of threads used for OpenMP BLAS: {}", getMaxThreads());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -52,6 +52,7 @@ public class JCublasBackend extends Nd4jBackend {
|
||||||
throw new RuntimeException("No CUDA devices were found in system");
|
throw new RuntimeException("No CUDA devices were found in system");
|
||||||
}
|
}
|
||||||
Loader.load(org.bytedeco.cuda.global.cublas.class);
|
Loader.load(org.bytedeco.cuda.global.cublas.class);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -108,6 +108,22 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
|
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
/*
|
||||||
|
val major = new int[1];
|
||||||
|
val minor = new int[1];
|
||||||
|
val build = new int[1];
|
||||||
|
org.bytedeco.cuda.global.cublas.cublasGetProperty(0, major);
|
||||||
|
org.bytedeco.cuda.global.cublas.cublasGetProperty(1, minor);
|
||||||
|
org.bytedeco.cuda.global.cublas.cublasGetProperty(2, build);
|
||||||
|
|
||||||
|
val pew = new int[100];
|
||||||
|
org.bytedeco.cuda.global.cudart.cudaDriverGetVersion(pew);
|
||||||
|
|
||||||
|
nativeOps.isBlasVersionMatches(major[0], minor[0], build[0]);
|
||||||
|
|
||||||
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -28,10 +28,12 @@ import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
|
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
|
||||||
import org.nd4j.jita.conf.CudaEnvironment;
|
import org.nd4j.jita.conf.CudaEnvironment;
|
||||||
import org.nd4j.linalg.api.blas.impl.BaseLevel3;
|
import org.nd4j.linalg.api.blas.impl.BaseLevel3;
|
||||||
|
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
||||||
import org.nd4j.linalg.factory.DataTypeValidation;
|
import org.nd4j.linalg.factory.DataTypeValidation;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.jcublas.CublasPointer;
|
import org.nd4j.linalg.jcublas.CublasPointer;
|
||||||
|
@ -113,8 +115,13 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
@Override
|
@Override
|
||||||
protected void sgemm(char Order, char TransA, char TransB, int M, int N, int K, float alpha, INDArray A, int lda,
|
protected void sgemm(char Order, char TransA, char TransB, int M, int N, int K, float alpha, INDArray A, int lda,
|
||||||
INDArray B, int ldb, float beta, INDArray C, int ldc) {
|
INDArray B, int ldb, float beta, INDArray C, int ldc) {
|
||||||
//A = Shape.toOffsetZero(A);
|
/*
|
||||||
//B = Shape.toOffsetZero(B);
|
val ctx = AtomicAllocator.getInstance().getDeviceContext();
|
||||||
|
val handle = ctx.getCublasHandle();
|
||||||
|
synchronized (handle) {
|
||||||
|
Nd4j.exec(new Mmul(A, B, C, MMulTranspose.builder().transposeA(false).transposeB(false).build()));
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
Nd4j.getExecutioner().push();
|
Nd4j.getExecutioner().push();
|
||||||
|
|
||||||
|
@ -141,6 +148,7 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
}
|
}
|
||||||
|
|
||||||
allocator.registerAction(ctx, C, A, B);
|
allocator.registerAction(ctx, C, A, B);
|
||||||
|
|
||||||
OpExecutionerUtil.checkForAny(C);
|
OpExecutionerUtil.checkForAny(C);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -557,6 +557,13 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
|
||||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
public Environment(Pointer p) { super(p); }
|
public Environment(Pointer p) { super(p); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* These 3 fields are mostly for CUDA/cuBLAS version tracking
|
||||||
|
*/
|
||||||
|
public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter);
|
||||||
|
public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter);
|
||||||
|
public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter);
|
||||||
|
|
||||||
public static native Environment getInstance();
|
public static native Environment getInstance();
|
||||||
|
|
||||||
public native @Cast("bool") boolean isVerbose();
|
public native @Cast("bool") boolean isVerbose();
|
||||||
|
@ -1874,7 +1881,7 @@ public native void setOmpNumThreads(int threads);
|
||||||
public native void setOmpMinThreads(int threads);
|
public native void setOmpMinThreads(int threads);
|
||||||
|
|
||||||
|
|
||||||
|
public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
|
|
@ -557,6 +557,13 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
||||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
public Environment(Pointer p) { super(p); }
|
public Environment(Pointer p) { super(p); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* These 3 fields are mostly for CUDA/cuBLAS version tracking
|
||||||
|
*/
|
||||||
|
public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter);
|
||||||
|
public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter);
|
||||||
|
public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter);
|
||||||
|
|
||||||
public static native Environment getInstance();
|
public static native Environment getInstance();
|
||||||
|
|
||||||
public native @Cast("bool") boolean isVerbose();
|
public native @Cast("bool") boolean isVerbose();
|
||||||
|
@ -1874,7 +1881,7 @@ public native void setOmpNumThreads(int threads);
|
||||||
public native void setOmpMinThreads(int threads);
|
public native void setOmpMinThreads(int threads);
|
||||||
|
|
||||||
|
|
||||||
|
public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -21929,6 +21936,78 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation applies bitwise AND
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||||
|
*
|
||||||
|
* \tparam T
|
||||||
|
*/
|
||||||
|
// #if NOT_EXCLUDED(OP_bitwise_and)
|
||||||
|
@Namespace("nd4j::ops") public static class bitwise_and extends BroadcastableOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public bitwise_and(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public bitwise_and(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public bitwise_and position(long position) {
|
||||||
|
return (bitwise_and)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public bitwise_and() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
}
|
||||||
|
// #endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation applies bitwise OR
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||||
|
*
|
||||||
|
* \tparam T
|
||||||
|
*/
|
||||||
|
// #if NOT_EXCLUDED(OP_bitwise_or)
|
||||||
|
@Namespace("nd4j::ops") public static class bitwise_or extends BroadcastableOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public bitwise_or(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public bitwise_or(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public bitwise_or position(long position) {
|
||||||
|
return (bitwise_or)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public bitwise_or() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
}
|
||||||
|
// #endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation applies bitwise XOR
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||||
|
*
|
||||||
|
* \tparam T
|
||||||
|
*/
|
||||||
|
// #if NOT_EXCLUDED(OP_bitwise_xor)
|
||||||
|
@Namespace("nd4j::ops") public static class bitwise_xor extends BroadcastableOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public bitwise_xor(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public bitwise_xor(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public bitwise_xor position(long position) {
|
||||||
|
return (bitwise_xor)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public bitwise_xor() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
}
|
||||||
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This operation returns hamming distance based on bits
|
* This operation returns hamming distance based on bits
|
||||||
*
|
*
|
||||||
|
|
|
@ -389,10 +389,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
INDArray input = Nd4j.create(inSize);
|
INDArray input = Nd4j.create(inSize);
|
||||||
AvgPooling2D avgPooling2D = AvgPooling2D.builder()
|
AvgPooling2D avgPooling2D = new AvgPooling2D(input, null, conf);
|
||||||
.arrayInput(input)
|
|
||||||
.config(conf)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D);
|
val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D);
|
||||||
|
|
||||||
|
@ -410,10 +407,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
|
|
||||||
//Test backprop:
|
//Test backprop:
|
||||||
Pooling2DDerivative avg2dDeriv = Pooling2DDerivative.derivativeBuilder()
|
Pooling2DDerivative avg2dDeriv = new Pooling2DDerivative(input, grad, null, conf);
|
||||||
.arrayInputs(new INDArray[]{input, grad})
|
|
||||||
.config(conf)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv);
|
val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv);
|
||||||
assertEquals(1, outSizesBP.size());
|
assertEquals(1, outSizesBP.size());
|
||||||
|
@ -435,10 +429,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
INDArray input = Nd4j.create(inSize);
|
INDArray input = Nd4j.create(inSize);
|
||||||
AvgPooling2D avgPooling2D = AvgPooling2D.builder()
|
AvgPooling2D avgPooling2D = new AvgPooling2D(input, null, conf);
|
||||||
.arrayInput(input)
|
|
||||||
.config(conf)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D);
|
val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D);
|
||||||
assertEquals(1, outSizes.size());
|
assertEquals(1, outSizes.size());
|
||||||
|
@ -454,11 +445,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
INDArray grad = Nd4j.create(exp);
|
INDArray grad = Nd4j.create(exp);
|
||||||
|
|
||||||
//Test backprop:
|
//Test backprop:
|
||||||
Pooling2DDerivative avg2dDeriv = Pooling2DDerivative.derivativeBuilder()
|
Pooling2DDerivative avg2dDeriv = new Pooling2DDerivative(input, grad, Nd4j.create(inSize), conf);
|
||||||
.arrayInputs(new INDArray[]{input, grad}) //Original input, and output gradient (eps - same shape as output)
|
|
||||||
.arrayOutputs(new INDArray[]{Nd4j.create(inSize)}) //Output for BP: same shape as original input
|
|
||||||
.config(conf)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv);
|
val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv);
|
||||||
assertEquals(1, outSizesBP.size());
|
assertEquals(1, outSizesBP.size());
|
||||||
|
@ -749,7 +736,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.isSameMode(false)
|
.isSameMode(false)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.cnn().conv2d(vars, c);
|
SDVariable out = sd.cnn().conv2d("conv", vars, c);
|
||||||
out = sd.nn().tanh("out", out);
|
out = sd.nn().tanh("out", out);
|
||||||
|
|
||||||
INDArray outArr = sd.execAndEndResult();
|
INDArray outArr = sd.execAndEndResult();
|
||||||
|
|
Loading…
Reference in New Issue