[WIP] latest update (#8145)

* [WIP] maxpool2d_bp fix (#160)

* one test for maxpool2d_bp

Signed-off-by: raver119 <raver119@gmail.com>

* - maxpool2d_bp cuda fix for NaNs
- streamSync after each custom op execution

Signed-off-by: raver119 <raver119@gmail.com>

* MLN/CG: Don't swallow exceptions if a second exception occurs during workspace closing (#161)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Upgrade protobuf version (#162)

* First steps for protobuf version upgrade

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Phase 2

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Update imports to shaded protobuf

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Version fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Switch to single execution for protobuf codegen to work around plugin bug

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Automatically delete old PB generated files after name change

Signed-off-by: Alex Black <blacka101@gmail.com>

* - string NDArray flat serde impl + tests (#163)

- string NDArray equalsTo impl

Signed-off-by: raver119 <raver119@gmail.com>

* get rid of context variable

Signed-off-by: raver119 <raver119@gmail.com>

* lup context fix (#164)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-24 16:59:30 +03:00 committed by GitHub
parent 95b2686ce5
commit d871eab2e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
93 changed files with 891 additions and 459 deletions

View File

@ -2278,6 +2278,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null;
List<MemoryWorkspace>[] closeAtEndIteraton = (List<MemoryWorkspace>[])new List[topologicalOrder.length];
MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
Throwable t = null;
try {
for (int i = 0; i <= stopIndex; i++) {
GraphVertex current = vertices[topologicalOrder[i]];
@ -2436,6 +2437,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
}
}
}
} catch (Throwable t2){
t = t2;
} finally {
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
//Though if stopIndex < numLayers, some might still be open
@ -2444,7 +2447,15 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
//Edge case here: seems that scoping out can increase the tagScope of the current WS
//and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient
// number of times to actually close it, in all cases
try{
ws.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
}
}
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
@ -2581,6 +2592,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
boolean traceLog = log.isTraceEnabled();
Throwable t = null;
try {
for (int i = topologicalOrder.length - 1; i >= 0; i--) {
boolean hitFrozen = false;
@ -2732,8 +2744,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
tempList.addFirst(new Triple<>(newName, entry.getValue(),
g.flatteningOrderForVariable(origName)));
}
for (Triple<String, INDArray, Character> t : tempList)
gradients.addFirst(t);
for (Triple<String, INDArray, Character> triple : tempList)
gradients.addFirst(triple);
}
//Close any activation gradient workspaces that we no longer require
@ -2752,19 +2764,28 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
}
}
} catch (Throwable t2){
t = t2;
} finally {
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){
try{
ws.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
}
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
}
//Now, add the gradients in the order we need them in for flattening (same as params order)
Gradient gradient = new DefaultGradient(flattenedGradients);
for (Triple<String, INDArray, Character> t : gradients) {
gradient.setGradientFor(t.getFirst(), t.getSecond(), t.getThird());
for (Triple<String, INDArray, Character> tr : gradients) {
gradient.setGradientFor(tr.getFirst(), tr.getSecond(), tr.getThird());
}
this.gradient = gradient;

View File

@ -1242,6 +1242,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
boolean traceLog = log.isTraceEnabled();
Throwable t = null;
try {
for (int i = 0; i <= layerIndex; i++) {
LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd);
@ -1328,17 +1329,34 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS
}
}
} catch (Throwable t2){
t = t2;
} finally {
if(wsActCloseNext != null){
try {
wsActCloseNext.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
}
if(temp != null){
//Should only be non-null on exception
while(temp.isScopeActive()){
//For safety, should never occur in theory: a single close() call may not be sufficient, if
// workspace scope was borrowed and not properly closed when exception occurred
try{
temp.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
}
}
@ -1871,6 +1889,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
boolean traceLog = log.isTraceEnabled();
Throwable t = null;
try {
for (int i = layers.length - 1; i >= 0; i--) {
if (layers[i] instanceof FrozenLayer) {
@ -1961,13 +1980,31 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName());
}
}
} catch (Throwable thr ){
t = thr;
} finally {
if(wsActGradCloseNext != null){
try {
wsActGradCloseNext.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
}
if(wsActGradTemp != null) {
//Should only be non-null on exception
try {
wsActGradTemp.close();
} catch (Throwable t2) {
if (t != null) {
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
}
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
}

View File

@ -476,20 +476,37 @@ std::vector<Nd4jLong> NDArray::getShapeInfoAsVector() {
////////////////////////////////////////////////////////////////////////
std::vector<int8_t> NDArray::asByteVector() {
if (isS()) {
// string data type requires special treatment
syncToHost();
auto numWords = this->lengthOf();
auto offsetsBuffer = this->bufferAsT<Nd4jLong>();
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numWords);
auto dataLength = offsetsBuffer[numWords];
std::vector<int8_t> result(headerLength + dataLength);
memcpy(result.data(), getBuffer(), headerLength + dataLength);
return result;
} else {
// all other types are linear
std::vector<int8_t> result((unsigned long long) this->lengthOf() * sizeOfT());
if (this->isView()) {
auto tmp = this->dup(this->ordering());
syncToHost();
memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
delete tmp;
}
else {
} else {
syncToHost();
memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
}
return result;
}
}
//////////////////////////////////////////////////////////////////////////
void NDArray::linspace(const double start) {
@ -1584,9 +1601,7 @@ std::string* NDArray::bufferAsT() const {
//////////////////////////////////////////////////////////////////////////
template <typename T>
T* NDArray::bufferAsT() const {
if (isS())
throw std::runtime_error("You can't use this method on String array");
// FIXME: do we REALLY want sync here?
syncToHost();
return reinterpret_cast<T*>(getBuffer());
@ -3202,12 +3217,30 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
} else if (!shape::equalsSoft(getShapeInfo(), other->getShapeInfo()))
return false;
if (isS()) {
// string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length
for (int e = 0; e < this->lengthOf(); e++) {
auto s1 = this->e<std::string>(e);
auto s2 = other->e<std::string>(e);
if (s1 != s2)
return false;
}
return true;
} else {
// regular numeric types
NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0
ExtraArguments extras({eps});
NDArray::prepareSpecialUse({&tmp}, {this, other});
NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo());
NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(),
getSpecialBuffer(), getSpecialShapeInfo(),
extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(),
other->getShapeInfo(), other->getSpecialBuffer(),
other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(),
tmp.specialBuffer(), tmp.specialShapeInfo());
NDArray::registerSpecialUse({&tmp}, {this, other});
synchronize("NDArray::equalsTo");
@ -3217,6 +3250,7 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
return true;
}
}
//////////////////////////////////////////////////////////////////////////
template <>

View File

@ -54,6 +54,7 @@
#include <graph/ExecutionResult.h>
#include <exceptions/graph_execution_exception.h>
#include <exceptions/no_results_exception.h>
#include <graph/FlatUtils.h>
namespace nd4j{
namespace graph {
@ -575,15 +576,9 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
continue;
NDArray* array = var->getNDArray();
auto byteVector = array->asByteVector();
auto array = var->getNDArray();
auto fBuffer = builder.CreateVector(byteVector);
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
auto fArray = CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array->dataType()), bo);
auto fArray = FlatUtils::toFlatArray(builder, *array);
auto fName = builder.CreateString(*(var->getName()));
auto id = CreateIntPair(builder, var->id(), var->index());

View File

@ -866,9 +866,10 @@ void initializeFunctions(Nd4jPointer *functions) {
Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) {
Nd4jPointer pointer;
// cudaHostAllocMapped |cudaHostAllocPortable
cudaError_t res = cudaHostAlloc(reinterpret_cast<void **>(&pointer), memorySize, cudaHostAllocDefault);
auto res = cudaHostAlloc(reinterpret_cast<void **>(&pointer), memorySize, cudaHostAllocDefault);
if (res != 0)
pointer = 0L;
throw nd4j::cuda_exception::build("cudaHostAlloc(...) failed", res);
return pointer;
}
@ -884,7 +885,7 @@ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) {
Nd4jPointer pointer;
auto res = cudaMalloc(reinterpret_cast<void **>(&pointer), memorySize);
if (res != 0)
pointer = 0L;
throw nd4j::cuda_exception::build("cudaMalloc(...) failed", res);
return pointer;
}
@ -894,9 +895,9 @@ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) {
* @param pointer pointer that'll be freed
*/
int freeHost(Nd4jPointer pointer) {
cudaError_t res = cudaFreeHost(reinterpret_cast<void *>(pointer));
auto res = cudaFreeHost(reinterpret_cast<void *>(pointer));
if (res != 0)
pointer = 0L;
throw nd4j::cuda_exception::build("cudaFreeHost(...) failed", res);
return 1L;
}
@ -907,9 +908,10 @@ int freeHost(Nd4jPointer pointer) {
* @param ptrToDeviceId pointer to deviceId.
*/
int freeDevice(Nd4jPointer pointer, int deviceId) {
cudaError_t res = cudaFree(reinterpret_cast<void *>(pointer));
auto res = cudaFree(reinterpret_cast<void *>(pointer));
if (res != 0)
pointer = 0L;
throw nd4j::cuda_exception::build("cudaFree(...) failed", res);
return 1L;
}
@ -934,7 +936,7 @@ Nd4jPointer createStream() {
auto stream = new cudaStream_t();
auto dZ = cudaStreamCreate(stream);
if (dZ != 0)
throw std::runtime_error("cudaStreamCreate(...) failed");
throw nd4j::cuda_exception::build("cudaStreamCreate(...) failed", dZ);
return stream;
}
@ -944,23 +946,21 @@ Nd4jPointer createEvent() {
CHECK_ALLOC(nativeEvent, "Failed to allocate new CUDA event buffer", sizeof(cudaEvent_t));
cudaError_t dZ = cudaEventCreateWithFlags(reinterpret_cast<cudaEvent_t *>(&nativeEvent), cudaEventDisableTiming);
checkCudaErrors(dZ);
auto dZ = cudaEventCreateWithFlags(reinterpret_cast<cudaEvent_t *>(&nativeEvent), cudaEventDisableTiming);
if (dZ != 0)
throw std::runtime_error("cudaEventCreateWithFlags(...) failed");
throw nd4j::cuda_exception::build("cudaEventCreateWithFlags(...) failed", dZ);
return nativeEvent;
}
int registerEvent(Nd4jPointer event, Nd4jPointer stream) {
cudaEvent_t *pEvent = reinterpret_cast<cudaEvent_t *>(&event);
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(stream);
auto pEvent = reinterpret_cast<cudaEvent_t *>(&event);
auto pStream = reinterpret_cast<cudaStream_t *>(stream);
cudaError_t dZ = cudaEventRecord(*pEvent, *pStream);
checkCudaErrors(dZ);
auto dZ = cudaEventRecord(*pEvent, *pStream);
if (dZ != 0)
throw std::runtime_error("cudaEventRecord(...) failed");
throw nd4j::cuda_exception::build("cudaEventRecord(...) failed", dZ);
return 1;
}
@ -1065,53 +1065,48 @@ int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4j
}
int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) {
cudaError_t dZ = cudaMemset(reinterpret_cast<void *>(dst), value, static_cast<size_t>(size));
checkCudaErrors(dZ);
auto dZ = cudaMemset(reinterpret_cast<void *>(dst), value, static_cast<size_t>(size));
if (dZ != 0)
throw std::runtime_error("cudaMemset(...) failed");
throw nd4j::cuda_exception::build("cudaMemset(...) failed", dZ);
return 1;
}
int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) {
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(reserved);
auto pStream = reinterpret_cast<cudaStream_t *>(reserved);
cudaError_t dZ = cudaMemsetAsync(reinterpret_cast<void *>(dst), value, static_cast<size_t>(size), *pStream);
checkCudaErrors(dZ);
auto dZ = cudaMemsetAsync(reinterpret_cast<void *>(dst), value, static_cast<size_t>(size), *pStream);
if (dZ != 0)
throw std::runtime_error("cudaMemsetAsync(...) failed");
throw nd4j::cuda_exception::build("cudaMemsetAsync(...) failed", dZ);
return 1;
}
int destroyEvent(Nd4jPointer event) {
cudaEvent_t *pEvent = reinterpret_cast<cudaEvent_t *>(&event);
cudaError_t dZ = cudaEventDestroy(*pEvent);
checkCudaErrors(dZ);
auto pEvent = reinterpret_cast<cudaEvent_t *>(&event);
auto dZ = cudaEventDestroy(*pEvent);
if (dZ != 0)
throw std::runtime_error("cudaEvenDestroy(...) failed");
throw nd4j::cuda_exception::build("cudaEvenDestroy(...) failed", dZ);
return 1;
}
int streamSynchronize(Nd4jPointer stream) {
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(stream);
auto pStream = reinterpret_cast<cudaStream_t *>(stream);
cudaError_t dZ = cudaStreamSynchronize(*pStream);
checkCudaErrors(dZ);
auto dZ = cudaStreamSynchronize(*pStream);
if (dZ != 0)
throw std::runtime_error("cudaStreamSynchronize(...) failed");
throw nd4j::cuda_exception::build("cudaStreamSynchronize(...) failed", dZ);
return 1L;
}
int eventSynchronize(Nd4jPointer event) {
cudaEvent_t *pEvent = reinterpret_cast<cudaEvent_t *>(&event);
auto pEvent = reinterpret_cast<cudaEvent_t *>(&event);
cudaError_t dZ = cudaEventSynchronize(*pEvent);
checkCudaErrors(dZ);
auto dZ = cudaEventSynchronize(*pEvent);
if (dZ != 0)
throw std::runtime_error("cudaEventSynchronize(...) failed");
throw nd4j::cuda_exception::build("cudaEventSynchronize(...) failed", dZ);
return 1L;
}
@ -2697,13 +2692,16 @@ int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opConte
auto result = op->execute(context);
// FIXME: remove once CUDA backend is 100% ready
auto res = cudaStreamSynchronize(*context->launchContext()->getCudaStream());
if (res != 0)
throw nd4j::cuda_exception::build("customOp execution failed", res);
for (auto v:context->fastpath_in()) {
v->makeBothActual();
v->syncToDevice();
}
for (auto v:context->fastpath_out()) {
v->makeBothActual();
v->syncToDevice();
}
return result;

View File

@ -36,6 +36,8 @@ namespace nd4j {
static std::pair<Nd4jLong, Nd4jLong> fromLongPair(LongPair* pair);
static NDArray* fromFlatArray(const nd4j::graph::FlatArray* flatArray);
static flatbuffers::Offset<FlatArray> toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array);
};
}
}

View File

@ -102,5 +102,16 @@ namespace nd4j {
delete[] newShape;
return array;
}
flatbuffers::Offset<FlatArray> FlatUtils::toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array) {
auto byteVector = array.asByteVector();
auto fBuffer = builder.CreateVector(byteVector);
auto fShape = builder.CreateVector(array.getShapeInfoAsFlatVector());
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array.dataType()), bo);
}
}
}

View File

@ -26,7 +26,6 @@
namespace nd4j {
namespace ops {
namespace helpers {
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
template <typename T>
static void swapRows_(NDArray* matrix, int theFirst, int theSecond) {
@ -108,14 +107,14 @@ namespace helpers {
template <typename T>
static NDArray lup_(NDArray* input, NDArray* compound, NDArray* permutation) {
static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) {
const int rowNum = input->rows();
const int columnNum = input->columns();
NDArray determinant = NDArrayFactory::create<T>(1.f);
NDArray compoundMatrix = *input; // copy
NDArray permutationMatrix(input, false, defaultContext); // has same shape as input and contiguous strides
NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides
permutationMatrix.setIdentity();
T pivotValue; // = T(0.0);
@ -161,46 +160,43 @@ namespace helpers {
return determinant;
}
BUILD_SINGLE_TEMPLATE(template NDArray lup_, (NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
BUILD_SINGLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
template <typename T>
static int determinant_(NDArray* input, NDArray* output) {
static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) {
Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
for (int e = 0; e < output->lengthOf(); e++) {
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row)
matrix.p(row, input->e<T>(k));
output->p(e, lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr));
output->p(e, lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr));
}
return Status::OK();
}
BUILD_SINGLE_TEMPLATE(template int determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
defaultContext = context;
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (input, output), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES);
}
template <typename T>
int logAbsDeterminant_(NDArray* input, NDArray* output) {
int logAbsDeterminant_(LaunchContext *context, NDArray* input, NDArray* output) {
Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n;
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
for (int e = 0; e < output->lengthOf(); e++) {
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) {
matrix.p(row, input->e<T>(k));
}
NDArray det = lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr);
NDArray det = lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr);
if (det.e<T>(0) != 0.f)
output->p(e, nd4j::math::nd4j_log<T,T>(nd4j::math::nd4j_abs(det.t<T>(0))));
}
@ -208,25 +204,23 @@ template <typename T>
return ND4J_STATUS_OK;
}
BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (input, output), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES);
}
template <typename T>
static int inverse_(NDArray* input, NDArray* output) {
static int inverse_(LaunchContext *context, NDArray* input, NDArray* output) {
auto n = input->sizeAt(-1);
auto n2 = n * n;
auto totalCount = output->lengthOf() / n2;
output->assign(0.f); // fill up output tensor with zeros
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
for (int e = 0; e < totalCount; e++) {
if (e)
@ -235,7 +229,7 @@ template <typename T>
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
matrix.p(row++, input->e<T>(k));
}
T det = lup_<T>(&matrix, &compound, &permutation).template e<T>(0);
T det = lup_<T>(context, &matrix, &compound, &permutation).template e<T>(0);
// FIXME: and how this is going to work on float16?
if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) {
@ -268,8 +262,7 @@ template <typename T>
}
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
defaultContext = context;
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES);
}
template <typename T>
@ -296,14 +289,13 @@ template <typename T>
return true;
}
BUILD_SINGLE_TEMPLATE(template bool checkCholeskyInput_, (nd4j::LaunchContext * context, NDArray const* input), FLOAT_TYPES);
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) {
BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, (context, input), FLOAT_TYPES);
}
template <typename T>
int cholesky_(NDArray* input, NDArray* output, bool inplace) {
int cholesky_(LaunchContext *context, NDArray* input, NDArray* output, bool inplace) {
auto n = input->sizeAt(-1);
auto n2 = n * n;
@ -311,8 +303,8 @@ template <typename T>
if (!inplace)
output->assign(0.f); // fill up output tensor with zeros only inplace=false
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), defaultContext)); //, block.getWorkspace());
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), defaultContext));
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), context)); //, block.getWorkspace());
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), context));
for (int e = 0; e < totalCount; e++) {
@ -346,14 +338,13 @@ template <typename T>
}
int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) {
defaultContext = context;
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
}
template <typename T>
int logdetFunctor_(NDArray* input, NDArray* output) {
int logdetFunctor_(LaunchContext *context, NDArray* input, NDArray* output) {
std::unique_ptr<NDArray> tempOutput(input->dup());
int res = cholesky_<T>(input, tempOutput.get(), false);
int res = cholesky_<T>(context, input, tempOutput.get(), false);
if (res != ND4J_STATUS_OK)
return res;
auto n = input->sizeAt(-1);
@ -372,7 +363,7 @@ template <typename T>
}
int logdetFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (input, output), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES);
}
}

View File

@ -907,6 +907,8 @@ __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInf
/*** max ***/
case 0: {
coord2 = hstart;
coord3 = hend;
T max = -DataTypeUtils::max<T>();
for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) {

View File

@ -31,8 +31,6 @@
namespace nd4j {
namespace ops {
namespace helpers {
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
// template <typename T>
// static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
// if (theFirst != theSecond) {
@ -198,36 +196,33 @@ namespace helpers {
}
template<typename T>
static void invertLowerMatrix_(NDArray *inputMatrix, NDArray *invertedMatrix) {
static void invertLowerMatrix_(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
int n = inputMatrix->rows();
invertedMatrix->setIdentity();
if (inputMatrix->isIdentityMatrix()) return;
auto stream = defaultContext->getCudaStream();
auto stream = context->getCudaStream();
// invert main diagonal
upvertKernel<T> << < 1, n, 512, *stream >> >
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
upvertKernel<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
// invert the second diagonal
invertKernelLow<T> << < 1, n, 512, *stream >> >
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertKernelLow<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
// invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertLowKernel<T><<< n, n, 512, *stream >> >
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertLowKernel<T><<<n, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
}
void invertLowerMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) {
void invertLowerMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE);
NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix});
}
template<typename T>
static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) {
static void invertUpperMatrix_(LaunchContext *context, NDArray* inputMatrix, NDArray* invertedMatrix) {
int n = inputMatrix->rows();
invertedMatrix->setIdentity();
auto stream = defaultContext->getCudaStream();
auto stream = context->getCudaStream();
if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I
return;
}
@ -237,13 +232,12 @@ namespace helpers {
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertedMatrix->tickWriteDevice();
invertedMatrix->printIndexedBuffer("Step1 UP inversion");
invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
}
void invertUpperMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) {
void invertUpperMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE);
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
}
@ -392,7 +386,6 @@ namespace helpers {
auto n = input->rows();
cusolverDnHandle_t cusolverH = nullptr;
cusolverStatus_t status = cusolverDnCreate(&cusolverH);
defaultContext = context;
if (CUSOLVER_STATUS_SUCCESS != status) {
throw cuda_exception::build("Cannot create cuSolver handle", status);
}
@ -528,24 +521,19 @@ namespace helpers {
input->tickWriteDevice();
}
BUILD_SINGLE_TEMPLATE(template void lup_,
(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation),
FLOAT_NATIVE);
BUILD_SINGLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE);
template<typename T>
static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n;
std::vector<int> dims();
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
{input->rankOf() - 2, input->rankOf() - 1});
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
// DataType dtype = input->dataType();
// if (dtype != DataType::DOUBLE)
// dtype = DataType::FLOAT32;
defaultContext = context;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(),
defaultContext); //, block.getWorkspace());
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1);
auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input});
@ -554,8 +542,7 @@ namespace helpers {
for (int e = 0; e < output->lengthOf(); e++) {
Nd4jLong pos = e * n2;
// if (matrix.dataType() == input->dataType())
fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
// else
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
@ -578,7 +565,6 @@ namespace helpers {
}
int determinant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input});
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE);
NDArray::registerSpecialUse({output}, {input});
@ -586,19 +572,16 @@ namespace helpers {
template<typename T>
int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n;
std::vector<int> dims();
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
{input->rankOf() - 2, input->rankOf() - 1});
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
DataType dtype = input->dataType();
if (dtype != DataType::DOUBLE)
dtype = DataType::FLOAT32;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype,
defaultContext); //, block.getWorkspace());
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1);
auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input});
@ -607,8 +590,7 @@ namespace helpers {
for (int e = 0; e < output->lengthOf(); e++) {
Nd4jLong pos = e * n2;
// if (matrix.dataType() == input->dataType())
fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
// else
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
@ -620,8 +602,7 @@ namespace helpers {
auto inputBuf = reinterpret_cast<T *>(matrix.specialBuffer());
auto outputBuf = reinterpret_cast<T *>(output->specialBuffer()) + offset;
// if (matrix.dataType() == input->dataType())
determinantLogKernel<T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
(inputBuf, outputBuf, n);
determinantLogKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(inputBuf, outputBuf, n);
// else
// determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
}
@ -633,7 +614,6 @@ namespace helpers {
}
int logAbsDeterminant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input});
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE);
NDArray::registerSpecialUse({output}, {input});
@ -696,17 +676,16 @@ namespace helpers {
template<typename T>
static int inverse_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
auto n = input->sizeAt(-1);
auto n2 = n * n;
auto dtype = DataTypeUtils::fromT<T>(); //input->dataType();
// if (dtype != DataType::DOUBLE)
// dtype = DataType::FLOAT32;
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, context);
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, context);
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, context);
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, context);
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, context);
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
{input->rankOf() - 2,
input->rankOf() - 1});
@ -716,20 +695,17 @@ namespace helpers {
auto stream = context->getCudaStream();
for (auto i = 0LL; i < packX.numberOfTads(); i++) {
fillMatrix<T, T> << < 1, n2, 1024, *stream >> >
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(),
i * n2, n);
fillMatrix<T, T><<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n);
matrix.tickWriteDevice();
compound.assign(matrix);
lup_<T>(context, &compound, nullptr, nullptr);
fillLowerUpperKernel<T> << < n, n, 1024, *stream >> >
(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
fillLowerUpperKernel<T><<<n, n, 1024, *stream>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
matrix.assign(0);
invertUpperMatrix(&upper, &matrix); // U^{-1}
invertUpperMatrix(context, &upper, &matrix); // U^{-1}
matrix.tickWriteDevice();
// matrix.printIndexedBuffer("Upper Inverted");
compound.assign(0);
invertLowerMatrix(&lower, &compound); // L{-1}
invertLowerMatrix(context, &lower, &compound); // L{-1}
compound.tickWriteDevice();
// compound.printIndexedBuffer("Lower Inverted");
// matrix.tickWriteDevice();
@ -737,15 +713,12 @@ namespace helpers {
nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0);
upper.tickWriteDevice();
// upper.printIndexedBuffer("Full inverted");
returnMatrix<T> << < 1, n2, 1024, *stream >> >
(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(),
i * n2, n);
returnMatrix<T> <<<1, n2, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n);
}
return Status::OK();
}
int inverse(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input});
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE);
NDArray::registerSpecialUse({output}, {input});
@ -788,7 +761,6 @@ namespace helpers {
int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
if (!inplace)
output->assign(input);
defaultContext = context;
std::unique_ptr<NDArray> tempOutput(output->dup());
cusolverDnHandle_t handle = nullptr;
auto n = input->sizeAt(-1);
@ -868,7 +840,6 @@ namespace helpers {
// template <typename T>
int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input});
if (input->dataType() == DataType::DOUBLE)
cholesky__<double>(context, input, output, inplace);
@ -876,8 +847,7 @@ namespace helpers {
cholesky__<float>(context, input, output, inplace);
else {
std::unique_ptr<NDArray> tempOutput(
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32,
defaultContext));
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, context));
tempOutput->assign(input);
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
output->assign(tempOutput.get());
@ -888,7 +858,6 @@ namespace helpers {
int cholesky(nd4j::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
// BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
defaultContext = context;
return cholesky_(context, input, output, inplace);
}
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
@ -927,7 +896,6 @@ namespace helpers {
template<typename T>
int logdetFunctor_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input});
auto n2 = input->sizeAt(-1) * input->sizeAt(-2);
auto stream = context->getCudaStream();
@ -957,7 +925,6 @@ namespace helpers {
}
int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
BUILD_SINGLE_SELECTOR(output->dataType(), logdetFunctor_, (context, input, output), FLOAT_NATIVE);
}

View File

@ -0,0 +1,100 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <NDArray.h>
#include <NDArrayFactory.h>
#include "testlayers.h"
#include <graph/Stash.h>
#include <FlatUtils.h>
using namespace nd4j;
class FlatUtilsTests : public testing::Test {
public:
};
TEST_F(FlatUtilsTests, flat_float_serde_1) {
auto array = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 3.f, 4.f});
flatbuffers::FlatBufferBuilder builder(1024);
auto flatArray = FlatUtils::toFlatArray(builder, array);
builder.Finish(flatArray);
auto pfArray = GetFlatArray(builder.GetBufferPointer());
auto restored = FlatUtils::fromFlatArray(pfArray);
ASSERT_EQ(array, *restored);
delete restored;
}
TEST_F(FlatUtilsTests, flat_int_serde_1) {
auto array = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
flatbuffers::FlatBufferBuilder builder(1024);
auto flatArray = FlatUtils::toFlatArray(builder, array);
builder.Finish(flatArray);
auto pfArray = GetFlatArray(builder.GetBufferPointer());
auto restored = FlatUtils::fromFlatArray(pfArray);
ASSERT_EQ(array, *restored);
delete restored;
}
TEST_F(FlatUtilsTests, flat_bool_serde_1) {
auto array = NDArrayFactory::create<bool>('c', {4}, {true, false, true, false});
flatbuffers::FlatBufferBuilder builder(1024);
auto flatArray = FlatUtils::toFlatArray(builder, array);
builder.Finish(flatArray);
auto pfArray = GetFlatArray(builder.GetBufferPointer());
auto restored = FlatUtils::fromFlatArray(pfArray);
ASSERT_EQ(array, *restored);
delete restored;
}
TEST_F(FlatUtilsTests, flat_string_serde_1) {
auto array = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"});
flatbuffers::FlatBufferBuilder builder(1024);
auto flatArray = FlatUtils::toFlatArray(builder, array);
builder.Finish(flatArray);
auto pfArray = GetFlatArray(builder.GetBufferPointer());
auto restored = FlatUtils::fromFlatArray(pfArray);
ASSERT_EQ(array, *restored);
delete restored;
}

View File

@ -24,7 +24,6 @@
#include "testlayers.h"
#include <graph/Stash.h>
using namespace nd4j;
using namespace nd4j;
class StringTests : public testing::Test {

View File

@ -31,10 +31,35 @@
<build>
<plugins>
<!-- AB 2019/08/24 This plugin is to be added TEMPORARILY due to a change in the filenames of the generated ONNX -->
<!-- Normal "mvn clean" etc won't delete these files, and any users who have built ND4J even once before the
change will run into a compilation error. This can be removed after a few weeks.-->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-antrun-plugin</artifactId>
<version>1.8</version>
<executions>
<execution>
<phase>generate-sources</phase>
<goals>
<goal>run</goal>
</goals>
<configuration>
<target>
<delete file="${project.build.sourceDirectory}/onnx/OnnxMlProto3.java" />
<delete file="${project.build.sourceDirectory}/onnx/OnnxOperatorsProto3.java" />
<delete file="${project.build.sourceDirectory}/onnx/OnnxProto3.java" />
</target>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>com.github.os72</groupId>
<artifactId>protoc-jar-maven-plugin</artifactId>
<version>3.5.1.1</version>
<version>3.8.0</version>
<executions>
<execution>
<id>tensorflow</id>
@ -43,30 +68,14 @@
<goal>run</goal>
</goals>
<configuration>
<type>java-shaded</type>
<protocVersion>3.5.1</protocVersion>
<protocVersion>3.8.0</protocVersion>
<extension>.proto</extension>
<includeDirectories>
<include>src/main/protobuf/tf</include>
<include>src/main/protobuf/onnx</include>
</includeDirectories>
<inputDirectories>
<include>src/main/protobuf/tf/tensorflow</include>
</inputDirectories>
<addSources>main</addSources>
<cleanOutputFolder>false</cleanOutputFolder>
<outputDirectory>src/main/java/</outputDirectory>
</configuration>
</execution>
<execution>
<id>onnx</id>
<phase>generate-sources</phase>
<goals>
<goal>run</goal>
</goals>
<configuration>
<type>java-shaded</type>
<extension>.proto3</extension>
<protocVersion>3.5.1</protocVersion>
<inputDirectories>
<include>src/main/protobuf/onnx</include>
</inputDirectories>
<addSources>main</addSources>
@ -76,6 +85,32 @@
</execution>
</executions>
</plugin>
<plugin>
<groupId>com.google.code.maven-replacer-plugin</groupId>
<artifactId>replacer</artifactId>
<version>1.5.3</version>
<configuration>
<includes>
<include>${project.build.sourceDirectory}/org/tensorflow/**</include>
<include>${project.build.sourceDirectory}/tensorflow/**</include>
<include>${project.build.sourceDirectory}/onnx/**</include>
</includes>
<token>com.google.protobuf.</token>
<value>org.nd4j.shade.protobuf.</value>
</configuration>
<executions>
<execution>
<id>replace-imports</id>
<phase>generate-sources</phase>
<goals>
<goal>replace</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
@ -148,20 +183,15 @@
<version>${flatbuffers.version}</version>
</dependency>
<!-- Note that this is shaded flatbuffers, see the protoc declaration above
mentioning java-shaded as the type for why we use this instead of google's (mainly due ot other systems packaging
their own older protobuf versions-->
<!-- Note that this is shaded protobuf. We use this instead of google's version mainly due ot other systems packaging
their own older (incompatible) protobuf versions-->
<dependency>
<groupId>com.github.os72</groupId>
<artifactId>protobuf-java-shaded-351</artifactId>
<version>0.9</version>
</dependency>
<dependency>
<groupId>com.github.os72</groupId>
<artifactId>protobuf-java-util-shaded-351</artifactId>
<version>0.9</version>
<groupId>org.nd4j</groupId>
<artifactId>protobuf</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.objenesis</groupId>
<artifactId>objenesis</artifactId>

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
@ -101,10 +101,10 @@ public abstract class DifferentialFunction {
/**
* Initialize the function from the given
* {@link onnx.OnnxProto3.NodeProto}
* {@link onnx.Onnx.NodeProto}
* @param node
*/
public DifferentialFunction(SameDiff sameDiff,onnx.OnnxProto3.NodeProto node,Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public DifferentialFunction(SameDiff sameDiff,onnx.Onnx.NodeProto node,Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
this.sameDiff = sameDiff;
setInstanceId();
initFromOnnx(node, sameDiff, attributesForNode, graph);
@ -731,13 +731,13 @@ public abstract class DifferentialFunction {
/**
* Iniitialize the function from the given
* {@link onnx.OnnxProto3.NodeProto}
* {@link onnx.Onnx.NodeProto}
* @param node
* @param initWith
* @param attributesForNode
* @param graph
*/
public abstract void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph);
public abstract void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph);

View File

@ -19,7 +19,7 @@ package org.nd4j.autodiff.samediff;
import java.util.Objects;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions;

View File

@ -16,7 +16,7 @@
package org.nd4j.imports.descriptors.tensorflow;
import com.github.os72.protobuf351.TextFormat;
import org.nd4j.shade.protobuf.TextFormat;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.io.ClassPathResource;
import org.tensorflow.framework.OpDef;

View File

@ -16,8 +16,8 @@
package org.nd4j.imports.graphmapper;
import com.github.os72.protobuf351.Message;
import com.github.os72.protobuf351.TextFormat;
import org.nd4j.shade.protobuf.Message;
import org.nd4j.shade.protobuf.TextFormat;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.io.IOUtils;

View File

@ -16,7 +16,7 @@
package org.nd4j.imports.graphmapper;
import com.github.os72.protobuf351.Message;
import org.nd4j.shade.protobuf.Message;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping;

View File

@ -16,13 +16,13 @@
package org.nd4j.imports.graphmapper.onnx;
import com.github.os72.protobuf351.ByteString;
import com.github.os72.protobuf351.Message;
import org.nd4j.shade.protobuf.ByteString;
import org.nd4j.shade.protobuf.Message;
import com.google.common.primitives.Floats;
import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -52,7 +52,7 @@ import java.util.*;
*
* @author Adam Gibson
*/
public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto, onnx.OnnxProto3.TypeProto.Tensor> {
public class OnnxGraphMapper extends BaseGraphMapper<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto, onnx.Onnx.TypeProto.Tensor> {
private static OnnxGraphMapper INSTANCE = new OnnxGraphMapper();
@ -64,9 +64,9 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override
public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) {
try {
OnnxProto3.ModelProto graphDef = OnnxProto3.ModelProto.parseFrom(inputFile);
Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(inputFile);
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
for(OnnxProto3.NodeProto node : graphDef.getGraph().getNodeList()) {
for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
bufferedWriter.write(node.toString() + "\n");
}
@ -88,7 +88,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
* @param node
* @param graph
*/
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph) {
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.NodeProto node, Onnx.GraphProto graph) {
val properties = on.mappingsForFunction();
val tfProperties = properties.get(mappedTfName);
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
@ -170,18 +170,18 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
}
@Override
public boolean isOpIgnoreException(OnnxProto3.NodeProto node) {
public boolean isOpIgnoreException(Onnx.NodeProto node) {
return false;
}
@Override
public String getTargetMappingForOp(DifferentialFunction function, OnnxProto3.NodeProto node) {
public String getTargetMappingForOp(DifferentialFunction function, Onnx.NodeProto node) {
return function.opName();
}
@Override
public void mapProperty(String name, DifferentialFunction on, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
public void mapProperty(String name, DifferentialFunction on, Onnx.NodeProto node, Onnx.GraphProto graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
val mapping = propertyMappingsForFunction.get(name).get(getTargetMappingForOp(on, node));
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
/**
@ -263,7 +263,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override
public OnnxProto3.NodeProto getNodeWithNameFromGraph(OnnxProto3.GraphProto graph, String name) {
public Onnx.NodeProto getNodeWithNameFromGraph(Onnx.GraphProto graph, String name) {
for(int i = 0; i < graph.getNodeCount(); i++) {
val node = graph.getNode(i);
if(node.getName().equals(name))
@ -274,21 +274,21 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
}
@Override
public boolean isPlaceHolderNode(OnnxProto3.TypeProto.Tensor node) {
public boolean isPlaceHolderNode(Onnx.TypeProto.Tensor node) {
return false;
}
@Override
public List<String> getControlDependencies(OnnxProto3.NodeProto node) {
public List<String> getControlDependencies(Onnx.NodeProto node) {
throw new UnsupportedOperationException("Not yet implemented");
}
@Override
public void dumpBinaryProtoAsText(File inputFile, File outputFile) {
try {
OnnxProto3.ModelProto graphDef = OnnxProto3.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
for(OnnxProto3.NodeProto node : graphDef.getGraph().getNodeList()) {
for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
bufferedWriter.write(node.toString());
}
@ -316,12 +316,12 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override
public Map<String,onnx.OnnxProto3.TypeProto.Tensor> variablesForGraph(OnnxProto3.GraphProto graphProto) {
public Map<String,onnx.Onnx.TypeProto.Tensor> variablesForGraph(Onnx.GraphProto graphProto) {
/**
* Need to figure out why
* gpu_0/conv1_1 isn't present in VGG
*/
Map<String,onnx.OnnxProto3.TypeProto.Tensor> ret = new HashMap<>();
Map<String,onnx.Onnx.TypeProto.Tensor> ret = new HashMap<>();
for(int i = 0; i < graphProto.getInputCount(); i++) {
ret.put(graphProto.getInput(i).getName(),graphProto.getInput(i).getType().getTensorType());
}
@ -356,19 +356,19 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
}
@Override
public String translateToSameDiffName(String name, OnnxProto3.NodeProto node) {
public String translateToSameDiffName(String name, Onnx.NodeProto node) {
return null;
}
protected void addDummyTensor(String name, Map<String, OnnxProto3.TypeProto.Tensor> to) {
OnnxProto3.TensorShapeProto.Dimension dim = OnnxProto3.TensorShapeProto.Dimension.
protected void addDummyTensor(String name, Map<String, Onnx.TypeProto.Tensor> to) {
Onnx.TensorShapeProto.Dimension dim = Onnx.TensorShapeProto.Dimension.
newBuilder()
.setDimValue(-1)
.build();
OnnxProto3.TypeProto.Tensor typeProto = OnnxProto3.TypeProto.Tensor.newBuilder()
Onnx.TypeProto.Tensor typeProto = Onnx.TypeProto.Tensor.newBuilder()
.setShape(
OnnxProto3.TensorShapeProto.newBuilder()
Onnx.TensorShapeProto.newBuilder()
.addDim(dim)
.addDim(dim).build())
.build();
@ -377,23 +377,23 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override
public Message.Builder getNewGraphBuilder() {
return OnnxProto3.GraphProto.newBuilder();
return Onnx.GraphProto.newBuilder();
}
@Override
public OnnxProto3.GraphProto parseGraphFrom(byte[] inputStream) throws IOException {
return OnnxProto3.ModelProto.parseFrom(inputStream).getGraph();
public Onnx.GraphProto parseGraphFrom(byte[] inputStream) throws IOException {
return Onnx.ModelProto.parseFrom(inputStream).getGraph();
}
@Override
public OnnxProto3.GraphProto parseGraphFrom(InputStream inputStream) throws IOException {
return OnnxProto3.ModelProto.parseFrom(inputStream).getGraph();
public Onnx.GraphProto parseGraphFrom(InputStream inputStream) throws IOException {
return Onnx.ModelProto.parseFrom(inputStream).getGraph();
}
@Override
public void mapNodeType(OnnxProto3.NodeProto tfNode, ImportState<OnnxProto3.GraphProto, OnnxProto3.TypeProto.Tensor> importState,
OpImportOverride<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto> opImportOverride,
OpImportFilter<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto> opFilter) {
public void mapNodeType(Onnx.NodeProto tfNode, ImportState<Onnx.GraphProto, Onnx.TypeProto.Tensor> importState,
OpImportOverride<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opImportOverride,
OpImportFilter<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opFilter) {
val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(tfNode.getOpType());
if(differentialFunction == null) {
throw new NoOpNameFoundException("No op name found " + tfNode.getOpType());
@ -425,13 +425,13 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override
public DataType dataTypeForTensor(OnnxProto3.TypeProto.Tensor tensorProto, int outputNum) {
public DataType dataTypeForTensor(Onnx.TypeProto.Tensor tensorProto, int outputNum) {
return nd4jTypeFromOnnxType(tensorProto.getElemType());
}
@Override
public boolean isStringType(OnnxProto3.TypeProto.Tensor tensor) {
return tensor.getElemType() == OnnxProto3.TensorProto.DataType.STRING;
public boolean isStringType(Onnx.TypeProto.Tensor tensor) {
return tensor.getElemType() == Onnx.TensorProto.DataType.STRING;
}
@ -440,7 +440,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
* @param dataType the data type to convert
* @return the nd4j type for the onnx type
*/
public DataType nd4jTypeFromOnnxType(OnnxProto3.TensorProto.DataType dataType) {
public DataType nd4jTypeFromOnnxType(Onnx.TensorProto.DataType dataType) {
switch (dataType) {
case DOUBLE: return DataType.DOUBLE;
case FLOAT: return DataType.FLOAT;
@ -452,8 +452,8 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
}
@Override
public String getAttrValueFromNode(OnnxProto3.NodeProto nodeProto, String key) {
for(OnnxProto3.AttributeProto attributeProto : nodeProto.getAttributeList()) {
public String getAttrValueFromNode(Onnx.NodeProto nodeProto, String key) {
for(Onnx.AttributeProto attributeProto : nodeProto.getAttributeList()) {
if(attributeProto.getName().equals(key)) {
return attributeProto.getS().toString();
}
@ -463,29 +463,29 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
}
@Override
public long[] getShapeFromAttribute(OnnxProto3.AttributeProto attributeProto) {
public long[] getShapeFromAttribute(Onnx.AttributeProto attributeProto) {
return Longs.toArray(attributeProto.getT().getDimsList());
}
@Override
public boolean isPlaceHolder(OnnxProto3.TypeProto.Tensor nodeType) {
public boolean isPlaceHolder(Onnx.TypeProto.Tensor nodeType) {
return false;
}
@Override
public boolean isConstant(OnnxProto3.TypeProto.Tensor nodeType) {
public boolean isConstant(Onnx.TypeProto.Tensor nodeType) {
return false;
}
@Override
public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph) {
public INDArray getNDArrayFromTensor(String tensorName, Onnx.TypeProto.Tensor tensorProto, Onnx.GraphProto graph) {
DataType type = dataTypeForTensor(tensorProto, 0);
if(!tensorProto.isInitialized()) {
throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
}
OnnxProto3.TensorProto tensor = null;
Onnx.TensorProto tensor = null;
for(int i = 0; i < graph.getInitializerCount(); i++) {
val initializer = graph.getInitializer(i);
if(initializer.getName().equals(tensorName)) {
@ -508,7 +508,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
return arr;
}
public INDArray mapTensorProto(OnnxProto3.TensorProto tensor) {
public INDArray mapTensorProto(Onnx.TensorProto tensor) {
if(tensor == null)
return null;
@ -527,7 +527,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
}
@Override
public long[] getShapeFromTensor(onnx.OnnxProto3.TypeProto.Tensor tensorProto) {
public long[] getShapeFromTensor(onnx.Onnx.TypeProto.Tensor tensorProto) {
val ret = new long[Math.max(2,tensorProto.getShape().getDimCount())];
int dimCount = tensorProto.getShape().getDimCount();
if(dimCount >= 2)
@ -548,11 +548,11 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
/**
* Get the shape from a tensor proto.
* Note that this is different from {@link #getShapeFromTensor(OnnxProto3.TensorProto)}
* Note that this is different from {@link #getShapeFromTensor(Onnx.TensorProto)}
* @param tensorProto the tensor to get the shape from
* @return
*/
public long[] getShapeFromTensor(OnnxProto3.TensorProto tensorProto) {
public long[] getShapeFromTensor(Onnx.TensorProto tensorProto) {
val ret = new long[Math.max(2,tensorProto.getDimsCount())];
int dimCount = tensorProto.getDimsCount();
if(dimCount >= 2)
@ -577,74 +577,74 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override
public String getInputFromNode(OnnxProto3.NodeProto node, int index) {
public String getInputFromNode(Onnx.NodeProto node, int index) {
return node.getInput(index);
}
@Override
public int numInputsFor(OnnxProto3.NodeProto nodeProto) {
public int numInputsFor(Onnx.NodeProto nodeProto) {
return nodeProto.getInputCount();
}
@Override
public long[] getShapeFromAttr(OnnxProto3.AttributeProto attr) {
public long[] getShapeFromAttr(Onnx.AttributeProto attr) {
return Longs.toArray(attr.getT().getDimsList());
}
@Override
public Map<String, OnnxProto3.AttributeProto> getAttrMap(OnnxProto3.NodeProto nodeProto) {
Map<String,OnnxProto3.AttributeProto> proto = new HashMap<>();
public Map<String, Onnx.AttributeProto> getAttrMap(Onnx.NodeProto nodeProto) {
Map<String,Onnx.AttributeProto> proto = new HashMap<>();
for(int i = 0; i < nodeProto.getAttributeCount(); i++) {
OnnxProto3.AttributeProto attributeProto = nodeProto.getAttribute(i);
Onnx.AttributeProto attributeProto = nodeProto.getAttribute(i);
proto.put(attributeProto.getName(),attributeProto);
}
return proto;
}
@Override
public String getName(OnnxProto3.NodeProto nodeProto) {
public String getName(Onnx.NodeProto nodeProto) {
return nodeProto.getName();
}
@Override
public boolean alreadySeen(OnnxProto3.NodeProto nodeProto) {
public boolean alreadySeen(Onnx.NodeProto nodeProto) {
return false;
}
@Override
public boolean isVariableNode(OnnxProto3.NodeProto nodeProto) {
public boolean isVariableNode(Onnx.NodeProto nodeProto) {
return nodeProto.getOpType().contains("Var");
}
@Override
public boolean shouldSkip(OnnxProto3.NodeProto opType) {
public boolean shouldSkip(Onnx.NodeProto opType) {
return false;
}
@Override
public boolean hasShape(OnnxProto3.NodeProto nodeProto) {
public boolean hasShape(Onnx.NodeProto nodeProto) {
return false;
}
@Override
public long[] getShape(OnnxProto3.NodeProto nodeProto) {
public long[] getShape(Onnx.NodeProto nodeProto) {
return null;
}
@Override
public INDArray getArrayFrom(OnnxProto3.NodeProto nodeProto, OnnxProto3.GraphProto graph) {
public INDArray getArrayFrom(Onnx.NodeProto nodeProto, Onnx.GraphProto graph) {
return null;
}
@Override
public String getOpType(OnnxProto3.NodeProto nodeProto) {
public String getOpType(Onnx.NodeProto nodeProto) {
return nodeProto.getOpType();
}
@Override
public List<OnnxProto3.NodeProto> getNodeList(OnnxProto3.GraphProto graphProto) {
public List<Onnx.NodeProto> getNodeList(Onnx.GraphProto graphProto) {
return graphProto.getNodeList();
}

View File

@ -16,7 +16,7 @@
package org.nd4j.imports.graphmapper.tf;
import com.github.os72.protobuf351.Message;
import org.nd4j.shade.protobuf.Message;
import com.google.common.primitives.Floats;
import com.google.common.primitives.Ints;
import lombok.extern.slf4j.Slf4j;

View File

@ -1,6 +1,6 @@
package org.nd4j.imports.graphmapper.tf.tensors;
import com.github.os72.protobuf351.Descriptors;
import org.nd4j.shade.protobuf.Descriptors;
import org.bytedeco.javacpp.indexer.Bfloat16ArrayIndexer;
import org.bytedeco.javacpp.indexer.HalfIndexer;
import org.nd4j.linalg.api.buffer.DataType;

View File

@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -205,7 +205,7 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -200,7 +200,7 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp {
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -20,7 +20,7 @@ import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -134,7 +134,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}
@Override

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
@ -218,7 +218,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
if (!attributesForNode.containsKey("axes")) {
this.dimensions = new int[] { Integer.MAX_VALUE };
}

View File

@ -21,7 +21,7 @@ import com.google.common.primitives.Doubles;
import com.google.common.primitives.Longs;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -603,7 +603,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
@ -61,7 +61,7 @@ public class NoOp extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -367,7 +367,7 @@ public class If extends DifferentialFunction implements CustomOp {
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -468,7 +468,7 @@ public class While extends DifferentialFunction implements CustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -122,7 +122,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers;
import lombok.Builder;
import lombok.NoArgsConstructor;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -96,7 +96,7 @@ public class Linear extends BaseModule {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -260,7 +260,7 @@ public class AvgPooling2D extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
val padding = !attributesForNode.containsKey("pads") ? Arrays.asList(1L) : attributesForNode.get("pads").getIntsList();

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -78,7 +78,7 @@ public class AvgPooling3D extends Pooling3D {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("Not yet implemented");
}

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
@ -139,7 +139,7 @@ public class BatchNorm extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
addArgs();
}

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -127,7 +127,7 @@ public class Conv2D extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
addArgs();
}

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -247,7 +247,7 @@ public class DeConv2D extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val autoPad = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
val dilations = attributesForNode.get("dilations");
val dilationY = dilations == null ? 1 : dilations.getIntsList().get(0).intValue();

View File

@ -20,7 +20,7 @@ import lombok.Builder;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -151,7 +151,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
addArgs();
}

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -115,7 +115,7 @@ public class LocalResponseNormalization extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val aAlpha = attributesForNode.get("alpha");
val aBeta = attributesForNode.get("beta");
val aBias = attributesForNode.get("bias");

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -221,7 +221,7 @@ public class MaxPooling2D extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
val isSameNode = paddingVal.equals("SAME");
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -78,7 +78,7 @@ public class MaxPooling3D extends Pooling3D {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("Not yet implemented");
}

View File

@ -20,7 +20,7 @@ import lombok.Builder;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -183,7 +183,7 @@ public class Pooling2D extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val isSameNode = attributesForNode.get("auto_pad").getS().equals("SAME");
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
val padding = attributesForNode.get("pads").getIntsList();

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMCellConfiguration;
@ -73,7 +73,7 @@ public class LSTMCell extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -65,7 +65,7 @@ public class SRU extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -66,7 +66,7 @@ public class SRUCell extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce;
import lombok.EqualsAndHashCode;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -204,7 +204,7 @@ public class Mmul extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
MMulTranspose mMulTranspose = MMulTranspose.builder()

View File

@ -20,7 +20,7 @@ import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import lombok.NoArgsConstructor;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
@ -283,7 +283,7 @@ public class TensorMmul extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
MMulTranspose mMulTranspose = MMulTranspose.builder()

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -163,7 +163,7 @@ public class Concat extends DynamicCustomOp {
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -77,7 +77,7 @@ public class Diag extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -79,7 +79,7 @@ public class DiagPart extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
@ -78,7 +78,7 @@ public class Gather extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -65,7 +65,7 @@ public class MergeAvg extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -64,7 +64,7 @@ public class MergeMax extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -66,7 +66,7 @@ public class MergeSum extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -68,7 +68,7 @@ public class ParallelStack extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -66,7 +66,7 @@ public class Rank extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -106,7 +106,7 @@ public class Repeat extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -126,7 +126,7 @@ public class Reshape extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val shape = new OnnxGraphMapper().getShape(node);
this.shape = shape;
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor;
import lombok.val;
import onnx.OnnxMlProto3;
import onnx.OnnxMl;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
@ -87,7 +87,7 @@ public class Shape extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new NoOpNameFoundException("No onnx name found for shape " + opName());
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -93,7 +93,7 @@ public class Stack extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import com.google.common.primitives.Ints;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
@ -156,7 +156,7 @@ public class Transpose extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
if (!attributesForNode.containsKey("perm")) {
} else

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -127,7 +127,7 @@ public class Unstack extends DynamicCustomOp {
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape.bp;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -71,7 +71,7 @@ public class ConcatBp extends DynamicCustomOp {
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
//No op
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
@ -59,7 +59,7 @@ public class TensorArrayConcat extends BaseTensorOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException();
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
@ -59,7 +59,7 @@ public class TensorArrayGather extends BaseTensorOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException();
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -54,7 +54,7 @@ public class TensorArrayRead extends BaseTensorOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}
@Override

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
@ -52,7 +52,7 @@ public class TensorArrayScatter extends BaseTensorOp {
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}
@Override

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.buffer.DataType;
@ -58,7 +58,7 @@ public class TensorArraySize extends BaseTensorOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}
@Override

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
@ -52,7 +52,7 @@ public class TensorArraySplit extends BaseTensorOp {
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}
@Override

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.clip;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -64,7 +64,7 @@ public class ClipByNorm extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("Not yet implemented");
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.clip;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -77,7 +77,7 @@ public class ClipByValue extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("Not yet implemented");
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -62,7 +62,7 @@ public class Assign extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -132,7 +132,7 @@ public class CumProd extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -133,7 +133,7 @@ public class CumSum extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -80,7 +80,7 @@ public class Fill extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.strict;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -81,7 +81,7 @@ public class RectifiedTanh extends BaseTransformStrictOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.random.impl;
import lombok.NonNull;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -75,7 +75,7 @@ public class DropOutInverted extends BaseRandomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.random.impl;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;

View File

@ -9,7 +9,7 @@
syntax = "proto3";
package onnx;
import "onnx.proto3";
import "onnx.proto";
//
// This file contains the proto definitions for OperatorSetProto and

View File

@ -16,7 +16,7 @@
package org.nd4j.tensorflow.conversion;
import com.github.os72.protobuf351.util.JsonFormat;
import org.nd4j.shade.protobuf.util.JsonFormat;
import org.apache.commons.io.IOUtils;
import org.junit.Ignore;
import org.junit.Rule;

View File

@ -16,7 +16,7 @@
package org.nd4j.tensorflow.conversion;
import com.github.os72.protobuf351.util.JsonFormat;
import org.nd4j.shade.protobuf.util.JsonFormat;
import org.apache.commons.io.IOUtils;
import org.junit.Ignore;
import org.junit.Test;

View File

@ -732,4 +732,20 @@ public class CustomOpsTests extends BaseNd4jTest {
fail("Failed datatypes: " + failed.toString());
}
}
@Test
public void testMaxPool2Dbp_1() {
val x = Nd4j.create(DataType.HALF, 2,3,16,16).assign(Double.NaN);
val y = Nd4j.create(DataType.HALF, 2,3,8,8).assign(Double.NaN);
val z = Nd4j.create(DataType.HALF, 2,3,16,16);
val op = DynamicCustomOp.builder("maxpool2d_bp")
.addInputs(x, y)
.addOutputs(z)
.addIntegerArguments(2, 2, 2, 2, 8,8, 1,1,1, 0,0)
.build();
Nd4j.exec(op);
Nd4j.getExecutioner().commit();
}
}

View File

@ -29,6 +29,7 @@
<packaging>pom</packaging>
<modules>
<module>jackson</module>
<module>protobuf</module>
</modules>
<properties>

View File

@ -0,0 +1,228 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>nd4j-shade</artifactId>
<groupId>org.nd4j</groupId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>protobuf</artifactId>
<properties>
<skipTestResourceEnforcement>true</skipTestResourceEnforcement>
</properties>
<dependencies>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>3.8.0</version>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java-util</artifactId>
<version>3.8.0</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>custom-lifecycle</id>
<activation>
<property><name>!skip.custom.lifecycle</name></property>
</activation>
<build>
<plugins>
<plugin>
<groupId>org.apache.portals.jetspeed-2</groupId>
<artifactId>jetspeed-mvn-maven-plugin</artifactId>
<version>2.3.1</version>
<executions>
<execution>
<id>compile-and-pack</id>
<phase>compile</phase>
<goals>
<goal>mvn</goal>
</goals>
</execution>
</executions>
<dependencies>
<dependency>
<groupId>org.apache.maven.shared</groupId>
<artifactId>maven-invoker</artifactId>
<version>2.2</version>
</dependency>
</dependencies>
<configuration>
<targets combine.children="merge">
<target>
<id>create-shaded-jars</id>
<dir>@rootdir@/nd4j/nd4j-shade/protobuf/</dir>
<goals>clean,compile,package</goals>
<properties>
<skip.custom.lifecycle>true</skip.custom.lifecycle>
</properties>
</target>
</targets>
<defaultTarget>create-shaded-jars</defaultTarget>
</configuration>
</plugin>
</plugins>
</build>
</profile>
</profiles>
<build>
<plugins>
<!-- Disable Maven Lint plugin in this module. For some reason it chokes on this module (internal NPE) and we don't need it anyway here -->
<plugin>
<groupId>com.lewisd</groupId>
<artifactId>lint-maven-plugin</artifactId>
<version>0.0.11</version>
<executions>
<execution>
<id>pom-lint</id>
<phase>none</phase>
</execution>
</executions>
</plugin>
<!--
Use Maven Shade plugin to add a shaded version of the Protobuf dependencies, that can be imported by
including this module (org.nd4j.protobuf) as a dependency.
The standard com.google.protobuf dependencies will be provided, though are prefixed by org.nd4j.shade.protobuf
-->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>${maven-shade-plugin.version}</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
<resource>reference.conf</resource>
</transformer>
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
</transformer>
</transformers>
</configuration>
</execution>
</executions>
<configuration>
<!--
Important configuration options here:
createDependencyReducedPom: remove the shaded artifacts from the module dependencies. Without this, the
original dependencies will be shaded, AND still included as transitive deps
in the final POM. This is not what we want.
shadedArtifactAttached: If true, the shaded artifact will be a separate JAR file for install, with
the original un-shaded JAR being separate. With this being set to false,
the original JAR will be modified, and no extra jar will be produced.
promoteTransitiveDependencies: This will promote the transitive dependencies of the shaded dependencies
to direct dependencies. Without this, we need to manually manage the transitive
dependencies of the shaded artifacts.
Note that using <optional>true</optional> in the dependencies also allows the deps to be shaded (and
original dependencies to not be included), but does NOT work with promoteTransitiveDependencies
-->
<shadedArtifactAttached>false</shadedArtifactAttached>
<createDependencyReducedPom>true</createDependencyReducedPom>
<promoteTransitiveDependencies>true</promoteTransitiveDependencies>
<artifactSet>
<includes>
<include>com.google.protobuf:*</include>
<include>com.google.protobuf.*:*</include>
</includes>
</artifactSet>
<relocations>
<!-- Protobuf dependencies -->
<relocation>
<pattern>com.google.protobuf</pattern>
<shadedPattern>org.nd4j.shade.protobuf</shadedPattern>
</relocation>
</relocations>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
<forceCreation>true</forceCreation>
</configuration>
<executions>
<execution>
<id>empty-javadoc-jar</id>
<phase>package</phase>
<goals>
<goal>jar</goal>
</goals>
<configuration>
<classifier>javadoc</classifier>
<classesDirectory>${basedir}/javadoc</classesDirectory>
</configuration>
</execution>
<execution>
<id>empty-sources-jar</id>
<phase>package</phase>
<goals>
<goal>jar</goal>
</goals>
<configuration>
<classifier>sources</classifier>
<classesDirectory>${basedir}/src</classesDirectory>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-dependency-plugin</artifactId>
<version>3.0.0</version>
<executions>
<execution>
<id>unpack</id>
<phase>package</phase>
<goals>
<goal>unpack</goal>
</goals>
<configuration>
<artifactItems>
<artifactItem>
<groupId>org.nd4j</groupId>
<artifactId>protobuf</artifactId>
<version>${project.version}</version>
<type>jar</type>
<overWrite>false</overWrite>
<outputDirectory>${project.build.directory}/classes/</outputDirectory>
<includes>**/*.class,**/*.xml</includes>
</artifactItem>
</artifactItems>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>

View File

@ -16,7 +16,7 @@
package org.nd4j.tensorflow.conversion;
import com.github.os72.protobuf351.InvalidProtocolBufferException;
import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*;
import org.nd4j.linalg.api.buffer.DataBuffer;

View File

@ -16,9 +16,9 @@
package org.nd4j.tensorflow.conversion.graphrunner;
import com.github.os72.protobuf351.ByteString;
import com.github.os72.protobuf351.InvalidProtocolBufferException;
import com.github.os72.protobuf351.util.JsonFormat;
import org.nd4j.shade.protobuf.ByteString;
import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
import org.nd4j.shade.protobuf.util.JsonFormat;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
@ -638,7 +638,7 @@ public class GraphRunner implements Closeable {
/**
* Convert a json string written out
* by {@link com.github.os72.protobuf351.util.JsonFormat}
* by {@link org.nd4j.shade.protobuf.util.JsonFormat}
* to a {@link org.bytedeco.tensorflow.ConfigProto}
* @param json the json to read
* @return the config proto to use