[WIP] latest update (#8145)
* [WIP] maxpool2d_bp fix (#160) * one test for maxpool2d_bp Signed-off-by: raver119 <raver119@gmail.com> * - maxpool2d_bp cuda fix for NaNs - streamSync after each custom op execution Signed-off-by: raver119 <raver119@gmail.com> * MLN/CG: Don't swallow exceptions if a second exception occurs during workspace closing (#161) Signed-off-by: AlexDBlack <blacka101@gmail.com> * Upgrade protobuf version (#162) * First steps for protobuf version upgrade Signed-off-by: AlexDBlack <blacka101@gmail.com> * Phase 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update imports to shaded protobuf Signed-off-by: AlexDBlack <blacka101@gmail.com> * Version fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Switch to single execution for protobuf codegen to work around plugin bug Signed-off-by: AlexDBlack <blacka101@gmail.com> * Automatically delete old PB generated files after name change Signed-off-by: Alex Black <blacka101@gmail.com> * - string NDArray flat serde impl + tests (#163) - string NDArray equalsTo impl Signed-off-by: raver119 <raver119@gmail.com> * get rid of context variable Signed-off-by: raver119 <raver119@gmail.com> * lup context fix (#164) Signed-off-by: raver119 <raver119@gmail.com>master
parent
95b2686ce5
commit
d871eab2e5
|
@ -2278,6 +2278,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null;
|
LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null;
|
||||||
List<MemoryWorkspace>[] closeAtEndIteraton = (List<MemoryWorkspace>[])new List[topologicalOrder.length];
|
List<MemoryWorkspace>[] closeAtEndIteraton = (List<MemoryWorkspace>[])new List[topologicalOrder.length];
|
||||||
MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
|
MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
|
||||||
|
Throwable t = null;
|
||||||
try {
|
try {
|
||||||
for (int i = 0; i <= stopIndex; i++) {
|
for (int i = 0; i <= stopIndex; i++) {
|
||||||
GraphVertex current = vertices[topologicalOrder[i]];
|
GraphVertex current = vertices[topologicalOrder[i]];
|
||||||
|
@ -2436,6 +2437,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (Throwable t2){
|
||||||
|
t = t2;
|
||||||
} finally {
|
} finally {
|
||||||
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
|
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
|
||||||
//Though if stopIndex < numLayers, some might still be open
|
//Though if stopIndex < numLayers, some might still be open
|
||||||
|
@ -2444,7 +2447,15 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
//Edge case here: seems that scoping out can increase the tagScope of the current WS
|
//Edge case here: seems that scoping out can increase the tagScope of the current WS
|
||||||
//and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient
|
//and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient
|
||||||
// number of times to actually close it, in all cases
|
// number of times to actually close it, in all cases
|
||||||
|
try{
|
||||||
ws.close();
|
ws.close();
|
||||||
|
} catch (Throwable t2){
|
||||||
|
if(t != null){
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||||
|
@ -2581,6 +2592,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
|
|
||||||
boolean traceLog = log.isTraceEnabled();
|
boolean traceLog = log.isTraceEnabled();
|
||||||
|
|
||||||
|
Throwable t = null;
|
||||||
try {
|
try {
|
||||||
for (int i = topologicalOrder.length - 1; i >= 0; i--) {
|
for (int i = topologicalOrder.length - 1; i >= 0; i--) {
|
||||||
boolean hitFrozen = false;
|
boolean hitFrozen = false;
|
||||||
|
@ -2732,8 +2744,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
tempList.addFirst(new Triple<>(newName, entry.getValue(),
|
tempList.addFirst(new Triple<>(newName, entry.getValue(),
|
||||||
g.flatteningOrderForVariable(origName)));
|
g.flatteningOrderForVariable(origName)));
|
||||||
}
|
}
|
||||||
for (Triple<String, INDArray, Character> t : tempList)
|
for (Triple<String, INDArray, Character> triple : tempList)
|
||||||
gradients.addFirst(t);
|
gradients.addFirst(triple);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Close any activation gradient workspaces that we no longer require
|
//Close any activation gradient workspaces that we no longer require
|
||||||
|
@ -2752,19 +2764,28 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
|
log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (Throwable t2){
|
||||||
|
t = t2;
|
||||||
} finally {
|
} finally {
|
||||||
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
|
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
|
||||||
for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){
|
for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){
|
||||||
|
try{
|
||||||
ws.close();
|
ws.close();
|
||||||
|
} catch (Throwable t2){
|
||||||
|
if(t != null){
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Now, add the gradients in the order we need them in for flattening (same as params order)
|
//Now, add the gradients in the order we need them in for flattening (same as params order)
|
||||||
Gradient gradient = new DefaultGradient(flattenedGradients);
|
Gradient gradient = new DefaultGradient(flattenedGradients);
|
||||||
for (Triple<String, INDArray, Character> t : gradients) {
|
for (Triple<String, INDArray, Character> tr : gradients) {
|
||||||
gradient.setGradientFor(t.getFirst(), t.getSecond(), t.getThird());
|
gradient.setGradientFor(tr.getFirst(), tr.getSecond(), tr.getThird());
|
||||||
}
|
}
|
||||||
|
|
||||||
this.gradient = gradient;
|
this.gradient = gradient;
|
||||||
|
|
|
@ -1242,6 +1242,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
|
|
||||||
boolean traceLog = log.isTraceEnabled();
|
boolean traceLog = log.isTraceEnabled();
|
||||||
|
|
||||||
|
Throwable t = null;
|
||||||
try {
|
try {
|
||||||
for (int i = 0; i <= layerIndex; i++) {
|
for (int i = 0; i <= layerIndex; i++) {
|
||||||
LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd);
|
LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd);
|
||||||
|
@ -1328,17 +1329,34 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS
|
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (Throwable t2){
|
||||||
|
t = t2;
|
||||||
} finally {
|
} finally {
|
||||||
if(wsActCloseNext != null){
|
if(wsActCloseNext != null){
|
||||||
|
try {
|
||||||
wsActCloseNext.close();
|
wsActCloseNext.close();
|
||||||
|
} catch (Throwable t2){
|
||||||
|
if(t != null){
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if(temp != null){
|
if(temp != null){
|
||||||
//Should only be non-null on exception
|
//Should only be non-null on exception
|
||||||
while(temp.isScopeActive()){
|
while(temp.isScopeActive()){
|
||||||
//For safety, should never occur in theory: a single close() call may not be sufficient, if
|
//For safety, should never occur in theory: a single close() call may not be sufficient, if
|
||||||
// workspace scope was borrowed and not properly closed when exception occurred
|
// workspace scope was borrowed and not properly closed when exception occurred
|
||||||
|
try{
|
||||||
temp.close();
|
temp.close();
|
||||||
|
} catch (Throwable t2){
|
||||||
|
if(t != null){
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1871,6 +1889,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
|
|
||||||
boolean traceLog = log.isTraceEnabled();
|
boolean traceLog = log.isTraceEnabled();
|
||||||
|
|
||||||
|
Throwable t = null;
|
||||||
try {
|
try {
|
||||||
for (int i = layers.length - 1; i >= 0; i--) {
|
for (int i = layers.length - 1; i >= 0; i--) {
|
||||||
if (layers[i] instanceof FrozenLayer) {
|
if (layers[i] instanceof FrozenLayer) {
|
||||||
|
@ -1961,13 +1980,31 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName());
|
log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (Throwable thr ){
|
||||||
|
t = thr;
|
||||||
} finally {
|
} finally {
|
||||||
if(wsActGradCloseNext != null){
|
if(wsActGradCloseNext != null){
|
||||||
|
try {
|
||||||
wsActGradCloseNext.close();
|
wsActGradCloseNext.close();
|
||||||
|
} catch (Throwable t2){
|
||||||
|
if(t != null){
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if(wsActGradTemp != null) {
|
if(wsActGradTemp != null) {
|
||||||
//Should only be non-null on exception
|
//Should only be non-null on exception
|
||||||
|
try {
|
||||||
wsActGradTemp.close();
|
wsActGradTemp.close();
|
||||||
|
} catch (Throwable t2) {
|
||||||
|
if (t != null) {
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||||
}
|
}
|
||||||
|
|
|
@ -476,20 +476,37 @@ std::vector<Nd4jLong> NDArray::getShapeInfoAsVector() {
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
std::vector<int8_t> NDArray::asByteVector() {
|
std::vector<int8_t> NDArray::asByteVector() {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if (isS()) {
|
||||||
|
// string data type requires special treatment
|
||||||
|
syncToHost();
|
||||||
|
auto numWords = this->lengthOf();
|
||||||
|
auto offsetsBuffer = this->bufferAsT<Nd4jLong>();
|
||||||
|
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numWords);
|
||||||
|
auto dataLength = offsetsBuffer[numWords];
|
||||||
|
std::vector<int8_t> result(headerLength + dataLength);
|
||||||
|
|
||||||
|
memcpy(result.data(), getBuffer(), headerLength + dataLength);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
} else {
|
||||||
|
// all other types are linear
|
||||||
std::vector<int8_t> result((unsigned long long) this->lengthOf() * sizeOfT());
|
std::vector<int8_t> result((unsigned long long) this->lengthOf() * sizeOfT());
|
||||||
|
|
||||||
if (this->isView()) {
|
if (this->isView()) {
|
||||||
auto tmp = this->dup(this->ordering());
|
auto tmp = this->dup(this->ordering());
|
||||||
|
syncToHost();
|
||||||
memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
|
memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
|
||||||
|
|
||||||
delete tmp;
|
delete tmp;
|
||||||
}
|
} else {
|
||||||
else {
|
syncToHost();
|
||||||
memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
|
memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::linspace(const double start) {
|
void NDArray::linspace(const double start) {
|
||||||
|
@ -1584,9 +1601,7 @@ std::string* NDArray::bufferAsT() const {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T* NDArray::bufferAsT() const {
|
T* NDArray::bufferAsT() const {
|
||||||
if (isS())
|
// FIXME: do we REALLY want sync here?
|
||||||
throw std::runtime_error("You can't use this method on String array");
|
|
||||||
|
|
||||||
syncToHost();
|
syncToHost();
|
||||||
|
|
||||||
return reinterpret_cast<T*>(getBuffer());
|
return reinterpret_cast<T*>(getBuffer());
|
||||||
|
@ -3202,12 +3217,30 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
|
||||||
} else if (!shape::equalsSoft(getShapeInfo(), other->getShapeInfo()))
|
} else if (!shape::equalsSoft(getShapeInfo(), other->getShapeInfo()))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
if (isS()) {
|
||||||
|
// string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length
|
||||||
|
for (int e = 0; e < this->lengthOf(); e++) {
|
||||||
|
auto s1 = this->e<std::string>(e);
|
||||||
|
auto s2 = other->e<std::string>(e);
|
||||||
|
|
||||||
|
if (s1 != s2)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
// regular numeric types
|
||||||
NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0
|
NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0
|
||||||
|
|
||||||
ExtraArguments extras({eps});
|
ExtraArguments extras({eps});
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&tmp}, {this, other});
|
NDArray::prepareSpecialUse({&tmp}, {this, other});
|
||||||
NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo());
|
NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(),
|
||||||
|
getSpecialBuffer(), getSpecialShapeInfo(),
|
||||||
|
extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(),
|
||||||
|
other->getShapeInfo(), other->getSpecialBuffer(),
|
||||||
|
other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(),
|
||||||
|
tmp.specialBuffer(), tmp.specialShapeInfo());
|
||||||
NDArray::registerSpecialUse({&tmp}, {this, other});
|
NDArray::registerSpecialUse({&tmp}, {this, other});
|
||||||
|
|
||||||
synchronize("NDArray::equalsTo");
|
synchronize("NDArray::equalsTo");
|
||||||
|
@ -3217,6 +3250,7 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <>
|
template <>
|
||||||
|
|
|
@ -54,6 +54,7 @@
|
||||||
#include <graph/ExecutionResult.h>
|
#include <graph/ExecutionResult.h>
|
||||||
#include <exceptions/graph_execution_exception.h>
|
#include <exceptions/graph_execution_exception.h>
|
||||||
#include <exceptions/no_results_exception.h>
|
#include <exceptions/no_results_exception.h>
|
||||||
|
#include <graph/FlatUtils.h>
|
||||||
|
|
||||||
namespace nd4j{
|
namespace nd4j{
|
||||||
namespace graph {
|
namespace graph {
|
||||||
|
@ -575,15 +576,9 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
|
||||||
NDArray* array = var->getNDArray();
|
auto array = var->getNDArray();
|
||||||
auto byteVector = array->asByteVector();
|
|
||||||
|
|
||||||
auto fBuffer = builder.CreateVector(byteVector);
|
auto fArray = FlatUtils::toFlatArray(builder, *array);
|
||||||
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
|
|
||||||
|
|
||||||
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
|
|
||||||
|
|
||||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array->dataType()), bo);
|
|
||||||
|
|
||||||
auto fName = builder.CreateString(*(var->getName()));
|
auto fName = builder.CreateString(*(var->getName()));
|
||||||
auto id = CreateIntPair(builder, var->id(), var->index());
|
auto id = CreateIntPair(builder, var->id(), var->index());
|
||||||
|
|
|
@ -866,9 +866,10 @@ void initializeFunctions(Nd4jPointer *functions) {
|
||||||
Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) {
|
Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) {
|
||||||
Nd4jPointer pointer;
|
Nd4jPointer pointer;
|
||||||
// cudaHostAllocMapped |cudaHostAllocPortable
|
// cudaHostAllocMapped |cudaHostAllocPortable
|
||||||
cudaError_t res = cudaHostAlloc(reinterpret_cast<void **>(&pointer), memorySize, cudaHostAllocDefault);
|
auto res = cudaHostAlloc(reinterpret_cast<void **>(&pointer), memorySize, cudaHostAllocDefault);
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
pointer = 0L;
|
throw nd4j::cuda_exception::build("cudaHostAlloc(...) failed", res);
|
||||||
|
|
||||||
return pointer;
|
return pointer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -884,7 +885,7 @@ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) {
|
||||||
Nd4jPointer pointer;
|
Nd4jPointer pointer;
|
||||||
auto res = cudaMalloc(reinterpret_cast<void **>(&pointer), memorySize);
|
auto res = cudaMalloc(reinterpret_cast<void **>(&pointer), memorySize);
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
pointer = 0L;
|
throw nd4j::cuda_exception::build("cudaMalloc(...) failed", res);
|
||||||
return pointer;
|
return pointer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -894,9 +895,9 @@ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) {
|
||||||
* @param pointer pointer that'll be freed
|
* @param pointer pointer that'll be freed
|
||||||
*/
|
*/
|
||||||
int freeHost(Nd4jPointer pointer) {
|
int freeHost(Nd4jPointer pointer) {
|
||||||
cudaError_t res = cudaFreeHost(reinterpret_cast<void *>(pointer));
|
auto res = cudaFreeHost(reinterpret_cast<void *>(pointer));
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
pointer = 0L;
|
throw nd4j::cuda_exception::build("cudaFreeHost(...) failed", res);
|
||||||
return 1L;
|
return 1L;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -907,9 +908,10 @@ int freeHost(Nd4jPointer pointer) {
|
||||||
* @param ptrToDeviceId pointer to deviceId.
|
* @param ptrToDeviceId pointer to deviceId.
|
||||||
*/
|
*/
|
||||||
int freeDevice(Nd4jPointer pointer, int deviceId) {
|
int freeDevice(Nd4jPointer pointer, int deviceId) {
|
||||||
cudaError_t res = cudaFree(reinterpret_cast<void *>(pointer));
|
auto res = cudaFree(reinterpret_cast<void *>(pointer));
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
pointer = 0L;
|
throw nd4j::cuda_exception::build("cudaFree(...) failed", res);
|
||||||
|
|
||||||
return 1L;
|
return 1L;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -934,7 +936,7 @@ Nd4jPointer createStream() {
|
||||||
auto stream = new cudaStream_t();
|
auto stream = new cudaStream_t();
|
||||||
auto dZ = cudaStreamCreate(stream);
|
auto dZ = cudaStreamCreate(stream);
|
||||||
if (dZ != 0)
|
if (dZ != 0)
|
||||||
throw std::runtime_error("cudaStreamCreate(...) failed");
|
throw nd4j::cuda_exception::build("cudaStreamCreate(...) failed", dZ);
|
||||||
|
|
||||||
return stream;
|
return stream;
|
||||||
}
|
}
|
||||||
|
@ -944,23 +946,21 @@ Nd4jPointer createEvent() {
|
||||||
|
|
||||||
CHECK_ALLOC(nativeEvent, "Failed to allocate new CUDA event buffer", sizeof(cudaEvent_t));
|
CHECK_ALLOC(nativeEvent, "Failed to allocate new CUDA event buffer", sizeof(cudaEvent_t));
|
||||||
|
|
||||||
cudaError_t dZ = cudaEventCreateWithFlags(reinterpret_cast<cudaEvent_t *>(&nativeEvent), cudaEventDisableTiming);
|
auto dZ = cudaEventCreateWithFlags(reinterpret_cast<cudaEvent_t *>(&nativeEvent), cudaEventDisableTiming);
|
||||||
checkCudaErrors(dZ);
|
|
||||||
if (dZ != 0)
|
if (dZ != 0)
|
||||||
throw std::runtime_error("cudaEventCreateWithFlags(...) failed");
|
throw nd4j::cuda_exception::build("cudaEventCreateWithFlags(...) failed", dZ);
|
||||||
|
|
||||||
|
|
||||||
return nativeEvent;
|
return nativeEvent;
|
||||||
}
|
}
|
||||||
|
|
||||||
int registerEvent(Nd4jPointer event, Nd4jPointer stream) {
|
int registerEvent(Nd4jPointer event, Nd4jPointer stream) {
|
||||||
cudaEvent_t *pEvent = reinterpret_cast<cudaEvent_t *>(&event);
|
auto pEvent = reinterpret_cast<cudaEvent_t *>(&event);
|
||||||
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(stream);
|
auto pStream = reinterpret_cast<cudaStream_t *>(stream);
|
||||||
|
|
||||||
cudaError_t dZ = cudaEventRecord(*pEvent, *pStream);
|
auto dZ = cudaEventRecord(*pEvent, *pStream);
|
||||||
checkCudaErrors(dZ);
|
|
||||||
if (dZ != 0)
|
if (dZ != 0)
|
||||||
throw std::runtime_error("cudaEventRecord(...) failed");
|
throw nd4j::cuda_exception::build("cudaEventRecord(...) failed", dZ);
|
||||||
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -1065,53 +1065,48 @@ int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4j
|
||||||
}
|
}
|
||||||
|
|
||||||
int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
||||||
cudaError_t dZ = cudaMemset(reinterpret_cast<void *>(dst), value, static_cast<size_t>(size));
|
auto dZ = cudaMemset(reinterpret_cast<void *>(dst), value, static_cast<size_t>(size));
|
||||||
checkCudaErrors(dZ);
|
|
||||||
if (dZ != 0)
|
if (dZ != 0)
|
||||||
throw std::runtime_error("cudaMemset(...) failed");
|
throw nd4j::cuda_exception::build("cudaMemset(...) failed", dZ);
|
||||||
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
||||||
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(reserved);
|
auto pStream = reinterpret_cast<cudaStream_t *>(reserved);
|
||||||
|
|
||||||
cudaError_t dZ = cudaMemsetAsync(reinterpret_cast<void *>(dst), value, static_cast<size_t>(size), *pStream);
|
auto dZ = cudaMemsetAsync(reinterpret_cast<void *>(dst), value, static_cast<size_t>(size), *pStream);
|
||||||
checkCudaErrors(dZ);
|
|
||||||
if (dZ != 0)
|
if (dZ != 0)
|
||||||
throw std::runtime_error("cudaMemsetAsync(...) failed");
|
throw nd4j::cuda_exception::build("cudaMemsetAsync(...) failed", dZ);
|
||||||
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
int destroyEvent(Nd4jPointer event) {
|
int destroyEvent(Nd4jPointer event) {
|
||||||
cudaEvent_t *pEvent = reinterpret_cast<cudaEvent_t *>(&event);
|
auto pEvent = reinterpret_cast<cudaEvent_t *>(&event);
|
||||||
cudaError_t dZ = cudaEventDestroy(*pEvent);
|
auto dZ = cudaEventDestroy(*pEvent);
|
||||||
checkCudaErrors(dZ);
|
|
||||||
if (dZ != 0)
|
if (dZ != 0)
|
||||||
throw std::runtime_error("cudaEvenDestroy(...) failed");
|
throw nd4j::cuda_exception::build("cudaEvenDestroy(...) failed", dZ);
|
||||||
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
int streamSynchronize(Nd4jPointer stream) {
|
int streamSynchronize(Nd4jPointer stream) {
|
||||||
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(stream);
|
auto pStream = reinterpret_cast<cudaStream_t *>(stream);
|
||||||
|
|
||||||
cudaError_t dZ = cudaStreamSynchronize(*pStream);
|
auto dZ = cudaStreamSynchronize(*pStream);
|
||||||
checkCudaErrors(dZ);
|
|
||||||
if (dZ != 0)
|
if (dZ != 0)
|
||||||
throw std::runtime_error("cudaStreamSynchronize(...) failed");
|
throw nd4j::cuda_exception::build("cudaStreamSynchronize(...) failed", dZ);
|
||||||
|
|
||||||
return 1L;
|
return 1L;
|
||||||
}
|
}
|
||||||
|
|
||||||
int eventSynchronize(Nd4jPointer event) {
|
int eventSynchronize(Nd4jPointer event) {
|
||||||
cudaEvent_t *pEvent = reinterpret_cast<cudaEvent_t *>(&event);
|
auto pEvent = reinterpret_cast<cudaEvent_t *>(&event);
|
||||||
|
|
||||||
cudaError_t dZ = cudaEventSynchronize(*pEvent);
|
auto dZ = cudaEventSynchronize(*pEvent);
|
||||||
checkCudaErrors(dZ);
|
|
||||||
if (dZ != 0)
|
if (dZ != 0)
|
||||||
throw std::runtime_error("cudaEventSynchronize(...) failed");
|
throw nd4j::cuda_exception::build("cudaEventSynchronize(...) failed", dZ);
|
||||||
|
|
||||||
return 1L;
|
return 1L;
|
||||||
}
|
}
|
||||||
|
@ -2697,13 +2692,16 @@ int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opConte
|
||||||
|
|
||||||
auto result = op->execute(context);
|
auto result = op->execute(context);
|
||||||
|
|
||||||
// FIXME: remove once CUDA backend is 100% ready
|
auto res = cudaStreamSynchronize(*context->launchContext()->getCudaStream());
|
||||||
|
if (res != 0)
|
||||||
|
throw nd4j::cuda_exception::build("customOp execution failed", res);
|
||||||
|
|
||||||
for (auto v:context->fastpath_in()) {
|
for (auto v:context->fastpath_in()) {
|
||||||
v->makeBothActual();
|
v->syncToDevice();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto v:context->fastpath_out()) {
|
for (auto v:context->fastpath_out()) {
|
||||||
v->makeBothActual();
|
v->syncToDevice();
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|
|
@ -36,6 +36,8 @@ namespace nd4j {
|
||||||
static std::pair<Nd4jLong, Nd4jLong> fromLongPair(LongPair* pair);
|
static std::pair<Nd4jLong, Nd4jLong> fromLongPair(LongPair* pair);
|
||||||
|
|
||||||
static NDArray* fromFlatArray(const nd4j::graph::FlatArray* flatArray);
|
static NDArray* fromFlatArray(const nd4j::graph::FlatArray* flatArray);
|
||||||
|
|
||||||
|
static flatbuffers::Offset<FlatArray> toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -102,5 +102,16 @@ namespace nd4j {
|
||||||
delete[] newShape;
|
delete[] newShape;
|
||||||
return array;
|
return array;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
flatbuffers::Offset<FlatArray> FlatUtils::toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array) {
|
||||||
|
auto byteVector = array.asByteVector();
|
||||||
|
|
||||||
|
auto fBuffer = builder.CreateVector(byteVector);
|
||||||
|
auto fShape = builder.CreateVector(array.getShapeInfoAsFlatVector());
|
||||||
|
|
||||||
|
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
|
||||||
|
|
||||||
|
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array.dataType()), bo);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -26,7 +26,6 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void swapRows_(NDArray* matrix, int theFirst, int theSecond) {
|
static void swapRows_(NDArray* matrix, int theFirst, int theSecond) {
|
||||||
|
@ -108,14 +107,14 @@ namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static NDArray lup_(NDArray* input, NDArray* compound, NDArray* permutation) {
|
static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) {
|
||||||
|
|
||||||
const int rowNum = input->rows();
|
const int rowNum = input->rows();
|
||||||
const int columnNum = input->columns();
|
const int columnNum = input->columns();
|
||||||
|
|
||||||
NDArray determinant = NDArrayFactory::create<T>(1.f);
|
NDArray determinant = NDArrayFactory::create<T>(1.f);
|
||||||
NDArray compoundMatrix = *input; // copy
|
NDArray compoundMatrix = *input; // copy
|
||||||
NDArray permutationMatrix(input, false, defaultContext); // has same shape as input and contiguous strides
|
NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides
|
||||||
permutationMatrix.setIdentity();
|
permutationMatrix.setIdentity();
|
||||||
|
|
||||||
T pivotValue; // = T(0.0);
|
T pivotValue; // = T(0.0);
|
||||||
|
@ -161,46 +160,43 @@ namespace helpers {
|
||||||
return determinant;
|
return determinant;
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template NDArray lup_, (NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static int determinant_(NDArray* input, NDArray* output) {
|
static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||||
|
|
||||||
Nd4jLong n = input->sizeAt(-1);
|
Nd4jLong n = input->sizeAt(-1);
|
||||||
Nd4jLong n2 = n * n;
|
Nd4jLong n2 = n * n;
|
||||||
|
|
||||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
|
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
|
||||||
|
|
||||||
for (int e = 0; e < output->lengthOf(); e++) {
|
for (int e = 0; e < output->lengthOf(); e++) {
|
||||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row)
|
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row)
|
||||||
matrix.p(row, input->e<T>(k));
|
matrix.p(row, input->e<T>(k));
|
||||||
output->p(e, lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr));
|
output->p(e, lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template int determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
|
|
||||||
|
|
||||||
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||||
defaultContext = context;
|
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES);
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (input, output), FLOAT_TYPES);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int logAbsDeterminant_(NDArray* input, NDArray* output) {
|
int logAbsDeterminant_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||||
|
|
||||||
Nd4jLong n = input->sizeAt(-1);
|
Nd4jLong n = input->sizeAt(-1);
|
||||||
Nd4jLong n2 = n * n;
|
Nd4jLong n2 = n * n;
|
||||||
|
|
||||||
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
|
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
|
||||||
for (int e = 0; e < output->lengthOf(); e++) {
|
for (int e = 0; e < output->lengthOf(); e++) {
|
||||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) {
|
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) {
|
||||||
matrix.p(row, input->e<T>(k));
|
matrix.p(row, input->e<T>(k));
|
||||||
}
|
}
|
||||||
NDArray det = lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr);
|
NDArray det = lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr);
|
||||||
if (det.e<T>(0) != 0.f)
|
if (det.e<T>(0) != 0.f)
|
||||||
output->p(e, nd4j::math::nd4j_log<T,T>(nd4j::math::nd4j_abs(det.t<T>(0))));
|
output->p(e, nd4j::math::nd4j_log<T,T>(nd4j::math::nd4j_abs(det.t<T>(0))));
|
||||||
}
|
}
|
||||||
|
@ -208,25 +204,23 @@ template <typename T>
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
|
|
||||||
|
|
||||||
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (input, output), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static int inverse_(NDArray* input, NDArray* output) {
|
static int inverse_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||||
|
|
||||||
auto n = input->sizeAt(-1);
|
auto n = input->sizeAt(-1);
|
||||||
auto n2 = n * n;
|
auto n2 = n * n;
|
||||||
auto totalCount = output->lengthOf() / n2;
|
auto totalCount = output->lengthOf() / n2;
|
||||||
|
|
||||||
output->assign(0.f); // fill up output tensor with zeros
|
output->assign(0.f); // fill up output tensor with zeros
|
||||||
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
|
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
||||||
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
|
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
||||||
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
|
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
|
||||||
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
|
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
|
||||||
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
|
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
|
||||||
|
|
||||||
for (int e = 0; e < totalCount; e++) {
|
for (int e = 0; e < totalCount; e++) {
|
||||||
if (e)
|
if (e)
|
||||||
|
@ -235,7 +229,7 @@ template <typename T>
|
||||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
|
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
|
||||||
matrix.p(row++, input->e<T>(k));
|
matrix.p(row++, input->e<T>(k));
|
||||||
}
|
}
|
||||||
T det = lup_<T>(&matrix, &compound, &permutation).template e<T>(0);
|
T det = lup_<T>(context, &matrix, &compound, &permutation).template e<T>(0);
|
||||||
|
|
||||||
// FIXME: and how this is going to work on float16?
|
// FIXME: and how this is going to work on float16?
|
||||||
if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) {
|
if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) {
|
||||||
|
@ -268,8 +262,7 @@ template <typename T>
|
||||||
}
|
}
|
||||||
|
|
||||||
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||||
defaultContext = context;
|
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES);
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -296,14 +289,13 @@ template <typename T>
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template bool checkCholeskyInput_, (nd4j::LaunchContext * context, NDArray const* input), FLOAT_TYPES);
|
|
||||||
|
|
||||||
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) {
|
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, (context, input), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, (context, input), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int cholesky_(NDArray* input, NDArray* output, bool inplace) {
|
int cholesky_(LaunchContext *context, NDArray* input, NDArray* output, bool inplace) {
|
||||||
|
|
||||||
auto n = input->sizeAt(-1);
|
auto n = input->sizeAt(-1);
|
||||||
auto n2 = n * n;
|
auto n2 = n * n;
|
||||||
|
@ -311,8 +303,8 @@ template <typename T>
|
||||||
if (!inplace)
|
if (!inplace)
|
||||||
output->assign(0.f); // fill up output tensor with zeros only inplace=false
|
output->assign(0.f); // fill up output tensor with zeros only inplace=false
|
||||||
|
|
||||||
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), defaultContext)); //, block.getWorkspace());
|
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), context)); //, block.getWorkspace());
|
||||||
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), defaultContext));
|
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), context));
|
||||||
|
|
||||||
for (int e = 0; e < totalCount; e++) {
|
for (int e = 0; e < totalCount; e++) {
|
||||||
|
|
||||||
|
@ -346,14 +338,13 @@ template <typename T>
|
||||||
}
|
}
|
||||||
|
|
||||||
int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) {
|
int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) {
|
||||||
defaultContext = context;
|
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int logdetFunctor_(NDArray* input, NDArray* output) {
|
int logdetFunctor_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||||
std::unique_ptr<NDArray> tempOutput(input->dup());
|
std::unique_ptr<NDArray> tempOutput(input->dup());
|
||||||
int res = cholesky_<T>(input, tempOutput.get(), false);
|
int res = cholesky_<T>(context, input, tempOutput.get(), false);
|
||||||
if (res != ND4J_STATUS_OK)
|
if (res != ND4J_STATUS_OK)
|
||||||
return res;
|
return res;
|
||||||
auto n = input->sizeAt(-1);
|
auto n = input->sizeAt(-1);
|
||||||
|
@ -372,7 +363,7 @@ template <typename T>
|
||||||
}
|
}
|
||||||
|
|
||||||
int logdetFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
int logdetFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (input, output), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -907,6 +907,8 @@ __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInf
|
||||||
|
|
||||||
/*** max ***/
|
/*** max ***/
|
||||||
case 0: {
|
case 0: {
|
||||||
|
coord2 = hstart;
|
||||||
|
coord3 = hend;
|
||||||
|
|
||||||
T max = -DataTypeUtils::max<T>();
|
T max = -DataTypeUtils::max<T>();
|
||||||
for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) {
|
for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) {
|
||||||
|
|
|
@ -31,8 +31,6 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
|
|
||||||
|
|
||||||
// template <typename T>
|
// template <typename T>
|
||||||
// static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
|
// static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
|
||||||
// if (theFirst != theSecond) {
|
// if (theFirst != theSecond) {
|
||||||
|
@ -198,36 +196,33 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void invertLowerMatrix_(NDArray *inputMatrix, NDArray *invertedMatrix) {
|
static void invertLowerMatrix_(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||||
int n = inputMatrix->rows();
|
int n = inputMatrix->rows();
|
||||||
invertedMatrix->setIdentity();
|
invertedMatrix->setIdentity();
|
||||||
|
|
||||||
if (inputMatrix->isIdentityMatrix()) return;
|
if (inputMatrix->isIdentityMatrix()) return;
|
||||||
|
|
||||||
auto stream = defaultContext->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
|
|
||||||
// invert main diagonal
|
// invert main diagonal
|
||||||
upvertKernel<T> << < 1, n, 512, *stream >> >
|
upvertKernel<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
|
||||||
// invert the second diagonal
|
// invert the second diagonal
|
||||||
invertKernelLow<T> << < 1, n, 512, *stream >> >
|
invertKernelLow<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
|
||||||
// invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
// invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
invertLowKernel<T><<< n, n, 512, *stream >> >
|
invertLowKernel<T><<<n, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void invertLowerMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) {
|
void invertLowerMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||||
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
||||||
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
||||||
NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix});
|
NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix});
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
static void invertUpperMatrix_(LaunchContext *context, NDArray* inputMatrix, NDArray* invertedMatrix) {
|
||||||
int n = inputMatrix->rows();
|
int n = inputMatrix->rows();
|
||||||
invertedMatrix->setIdentity();
|
invertedMatrix->setIdentity();
|
||||||
auto stream = defaultContext->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I
|
if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -237,13 +232,12 @@ namespace helpers {
|
||||||
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
invertedMatrix->tickWriteDevice();
|
invertedMatrix->tickWriteDevice();
|
||||||
invertedMatrix->printIndexedBuffer("Step1 UP inversion");
|
invertedMatrix->printIndexedBuffer("Step1 UP inversion");
|
||||||
invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),
|
invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void invertUpperMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) {
|
void invertUpperMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||||
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
||||||
BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
||||||
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -392,7 +386,6 @@ namespace helpers {
|
||||||
auto n = input->rows();
|
auto n = input->rows();
|
||||||
cusolverDnHandle_t cusolverH = nullptr;
|
cusolverDnHandle_t cusolverH = nullptr;
|
||||||
cusolverStatus_t status = cusolverDnCreate(&cusolverH);
|
cusolverStatus_t status = cusolverDnCreate(&cusolverH);
|
||||||
defaultContext = context;
|
|
||||||
if (CUSOLVER_STATUS_SUCCESS != status) {
|
if (CUSOLVER_STATUS_SUCCESS != status) {
|
||||||
throw cuda_exception::build("Cannot create cuSolver handle", status);
|
throw cuda_exception::build("Cannot create cuSolver handle", status);
|
||||||
}
|
}
|
||||||
|
@ -528,24 +521,19 @@ namespace helpers {
|
||||||
input->tickWriteDevice();
|
input->tickWriteDevice();
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void lup_,
|
BUILD_SINGLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE);
|
||||||
(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation),
|
|
||||||
FLOAT_NATIVE);
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||||
Nd4jLong n = input->sizeAt(-1);
|
Nd4jLong n = input->sizeAt(-1);
|
||||||
Nd4jLong n2 = n * n;
|
Nd4jLong n2 = n * n;
|
||||||
std::vector<int> dims();
|
std::vector<int> dims();
|
||||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
|
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
|
||||||
{input->rankOf() - 2, input->rankOf() - 1});
|
|
||||||
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
|
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
|
||||||
// DataType dtype = input->dataType();
|
// DataType dtype = input->dataType();
|
||||||
// if (dtype != DataType::DOUBLE)
|
// if (dtype != DataType::DOUBLE)
|
||||||
// dtype = DataType::FLOAT32;
|
// dtype = DataType::FLOAT32;
|
||||||
defaultContext = context;
|
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
||||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(),
|
|
||||||
defaultContext); //, block.getWorkspace());
|
|
||||||
auto det = NDArrayFactory::create<T>(1);
|
auto det = NDArrayFactory::create<T>(1);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
@ -554,8 +542,7 @@ namespace helpers {
|
||||||
for (int e = 0; e < output->lengthOf(); e++) {
|
for (int e = 0; e < output->lengthOf(); e++) {
|
||||||
Nd4jLong pos = e * n2;
|
Nd4jLong pos = e * n2;
|
||||||
// if (matrix.dataType() == input->dataType())
|
// if (matrix.dataType() == input->dataType())
|
||||||
fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
|
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||||
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
|
||||||
// else
|
// else
|
||||||
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||||
|
|
||||||
|
@ -578,7 +565,6 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
int determinant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
int determinant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||||
defaultContext = context;
|
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE);
|
||||||
NDArray::registerSpecialUse({output}, {input});
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
|
@ -586,19 +572,16 @@ namespace helpers {
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) {
|
int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) {
|
||||||
defaultContext = context;
|
|
||||||
Nd4jLong n = input->sizeAt(-1);
|
Nd4jLong n = input->sizeAt(-1);
|
||||||
Nd4jLong n2 = n * n;
|
Nd4jLong n2 = n * n;
|
||||||
std::vector<int> dims();
|
std::vector<int> dims();
|
||||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
|
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
|
||||||
{input->rankOf() - 2, input->rankOf() - 1});
|
|
||||||
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
|
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
|
||||||
DataType dtype = input->dataType();
|
DataType dtype = input->dataType();
|
||||||
if (dtype != DataType::DOUBLE)
|
if (dtype != DataType::DOUBLE)
|
||||||
dtype = DataType::FLOAT32;
|
dtype = DataType::FLOAT32;
|
||||||
|
|
||||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype,
|
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace());
|
||||||
defaultContext); //, block.getWorkspace());
|
|
||||||
auto det = NDArrayFactory::create<T>(1);
|
auto det = NDArrayFactory::create<T>(1);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
@ -607,8 +590,7 @@ namespace helpers {
|
||||||
for (int e = 0; e < output->lengthOf(); e++) {
|
for (int e = 0; e < output->lengthOf(); e++) {
|
||||||
Nd4jLong pos = e * n2;
|
Nd4jLong pos = e * n2;
|
||||||
// if (matrix.dataType() == input->dataType())
|
// if (matrix.dataType() == input->dataType())
|
||||||
fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
|
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||||
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
|
||||||
// else
|
// else
|
||||||
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||||
|
|
||||||
|
@ -620,8 +602,7 @@ namespace helpers {
|
||||||
auto inputBuf = reinterpret_cast<T *>(matrix.specialBuffer());
|
auto inputBuf = reinterpret_cast<T *>(matrix.specialBuffer());
|
||||||
auto outputBuf = reinterpret_cast<T *>(output->specialBuffer()) + offset;
|
auto outputBuf = reinterpret_cast<T *>(output->specialBuffer()) + offset;
|
||||||
// if (matrix.dataType() == input->dataType())
|
// if (matrix.dataType() == input->dataType())
|
||||||
determinantLogKernel<T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
|
determinantLogKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(inputBuf, outputBuf, n);
|
||||||
(inputBuf, outputBuf, n);
|
|
||||||
// else
|
// else
|
||||||
// determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
// determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||||
}
|
}
|
||||||
|
@ -633,7 +614,6 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
int logAbsDeterminant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
int logAbsDeterminant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||||
defaultContext = context;
|
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE);
|
||||||
NDArray::registerSpecialUse({output}, {input});
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
|
@ -696,17 +676,16 @@ namespace helpers {
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static int inverse_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
static int inverse_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||||
defaultContext = context;
|
|
||||||
auto n = input->sizeAt(-1);
|
auto n = input->sizeAt(-1);
|
||||||
auto n2 = n * n;
|
auto n2 = n * n;
|
||||||
auto dtype = DataTypeUtils::fromT<T>(); //input->dataType();
|
auto dtype = DataTypeUtils::fromT<T>(); //input->dataType();
|
||||||
// if (dtype != DataType::DOUBLE)
|
// if (dtype != DataType::DOUBLE)
|
||||||
// dtype = DataType::FLOAT32;
|
// dtype = DataType::FLOAT32;
|
||||||
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||||
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||||
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||||
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||||
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
|
||||||
{input->rankOf() - 2,
|
{input->rankOf() - 2,
|
||||||
input->rankOf() - 1});
|
input->rankOf() - 1});
|
||||||
|
@ -716,20 +695,17 @@ namespace helpers {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
|
|
||||||
for (auto i = 0LL; i < packX.numberOfTads(); i++) {
|
for (auto i = 0LL; i < packX.numberOfTads(); i++) {
|
||||||
fillMatrix<T, T> << < 1, n2, 1024, *stream >> >
|
fillMatrix<T, T><<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n);
|
||||||
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(),
|
|
||||||
i * n2, n);
|
|
||||||
matrix.tickWriteDevice();
|
matrix.tickWriteDevice();
|
||||||
compound.assign(matrix);
|
compound.assign(matrix);
|
||||||
lup_<T>(context, &compound, nullptr, nullptr);
|
lup_<T>(context, &compound, nullptr, nullptr);
|
||||||
fillLowerUpperKernel<T> << < n, n, 1024, *stream >> >
|
fillLowerUpperKernel<T><<<n, n, 1024, *stream>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
|
||||||
(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
|
|
||||||
matrix.assign(0);
|
matrix.assign(0);
|
||||||
invertUpperMatrix(&upper, &matrix); // U^{-1}
|
invertUpperMatrix(context, &upper, &matrix); // U^{-1}
|
||||||
matrix.tickWriteDevice();
|
matrix.tickWriteDevice();
|
||||||
// matrix.printIndexedBuffer("Upper Inverted");
|
// matrix.printIndexedBuffer("Upper Inverted");
|
||||||
compound.assign(0);
|
compound.assign(0);
|
||||||
invertLowerMatrix(&lower, &compound); // L{-1}
|
invertLowerMatrix(context, &lower, &compound); // L{-1}
|
||||||
compound.tickWriteDevice();
|
compound.tickWriteDevice();
|
||||||
// compound.printIndexedBuffer("Lower Inverted");
|
// compound.printIndexedBuffer("Lower Inverted");
|
||||||
// matrix.tickWriteDevice();
|
// matrix.tickWriteDevice();
|
||||||
|
@ -737,15 +713,12 @@ namespace helpers {
|
||||||
nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0);
|
nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0);
|
||||||
upper.tickWriteDevice();
|
upper.tickWriteDevice();
|
||||||
// upper.printIndexedBuffer("Full inverted");
|
// upper.printIndexedBuffer("Full inverted");
|
||||||
returnMatrix<T> << < 1, n2, 1024, *stream >> >
|
returnMatrix<T> <<<1, n2, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n);
|
||||||
(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(),
|
|
||||||
i * n2, n);
|
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
int inverse(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
int inverse(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||||
defaultContext = context;
|
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE);
|
||||||
NDArray::registerSpecialUse({output}, {input});
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
|
@ -788,7 +761,6 @@ namespace helpers {
|
||||||
int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
||||||
if (!inplace)
|
if (!inplace)
|
||||||
output->assign(input);
|
output->assign(input);
|
||||||
defaultContext = context;
|
|
||||||
std::unique_ptr<NDArray> tempOutput(output->dup());
|
std::unique_ptr<NDArray> tempOutput(output->dup());
|
||||||
cusolverDnHandle_t handle = nullptr;
|
cusolverDnHandle_t handle = nullptr;
|
||||||
auto n = input->sizeAt(-1);
|
auto n = input->sizeAt(-1);
|
||||||
|
@ -868,7 +840,6 @@ namespace helpers {
|
||||||
|
|
||||||
// template <typename T>
|
// template <typename T>
|
||||||
int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
||||||
defaultContext = context;
|
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
if (input->dataType() == DataType::DOUBLE)
|
if (input->dataType() == DataType::DOUBLE)
|
||||||
cholesky__<double>(context, input, output, inplace);
|
cholesky__<double>(context, input, output, inplace);
|
||||||
|
@ -876,8 +847,7 @@ namespace helpers {
|
||||||
cholesky__<float>(context, input, output, inplace);
|
cholesky__<float>(context, input, output, inplace);
|
||||||
else {
|
else {
|
||||||
std::unique_ptr<NDArray> tempOutput(
|
std::unique_ptr<NDArray> tempOutput(
|
||||||
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32,
|
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, context));
|
||||||
defaultContext));
|
|
||||||
tempOutput->assign(input);
|
tempOutput->assign(input);
|
||||||
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
|
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
|
||||||
output->assign(tempOutput.get());
|
output->assign(tempOutput.get());
|
||||||
|
@ -888,7 +858,6 @@ namespace helpers {
|
||||||
|
|
||||||
int cholesky(nd4j::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
int cholesky(nd4j::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
||||||
// BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
|
// BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
|
||||||
defaultContext = context;
|
|
||||||
return cholesky_(context, input, output, inplace);
|
return cholesky_(context, input, output, inplace);
|
||||||
}
|
}
|
||||||
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
|
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
|
||||||
|
@ -927,7 +896,6 @@ namespace helpers {
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
int logdetFunctor_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
int logdetFunctor_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||||
defaultContext = context;
|
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
auto n2 = input->sizeAt(-1) * input->sizeAt(-2);
|
auto n2 = input->sizeAt(-1) * input->sizeAt(-2);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
|
@ -957,7 +925,6 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||||
defaultContext = context;
|
|
||||||
BUILD_SINGLE_SELECTOR(output->dataType(), logdetFunctor_, (context, input, output), FLOAT_NATIVE);
|
BUILD_SINGLE_SELECTOR(output->dataType(), logdetFunctor_, (context, input, output), FLOAT_NATIVE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,100 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <NDArray.h>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <graph/Stash.h>
|
||||||
|
#include <FlatUtils.h>
|
||||||
|
|
||||||
|
using namespace nd4j;
|
||||||
|
|
||||||
|
class FlatUtilsTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(FlatUtilsTests, flat_float_serde_1) {
|
||||||
|
auto array = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 3.f, 4.f});
|
||||||
|
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||||
|
builder.Finish(flatArray);
|
||||||
|
|
||||||
|
|
||||||
|
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||||
|
|
||||||
|
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||||
|
|
||||||
|
ASSERT_EQ(array, *restored);
|
||||||
|
|
||||||
|
delete restored;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FlatUtilsTests, flat_int_serde_1) {
|
||||||
|
auto array = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
|
||||||
|
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||||
|
builder.Finish(flatArray);
|
||||||
|
|
||||||
|
|
||||||
|
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||||
|
|
||||||
|
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||||
|
|
||||||
|
ASSERT_EQ(array, *restored);
|
||||||
|
|
||||||
|
delete restored;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FlatUtilsTests, flat_bool_serde_1) {
|
||||||
|
auto array = NDArrayFactory::create<bool>('c', {4}, {true, false, true, false});
|
||||||
|
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||||
|
builder.Finish(flatArray);
|
||||||
|
|
||||||
|
|
||||||
|
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||||
|
|
||||||
|
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||||
|
|
||||||
|
ASSERT_EQ(array, *restored);
|
||||||
|
|
||||||
|
delete restored;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FlatUtilsTests, flat_string_serde_1) {
|
||||||
|
auto array = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"});
|
||||||
|
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||||
|
builder.Finish(flatArray);
|
||||||
|
|
||||||
|
|
||||||
|
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||||
|
|
||||||
|
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||||
|
|
||||||
|
ASSERT_EQ(array, *restored);
|
||||||
|
|
||||||
|
delete restored;
|
||||||
|
}
|
|
@ -24,7 +24,6 @@
|
||||||
#include "testlayers.h"
|
#include "testlayers.h"
|
||||||
#include <graph/Stash.h>
|
#include <graph/Stash.h>
|
||||||
|
|
||||||
using namespace nd4j;
|
|
||||||
using namespace nd4j;
|
using namespace nd4j;
|
||||||
|
|
||||||
class StringTests : public testing::Test {
|
class StringTests : public testing::Test {
|
||||||
|
|
|
@ -31,10 +31,35 @@
|
||||||
|
|
||||||
<build>
|
<build>
|
||||||
<plugins>
|
<plugins>
|
||||||
|
|
||||||
|
<!-- AB 2019/08/24 This plugin is to be added TEMPORARILY due to a change in the filenames of the generated ONNX -->
|
||||||
|
<!-- Normal "mvn clean" etc won't delete these files, and any users who have built ND4J even once before the
|
||||||
|
change will run into a compilation error. This can be removed after a few weeks.-->
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-antrun-plugin</artifactId>
|
||||||
|
<version>1.8</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<phase>generate-sources</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>run</goal>
|
||||||
|
</goals>
|
||||||
|
<configuration>
|
||||||
|
<target>
|
||||||
|
<delete file="${project.build.sourceDirectory}/onnx/OnnxMlProto3.java" />
|
||||||
|
<delete file="${project.build.sourceDirectory}/onnx/OnnxOperatorsProto3.java" />
|
||||||
|
<delete file="${project.build.sourceDirectory}/onnx/OnnxProto3.java" />
|
||||||
|
</target>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>com.github.os72</groupId>
|
<groupId>com.github.os72</groupId>
|
||||||
<artifactId>protoc-jar-maven-plugin</artifactId>
|
<artifactId>protoc-jar-maven-plugin</artifactId>
|
||||||
<version>3.5.1.1</version>
|
<version>3.8.0</version>
|
||||||
<executions>
|
<executions>
|
||||||
<execution>
|
<execution>
|
||||||
<id>tensorflow</id>
|
<id>tensorflow</id>
|
||||||
|
@ -43,30 +68,14 @@
|
||||||
<goal>run</goal>
|
<goal>run</goal>
|
||||||
</goals>
|
</goals>
|
||||||
<configuration>
|
<configuration>
|
||||||
<type>java-shaded</type>
|
<protocVersion>3.8.0</protocVersion>
|
||||||
<protocVersion>3.5.1</protocVersion>
|
<extension>.proto</extension>
|
||||||
<includeDirectories>
|
<includeDirectories>
|
||||||
<include>src/main/protobuf/tf</include>
|
<include>src/main/protobuf/tf</include>
|
||||||
|
<include>src/main/protobuf/onnx</include>
|
||||||
</includeDirectories>
|
</includeDirectories>
|
||||||
<inputDirectories>
|
<inputDirectories>
|
||||||
<include>src/main/protobuf/tf/tensorflow</include>
|
<include>src/main/protobuf/tf/tensorflow</include>
|
||||||
</inputDirectories>
|
|
||||||
<addSources>main</addSources>
|
|
||||||
<cleanOutputFolder>false</cleanOutputFolder>
|
|
||||||
<outputDirectory>src/main/java/</outputDirectory>
|
|
||||||
</configuration>
|
|
||||||
</execution>
|
|
||||||
<execution>
|
|
||||||
<id>onnx</id>
|
|
||||||
<phase>generate-sources</phase>
|
|
||||||
<goals>
|
|
||||||
<goal>run</goal>
|
|
||||||
</goals>
|
|
||||||
<configuration>
|
|
||||||
<type>java-shaded</type>
|
|
||||||
<extension>.proto3</extension>
|
|
||||||
<protocVersion>3.5.1</protocVersion>
|
|
||||||
<inputDirectories>
|
|
||||||
<include>src/main/protobuf/onnx</include>
|
<include>src/main/protobuf/onnx</include>
|
||||||
</inputDirectories>
|
</inputDirectories>
|
||||||
<addSources>main</addSources>
|
<addSources>main</addSources>
|
||||||
|
@ -76,6 +85,32 @@
|
||||||
</execution>
|
</execution>
|
||||||
</executions>
|
</executions>
|
||||||
</plugin>
|
</plugin>
|
||||||
|
|
||||||
|
<plugin>
|
||||||
|
<groupId>com.google.code.maven-replacer-plugin</groupId>
|
||||||
|
<artifactId>replacer</artifactId>
|
||||||
|
<version>1.5.3</version>
|
||||||
|
<configuration>
|
||||||
|
<includes>
|
||||||
|
<include>${project.build.sourceDirectory}/org/tensorflow/**</include>
|
||||||
|
<include>${project.build.sourceDirectory}/tensorflow/**</include>
|
||||||
|
<include>${project.build.sourceDirectory}/onnx/**</include>
|
||||||
|
</includes>
|
||||||
|
<token>com.google.protobuf.</token>
|
||||||
|
<value>org.nd4j.shade.protobuf.</value>
|
||||||
|
</configuration>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>replace-imports</id>
|
||||||
|
<phase>generate-sources</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>replace</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
|
||||||
|
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-compiler-plugin</artifactId>
|
<artifactId>maven-compiler-plugin</artifactId>
|
||||||
|
@ -148,20 +183,15 @@
|
||||||
<version>${flatbuffers.version}</version>
|
<version>${flatbuffers.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<!-- Note that this is shaded flatbuffers, see the protoc declaration above
|
<!-- Note that this is shaded protobuf. We use this instead of google's version mainly due ot other systems packaging
|
||||||
mentioning java-shaded as the type for why we use this instead of google's (mainly due ot other systems packaging
|
their own older (incompatible) protobuf versions-->
|
||||||
their own older protobuf versions-->
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.github.os72</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>protobuf-java-shaded-351</artifactId>
|
<artifactId>protobuf</artifactId>
|
||||||
<version>0.9</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.github.os72</groupId>
|
|
||||||
<artifactId>protobuf-java-util-shaded-351</artifactId>
|
|
||||||
<version>0.9</version>
|
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.objenesis</groupId>
|
<groupId>org.objenesis</groupId>
|
||||||
<artifactId>objenesis</artifactId>
|
<artifactId>objenesis</artifactId>
|
||||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
||||||
|
@ -101,10 +101,10 @@ public abstract class DifferentialFunction {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize the function from the given
|
* Initialize the function from the given
|
||||||
* {@link onnx.OnnxProto3.NodeProto}
|
* {@link onnx.Onnx.NodeProto}
|
||||||
* @param node
|
* @param node
|
||||||
*/
|
*/
|
||||||
public DifferentialFunction(SameDiff sameDiff,onnx.OnnxProto3.NodeProto node,Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public DifferentialFunction(SameDiff sameDiff,onnx.Onnx.NodeProto node,Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
this.sameDiff = sameDiff;
|
this.sameDiff = sameDiff;
|
||||||
setInstanceId();
|
setInstanceId();
|
||||||
initFromOnnx(node, sameDiff, attributesForNode, graph);
|
initFromOnnx(node, sameDiff, attributesForNode, graph);
|
||||||
|
@ -731,13 +731,13 @@ public abstract class DifferentialFunction {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Iniitialize the function from the given
|
* Iniitialize the function from the given
|
||||||
* {@link onnx.OnnxProto3.NodeProto}
|
* {@link onnx.Onnx.NodeProto}
|
||||||
* @param node
|
* @param node
|
||||||
* @param initWith
|
* @param initWith
|
||||||
* @param attributesForNode
|
* @param attributesForNode
|
||||||
* @param graph
|
* @param graph
|
||||||
*/
|
*/
|
||||||
public abstract void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph);
|
public abstract void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ package org.nd4j.autodiff.samediff;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.imports.descriptors.tensorflow;
|
package org.nd4j.imports.descriptors.tensorflow;
|
||||||
|
|
||||||
import com.github.os72.protobuf351.TextFormat;
|
import org.nd4j.shade.protobuf.TextFormat;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
import org.tensorflow.framework.OpDef;
|
import org.tensorflow.framework.OpDef;
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
|
|
||||||
package org.nd4j.imports.graphmapper;
|
package org.nd4j.imports.graphmapper;
|
||||||
|
|
||||||
import com.github.os72.protobuf351.Message;
|
import org.nd4j.shade.protobuf.Message;
|
||||||
import com.github.os72.protobuf351.TextFormat;
|
import org.nd4j.shade.protobuf.TextFormat;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.imports.graphmapper;
|
package org.nd4j.imports.graphmapper;
|
||||||
|
|
||||||
import com.github.os72.protobuf351.Message;
|
import org.nd4j.shade.protobuf.Message;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||||
|
|
|
@ -16,13 +16,13 @@
|
||||||
|
|
||||||
package org.nd4j.imports.graphmapper.onnx;
|
package org.nd4j.imports.graphmapper.onnx;
|
||||||
|
|
||||||
import com.github.os72.protobuf351.ByteString;
|
import org.nd4j.shade.protobuf.ByteString;
|
||||||
import com.github.os72.protobuf351.Message;
|
import org.nd4j.shade.protobuf.Message;
|
||||||
import com.google.common.primitives.Floats;
|
import com.google.common.primitives.Floats;
|
||||||
import com.google.common.primitives.Ints;
|
import com.google.common.primitives.Ints;
|
||||||
import com.google.common.primitives.Longs;
|
import com.google.common.primitives.Longs;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -52,7 +52,7 @@ import java.util.*;
|
||||||
*
|
*
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto, onnx.OnnxProto3.TypeProto.Tensor> {
|
public class OnnxGraphMapper extends BaseGraphMapper<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto, onnx.Onnx.TypeProto.Tensor> {
|
||||||
private static OnnxGraphMapper INSTANCE = new OnnxGraphMapper();
|
private static OnnxGraphMapper INSTANCE = new OnnxGraphMapper();
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,9 +64,9 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
||||||
@Override
|
@Override
|
||||||
public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) {
|
public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) {
|
||||||
try {
|
try {
|
||||||
OnnxProto3.ModelProto graphDef = OnnxProto3.ModelProto.parseFrom(inputFile);
|
Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(inputFile);
|
||||||
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
|
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
|
||||||
for(OnnxProto3.NodeProto node : graphDef.getGraph().getNodeList()) {
|
for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
|
||||||
bufferedWriter.write(node.toString() + "\n");
|
bufferedWriter.write(node.toString() + "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
||||||
* @param node
|
* @param node
|
||||||
* @param graph
|
* @param graph
|
||||||
*/
|
*/
|
||||||
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph) {
|
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.NodeProto node, Onnx.GraphProto graph) {
|
||||||
val properties = on.mappingsForFunction();
|
val properties = on.mappingsForFunction();
|
||||||
val tfProperties = properties.get(mappedTfName);
|
val tfProperties = properties.get(mappedTfName);
|
||||||
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
|
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
|
||||||
|
@ -170,18 +170,18 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isOpIgnoreException(OnnxProto3.NodeProto node) {
|
public boolean isOpIgnoreException(Onnx.NodeProto node) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getTargetMappingForOp(DifferentialFunction function, OnnxProto3.NodeProto node) {
|
public String getTargetMappingForOp(DifferentialFunction function, Onnx.NodeProto node) {
|
||||||
return function.opName();
|
return function.opName();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void mapProperty(String name, DifferentialFunction on, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
|
public void mapProperty(String name, DifferentialFunction on, Onnx.NodeProto node, Onnx.GraphProto graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
|
||||||
val mapping = propertyMappingsForFunction.get(name).get(getTargetMappingForOp(on, node));
|
val mapping = propertyMappingsForFunction.get(name).get(getTargetMappingForOp(on, node));
|
||||||
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
|
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
|
||||||
/**
|
/**
|
||||||
|
@ -263,7 +263,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OnnxProto3.NodeProto getNodeWithNameFromGraph(OnnxProto3.GraphProto graph, String name) {
|
public Onnx.NodeProto getNodeWithNameFromGraph(Onnx.GraphProto graph, String name) {
|
||||||
for(int i = 0; i < graph.getNodeCount(); i++) {
|
for(int i = 0; i < graph.getNodeCount(); i++) {
|
||||||
val node = graph.getNode(i);
|
val node = graph.getNode(i);
|
||||||
if(node.getName().equals(name))
|
if(node.getName().equals(name))
|
||||||
|
@ -274,21 +274,21 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isPlaceHolderNode(OnnxProto3.TypeProto.Tensor node) {
|
public boolean isPlaceHolderNode(Onnx.TypeProto.Tensor node) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<String> getControlDependencies(OnnxProto3.NodeProto node) {
|
public List<String> getControlDependencies(Onnx.NodeProto node) {
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
throw new UnsupportedOperationException("Not yet implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void dumpBinaryProtoAsText(File inputFile, File outputFile) {
|
public void dumpBinaryProtoAsText(File inputFile, File outputFile) {
|
||||||
try {
|
try {
|
||||||
OnnxProto3.ModelProto graphDef = OnnxProto3.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
|
Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
|
||||||
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
|
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
|
||||||
for(OnnxProto3.NodeProto node : graphDef.getGraph().getNodeList()) {
|
for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
|
||||||
bufferedWriter.write(node.toString());
|
bufferedWriter.write(node.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -316,12 +316,12 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String,onnx.OnnxProto3.TypeProto.Tensor> variablesForGraph(OnnxProto3.GraphProto graphProto) {
|
public Map<String,onnx.Onnx.TypeProto.Tensor> variablesForGraph(Onnx.GraphProto graphProto) {
|
||||||
/**
|
/**
|
||||||
* Need to figure out why
|
* Need to figure out why
|
||||||
* gpu_0/conv1_1 isn't present in VGG
|
* gpu_0/conv1_1 isn't present in VGG
|
||||||
*/
|
*/
|
||||||
Map<String,onnx.OnnxProto3.TypeProto.Tensor> ret = new HashMap<>();
|
Map<String,onnx.Onnx.TypeProto.Tensor> ret = new HashMap<>();
|
||||||
for(int i = 0; i < graphProto.getInputCount(); i++) {
|
for(int i = 0; i < graphProto.getInputCount(); i++) {
|
||||||
ret.put(graphProto.getInput(i).getName(),graphProto.getInput(i).getType().getTensorType());
|
ret.put(graphProto.getInput(i).getName(),graphProto.getInput(i).getType().getTensorType());
|
||||||
}
|
}
|
||||||
|
@ -356,19 +356,19 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String translateToSameDiffName(String name, OnnxProto3.NodeProto node) {
|
public String translateToSameDiffName(String name, Onnx.NodeProto node) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected void addDummyTensor(String name, Map<String, OnnxProto3.TypeProto.Tensor> to) {
|
protected void addDummyTensor(String name, Map<String, Onnx.TypeProto.Tensor> to) {
|
||||||
OnnxProto3.TensorShapeProto.Dimension dim = OnnxProto3.TensorShapeProto.Dimension.
|
Onnx.TensorShapeProto.Dimension dim = Onnx.TensorShapeProto.Dimension.
|
||||||
newBuilder()
|
newBuilder()
|
||||||
.setDimValue(-1)
|
.setDimValue(-1)
|
||||||
.build();
|
.build();
|
||||||
OnnxProto3.TypeProto.Tensor typeProto = OnnxProto3.TypeProto.Tensor.newBuilder()
|
Onnx.TypeProto.Tensor typeProto = Onnx.TypeProto.Tensor.newBuilder()
|
||||||
.setShape(
|
.setShape(
|
||||||
OnnxProto3.TensorShapeProto.newBuilder()
|
Onnx.TensorShapeProto.newBuilder()
|
||||||
.addDim(dim)
|
.addDim(dim)
|
||||||
.addDim(dim).build())
|
.addDim(dim).build())
|
||||||
.build();
|
.build();
|
||||||
|
@ -377,23 +377,23 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Message.Builder getNewGraphBuilder() {
|
public Message.Builder getNewGraphBuilder() {
|
||||||
return OnnxProto3.GraphProto.newBuilder();
|
return Onnx.GraphProto.newBuilder();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OnnxProto3.GraphProto parseGraphFrom(byte[] inputStream) throws IOException {
|
public Onnx.GraphProto parseGraphFrom(byte[] inputStream) throws IOException {
|
||||||
return OnnxProto3.ModelProto.parseFrom(inputStream).getGraph();
|
return Onnx.ModelProto.parseFrom(inputStream).getGraph();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OnnxProto3.GraphProto parseGraphFrom(InputStream inputStream) throws IOException {
|
public Onnx.GraphProto parseGraphFrom(InputStream inputStream) throws IOException {
|
||||||
return OnnxProto3.ModelProto.parseFrom(inputStream).getGraph();
|
return Onnx.ModelProto.parseFrom(inputStream).getGraph();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void mapNodeType(OnnxProto3.NodeProto tfNode, ImportState<OnnxProto3.GraphProto, OnnxProto3.TypeProto.Tensor> importState,
|
public void mapNodeType(Onnx.NodeProto tfNode, ImportState<Onnx.GraphProto, Onnx.TypeProto.Tensor> importState,
|
||||||
OpImportOverride<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto> opImportOverride,
|
OpImportOverride<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opImportOverride,
|
||||||
OpImportFilter<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto> opFilter) {
|
OpImportFilter<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opFilter) {
|
||||||
val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(tfNode.getOpType());
|
val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(tfNode.getOpType());
|
||||||
if(differentialFunction == null) {
|
if(differentialFunction == null) {
|
||||||
throw new NoOpNameFoundException("No op name found " + tfNode.getOpType());
|
throw new NoOpNameFoundException("No op name found " + tfNode.getOpType());
|
||||||
|
@ -425,13 +425,13 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType dataTypeForTensor(OnnxProto3.TypeProto.Tensor tensorProto, int outputNum) {
|
public DataType dataTypeForTensor(Onnx.TypeProto.Tensor tensorProto, int outputNum) {
|
||||||
return nd4jTypeFromOnnxType(tensorProto.getElemType());
|
return nd4jTypeFromOnnxType(tensorProto.getElemType());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isStringType(OnnxProto3.TypeProto.Tensor tensor) {
|
public boolean isStringType(Onnx.TypeProto.Tensor tensor) {
|
||||||
return tensor.getElemType() == OnnxProto3.TensorProto.DataType.STRING;
|
return tensor.getElemType() == Onnx.TensorProto.DataType.STRING;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -440,7 +440,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
||||||
* @param dataType the data type to convert
|
* @param dataType the data type to convert
|
||||||
* @return the nd4j type for the onnx type
|
* @return the nd4j type for the onnx type
|
||||||
*/
|
*/
|
||||||
public DataType nd4jTypeFromOnnxType(OnnxProto3.TensorProto.DataType dataType) {
|
public DataType nd4jTypeFromOnnxType(Onnx.TensorProto.DataType dataType) {
|
||||||
switch (dataType) {
|
switch (dataType) {
|
||||||
case DOUBLE: return DataType.DOUBLE;
|
case DOUBLE: return DataType.DOUBLE;
|
||||||
case FLOAT: return DataType.FLOAT;
|
case FLOAT: return DataType.FLOAT;
|
||||||
|
@ -452,8 +452,8 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getAttrValueFromNode(OnnxProto3.NodeProto nodeProto, String key) {
|
public String getAttrValueFromNode(Onnx.NodeProto nodeProto, String key) {
|
||||||
for(OnnxProto3.AttributeProto attributeProto : nodeProto.getAttributeList()) {
|
for(Onnx.AttributeProto attributeProto : nodeProto.getAttributeList()) {
|
||||||
if(attributeProto.getName().equals(key)) {
|
if(attributeProto.getName().equals(key)) {
|
||||||
return attributeProto.getS().toString();
|
return attributeProto.getS().toString();
|
||||||
}
|
}
|
||||||
|
@ -463,29 +463,29 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long[] getShapeFromAttribute(OnnxProto3.AttributeProto attributeProto) {
|
public long[] getShapeFromAttribute(Onnx.AttributeProto attributeProto) {
|
||||||
return Longs.toArray(attributeProto.getT().getDimsList());
|
return Longs.toArray(attributeProto.getT().getDimsList());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isPlaceHolder(OnnxProto3.TypeProto.Tensor nodeType) {
|
public boolean isPlaceHolder(Onnx.TypeProto.Tensor nodeType) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isConstant(OnnxProto3.TypeProto.Tensor nodeType) {
|
public boolean isConstant(Onnx.TypeProto.Tensor nodeType) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph) {
|
public INDArray getNDArrayFromTensor(String tensorName, Onnx.TypeProto.Tensor tensorProto, Onnx.GraphProto graph) {
|
||||||
DataType type = dataTypeForTensor(tensorProto, 0);
|
DataType type = dataTypeForTensor(tensorProto, 0);
|
||||||
if(!tensorProto.isInitialized()) {
|
if(!tensorProto.isInitialized()) {
|
||||||
throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
|
throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
|
||||||
}
|
}
|
||||||
|
|
||||||
OnnxProto3.TensorProto tensor = null;
|
Onnx.TensorProto tensor = null;
|
||||||
for(int i = 0; i < graph.getInitializerCount(); i++) {
|
for(int i = 0; i < graph.getInitializerCount(); i++) {
|
||||||
val initializer = graph.getInitializer(i);
|
val initializer = graph.getInitializer(i);
|
||||||
if(initializer.getName().equals(tensorName)) {
|
if(initializer.getName().equals(tensorName)) {
|
||||||
|
@ -508,7 +508,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
||||||
return arr;
|
return arr;
|
||||||
}
|
}
|
||||||
|
|
||||||
public INDArray mapTensorProto(OnnxProto3.TensorProto tensor) {
|
public INDArray mapTensorProto(Onnx.TensorProto tensor) {
|
||||||
if(tensor == null)
|
if(tensor == null)
|
||||||
return null;
|
return null;
|
||||||
|
|
||||||
|
@ -527,7 +527,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long[] getShapeFromTensor(onnx.OnnxProto3.TypeProto.Tensor tensorProto) {
|
public long[] getShapeFromTensor(onnx.Onnx.TypeProto.Tensor tensorProto) {
|
||||||
val ret = new long[Math.max(2,tensorProto.getShape().getDimCount())];
|
val ret = new long[Math.max(2,tensorProto.getShape().getDimCount())];
|
||||||
int dimCount = tensorProto.getShape().getDimCount();
|
int dimCount = tensorProto.getShape().getDimCount();
|
||||||
if(dimCount >= 2)
|
if(dimCount >= 2)
|
||||||
|
@ -548,11 +548,11 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the shape from a tensor proto.
|
* Get the shape from a tensor proto.
|
||||||
* Note that this is different from {@link #getShapeFromTensor(OnnxProto3.TensorProto)}
|
* Note that this is different from {@link #getShapeFromTensor(Onnx.TensorProto)}
|
||||||
* @param tensorProto the tensor to get the shape from
|
* @param tensorProto the tensor to get the shape from
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public long[] getShapeFromTensor(OnnxProto3.TensorProto tensorProto) {
|
public long[] getShapeFromTensor(Onnx.TensorProto tensorProto) {
|
||||||
val ret = new long[Math.max(2,tensorProto.getDimsCount())];
|
val ret = new long[Math.max(2,tensorProto.getDimsCount())];
|
||||||
int dimCount = tensorProto.getDimsCount();
|
int dimCount = tensorProto.getDimsCount();
|
||||||
if(dimCount >= 2)
|
if(dimCount >= 2)
|
||||||
|
@ -577,74 +577,74 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getInputFromNode(OnnxProto3.NodeProto node, int index) {
|
public String getInputFromNode(Onnx.NodeProto node, int index) {
|
||||||
return node.getInput(index);
|
return node.getInput(index);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int numInputsFor(OnnxProto3.NodeProto nodeProto) {
|
public int numInputsFor(Onnx.NodeProto nodeProto) {
|
||||||
return nodeProto.getInputCount();
|
return nodeProto.getInputCount();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long[] getShapeFromAttr(OnnxProto3.AttributeProto attr) {
|
public long[] getShapeFromAttr(Onnx.AttributeProto attr) {
|
||||||
return Longs.toArray(attr.getT().getDimsList());
|
return Longs.toArray(attr.getT().getDimsList());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, OnnxProto3.AttributeProto> getAttrMap(OnnxProto3.NodeProto nodeProto) {
|
public Map<String, Onnx.AttributeProto> getAttrMap(Onnx.NodeProto nodeProto) {
|
||||||
Map<String,OnnxProto3.AttributeProto> proto = new HashMap<>();
|
Map<String,Onnx.AttributeProto> proto = new HashMap<>();
|
||||||
for(int i = 0; i < nodeProto.getAttributeCount(); i++) {
|
for(int i = 0; i < nodeProto.getAttributeCount(); i++) {
|
||||||
OnnxProto3.AttributeProto attributeProto = nodeProto.getAttribute(i);
|
Onnx.AttributeProto attributeProto = nodeProto.getAttribute(i);
|
||||||
proto.put(attributeProto.getName(),attributeProto);
|
proto.put(attributeProto.getName(),attributeProto);
|
||||||
}
|
}
|
||||||
return proto;
|
return proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getName(OnnxProto3.NodeProto nodeProto) {
|
public String getName(Onnx.NodeProto nodeProto) {
|
||||||
return nodeProto.getName();
|
return nodeProto.getName();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean alreadySeen(OnnxProto3.NodeProto nodeProto) {
|
public boolean alreadySeen(Onnx.NodeProto nodeProto) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isVariableNode(OnnxProto3.NodeProto nodeProto) {
|
public boolean isVariableNode(Onnx.NodeProto nodeProto) {
|
||||||
return nodeProto.getOpType().contains("Var");
|
return nodeProto.getOpType().contains("Var");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean shouldSkip(OnnxProto3.NodeProto opType) {
|
public boolean shouldSkip(Onnx.NodeProto opType) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean hasShape(OnnxProto3.NodeProto nodeProto) {
|
public boolean hasShape(Onnx.NodeProto nodeProto) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long[] getShape(OnnxProto3.NodeProto nodeProto) {
|
public long[] getShape(Onnx.NodeProto nodeProto) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray getArrayFrom(OnnxProto3.NodeProto nodeProto, OnnxProto3.GraphProto graph) {
|
public INDArray getArrayFrom(Onnx.NodeProto nodeProto, Onnx.GraphProto graph) {
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getOpType(OnnxProto3.NodeProto nodeProto) {
|
public String getOpType(Onnx.NodeProto nodeProto) {
|
||||||
return nodeProto.getOpType();
|
return nodeProto.getOpType();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<OnnxProto3.NodeProto> getNodeList(OnnxProto3.GraphProto graphProto) {
|
public List<Onnx.NodeProto> getNodeList(Onnx.GraphProto graphProto) {
|
||||||
return graphProto.getNodeList();
|
return graphProto.getNodeList();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.imports.graphmapper.tf;
|
package org.nd4j.imports.graphmapper.tf;
|
||||||
|
|
||||||
import com.github.os72.protobuf351.Message;
|
import org.nd4j.shade.protobuf.Message;
|
||||||
import com.google.common.primitives.Floats;
|
import com.google.common.primitives.Floats;
|
||||||
import com.google.common.primitives.Ints;
|
import com.google.common.primitives.Ints;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
package org.nd4j.imports.graphmapper.tf.tensors;
|
package org.nd4j.imports.graphmapper.tf.tensors;
|
||||||
|
|
||||||
import com.github.os72.protobuf351.Descriptors;
|
import org.nd4j.shade.protobuf.Descriptors;
|
||||||
import org.bytedeco.javacpp.indexer.Bfloat16ArrayIndexer;
|
import org.bytedeco.javacpp.indexer.Bfloat16ArrayIndexer;
|
||||||
import org.bytedeco.javacpp.indexer.HalfIndexer;
|
import org.bytedeco.javacpp.indexer.HalfIndexer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
|
|
@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -205,7 +205,7 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -200,7 +200,7 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp {
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ import lombok.Data;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -134,7 +134,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
|
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
|
||||||
|
@ -218,7 +218,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
if (!attributesForNode.containsKey("axes")) {
|
if (!attributesForNode.containsKey("axes")) {
|
||||||
this.dimensions = new int[] { Integer.MAX_VALUE };
|
this.dimensions = new int[] { Integer.MAX_VALUE };
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@ import com.google.common.primitives.Doubles;
|
||||||
import com.google.common.primitives.Longs;
|
import com.google.common.primitives.Longs;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -603,7 +603,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops;
|
package org.nd4j.linalg.api.ops;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -61,7 +61,7 @@ public class NoOp extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow;
|
||||||
|
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -367,7 +367,7 @@ public class If extends DifferentialFunction implements CustomOp {
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow;
|
||||||
|
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -468,7 +468,7 @@ public class While extends DifferentialFunction implements CustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers;
|
package org.nd4j.linalg.api.ops.impl.layers;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -122,7 +122,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -96,7 +96,7 @@ public class Linear extends BaseModule {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -260,7 +260,7 @@ public class AvgPooling2D extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
|
val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
|
||||||
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
|
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
|
||||||
val padding = !attributesForNode.containsKey("pads") ? Arrays.asList(1L) : attributesForNode.get("pads").getIntsList();
|
val padding = !attributesForNode.containsKey("pads") ? Arrays.asList(1L) : attributesForNode.get("pads").getIntsList();
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -78,7 +78,7 @@ public class AvgPooling3D extends Pooling3D {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
throw new UnsupportedOperationException("Not yet implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
|
@ -139,7 +139,7 @@ public class BatchNorm extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -127,7 +127,7 @@ public class Conv2D extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -247,7 +247,7 @@ public class DeConv2D extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
val autoPad = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
|
val autoPad = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
|
||||||
val dilations = attributesForNode.get("dilations");
|
val dilations = attributesForNode.get("dilations");
|
||||||
val dilationY = dilations == null ? 1 : dilations.getIntsList().get(0).intValue();
|
val dilationY = dilations == null ? 1 : dilations.getIntsList().get(0).intValue();
|
||||||
|
|
|
@ -20,7 +20,7 @@ import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -151,7 +151,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -115,7 +115,7 @@ public class LocalResponseNormalization extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
val aAlpha = attributesForNode.get("alpha");
|
val aAlpha = attributesForNode.get("alpha");
|
||||||
val aBeta = attributesForNode.get("beta");
|
val aBeta = attributesForNode.get("beta");
|
||||||
val aBias = attributesForNode.get("bias");
|
val aBias = attributesForNode.get("bias");
|
||||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -221,7 +221,7 @@ public class MaxPooling2D extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
|
val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
|
||||||
val isSameNode = paddingVal.equals("SAME");
|
val isSameNode = paddingVal.equals("SAME");
|
||||||
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
|
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -78,7 +78,7 @@ public class MaxPooling3D extends Pooling3D {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
throw new UnsupportedOperationException("Not yet implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -183,7 +183,7 @@ public class Pooling2D extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
val isSameNode = attributesForNode.get("auto_pad").getS().equals("SAME");
|
val isSameNode = attributesForNode.get("auto_pad").getS().equals("SAME");
|
||||||
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
|
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
|
||||||
val padding = attributesForNode.get("pads").getIntsList();
|
val padding = attributesForNode.get("pads").getIntsList();
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMCellConfiguration;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMCellConfiguration;
|
||||||
|
@ -73,7 +73,7 @@ public class LSTMCell extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
@ -65,7 +65,7 @@ public class SRU extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
@ -66,7 +66,7 @@ public class SRUCell extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -204,7 +204,7 @@ public class Mmul extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
|
val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
|
||||||
val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
|
val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
|
||||||
MMulTranspose mMulTranspose = MMulTranspose.builder()
|
MMulTranspose mMulTranspose = MMulTranspose.builder()
|
||||||
|
|
|
@ -20,7 +20,7 @@ import com.google.common.primitives.Ints;
|
||||||
import com.google.common.primitives.Longs;
|
import com.google.common.primitives.Longs;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||||
|
@ -283,7 +283,7 @@ public class TensorMmul extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
|
val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
|
||||||
val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
|
val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
|
||||||
MMulTranspose mMulTranspose = MMulTranspose.builder()
|
MMulTranspose mMulTranspose = MMulTranspose.builder()
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -163,7 +163,7 @@ public class Concat extends DynamicCustomOp {
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -77,7 +77,7 @@ public class Diag extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -79,7 +79,7 @@ public class DiagPart extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||||
|
@ -78,7 +78,7 @@ public class Gather extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -65,7 +65,7 @@ public class MergeAvg extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -64,7 +64,7 @@ public class MergeMax extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -66,7 +66,7 @@ public class MergeSum extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -68,7 +68,7 @@ public class ParallelStack extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
|
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -66,7 +66,7 @@ public class Rank extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -106,7 +106,7 @@ public class Repeat extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -126,7 +126,7 @@ public class Reshape extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
val shape = new OnnxGraphMapper().getShape(node);
|
val shape = new OnnxGraphMapper().getShape(node);
|
||||||
this.shape = shape;
|
this.shape = shape;
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxMlProto3;
|
import onnx.OnnxMl;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
||||||
|
@ -87,7 +87,7 @@ public class Shape extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
throw new NoOpNameFoundException("No onnx name found for shape " + opName());
|
throw new NoOpNameFoundException("No onnx name found for shape " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -93,7 +93,7 @@ public class Stack extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
|
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import com.google.common.primitives.Ints;
|
import com.google.common.primitives.Ints;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.VariableType;
|
import org.nd4j.autodiff.samediff.VariableType;
|
||||||
|
@ -156,7 +156,7 @@ public class Transpose extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
if (!attributesForNode.containsKey("perm")) {
|
if (!attributesForNode.containsKey("perm")) {
|
||||||
|
|
||||||
} else
|
} else
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -127,7 +127,7 @@ public class Unstack extends DynamicCustomOp {
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
|
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape.bp;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -71,7 +71,7 @@ public class ConcatBp extends DynamicCustomOp {
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
//No op
|
//No op
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
@ -59,7 +59,7 @@ public class TensorArrayConcat extends BaseTensorOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
@ -59,7 +59,7 @@ public class TensorArrayGather extends BaseTensorOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -54,7 +54,7 @@ public class TensorArrayRead extends BaseTensorOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -52,7 +52,7 @@ public class TensorArrayScatter extends BaseTensorOp {
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -58,7 +58,7 @@ public class TensorArraySize extends BaseTensorOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -52,7 +52,7 @@ public class TensorArraySplit extends BaseTensorOp {
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.clip;
|
package org.nd4j.linalg.api.ops.impl.transforms.clip;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -64,7 +64,7 @@ public class ClipByNorm extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
throw new UnsupportedOperationException("Not yet implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.clip;
|
package org.nd4j.linalg.api.ops.impl.transforms.clip;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -77,7 +77,7 @@ public class ClipByValue extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
throw new UnsupportedOperationException("Not yet implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -62,7 +62,7 @@ public class Assign extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -132,7 +132,7 @@ public class CumProd extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
|
||||||
|
@ -133,7 +133,7 @@ public class CumSum extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -80,7 +80,7 @@ public class Fill extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||||
|
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -81,7 +81,7 @@ public class RectifiedTanh extends BaseTransformStrictOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.random.impl;
|
package org.nd4j.linalg.api.ops.random.impl;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -75,7 +75,7 @@ public class DropOutInverted extends BaseRandomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.random.impl;
|
package org.nd4j.linalg.api.ops.random.impl;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxProto3;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
syntax = "proto3";
|
syntax = "proto3";
|
||||||
|
|
||||||
package onnx;
|
package onnx;
|
||||||
import "onnx.proto3";
|
import "onnx.proto";
|
||||||
|
|
||||||
//
|
//
|
||||||
// This file contains the proto definitions for OperatorSetProto and
|
// This file contains the proto definitions for OperatorSetProto and
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.tensorflow.conversion;
|
package org.nd4j.tensorflow.conversion;
|
||||||
|
|
||||||
import com.github.os72.protobuf351.util.JsonFormat;
|
import org.nd4j.shade.protobuf.util.JsonFormat;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.tensorflow.conversion;
|
package org.nd4j.tensorflow.conversion;
|
||||||
|
|
||||||
import com.github.os72.protobuf351.util.JsonFormat;
|
import org.nd4j.shade.protobuf.util.JsonFormat;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
|
@ -732,4 +732,20 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
fail("Failed datatypes: " + failed.toString());
|
fail("Failed datatypes: " + failed.toString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMaxPool2Dbp_1() {
|
||||||
|
val x = Nd4j.create(DataType.HALF, 2,3,16,16).assign(Double.NaN);
|
||||||
|
val y = Nd4j.create(DataType.HALF, 2,3,8,8).assign(Double.NaN);
|
||||||
|
val z = Nd4j.create(DataType.HALF, 2,3,16,16);
|
||||||
|
|
||||||
|
val op = DynamicCustomOp.builder("maxpool2d_bp")
|
||||||
|
.addInputs(x, y)
|
||||||
|
.addOutputs(z)
|
||||||
|
.addIntegerArguments(2, 2, 2, 2, 8,8, 1,1,1, 0,0)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Nd4j.exec(op);
|
||||||
|
Nd4j.getExecutioner().commit();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@
|
||||||
<packaging>pom</packaging>
|
<packaging>pom</packaging>
|
||||||
<modules>
|
<modules>
|
||||||
<module>jackson</module>
|
<module>jackson</module>
|
||||||
|
<module>protobuf</module>
|
||||||
</modules>
|
</modules>
|
||||||
|
|
||||||
<properties>
|
<properties>
|
||||||
|
|
|
@ -0,0 +1,228 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<parent>
|
||||||
|
<artifactId>nd4j-shade</artifactId>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
|
</parent>
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<artifactId>protobuf</artifactId>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<skipTestResourceEnforcement>true</skipTestResourceEnforcement>
|
||||||
|
</properties>
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.google.protobuf</groupId>
|
||||||
|
<artifactId>protobuf-java</artifactId>
|
||||||
|
<version>3.8.0</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.google.protobuf</groupId>
|
||||||
|
<artifactId>protobuf-java-util</artifactId>
|
||||||
|
<version>3.8.0</version>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
|
|
||||||
|
<profiles>
|
||||||
|
<profile>
|
||||||
|
<id>custom-lifecycle</id>
|
||||||
|
|
||||||
|
<activation>
|
||||||
|
<property><name>!skip.custom.lifecycle</name></property>
|
||||||
|
</activation>
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.portals.jetspeed-2</groupId>
|
||||||
|
<artifactId>jetspeed-mvn-maven-plugin</artifactId>
|
||||||
|
<version>2.3.1</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>compile-and-pack</id>
|
||||||
|
<phase>compile</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>mvn</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.maven.shared</groupId>
|
||||||
|
<artifactId>maven-invoker</artifactId>
|
||||||
|
<version>2.2</version>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
<configuration>
|
||||||
|
<targets combine.children="merge">
|
||||||
|
|
||||||
|
<target>
|
||||||
|
<id>create-shaded-jars</id>
|
||||||
|
<dir>@rootdir@/nd4j/nd4j-shade/protobuf/</dir>
|
||||||
|
<goals>clean,compile,package</goals>
|
||||||
|
<properties>
|
||||||
|
<skip.custom.lifecycle>true</skip.custom.lifecycle>
|
||||||
|
</properties>
|
||||||
|
</target>
|
||||||
|
|
||||||
|
</targets>
|
||||||
|
<defaultTarget>create-shaded-jars</defaultTarget>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
</profile>
|
||||||
|
</profiles>
|
||||||
|
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
|
||||||
|
<!-- Disable Maven Lint plugin in this module. For some reason it chokes on this module (internal NPE) and we don't need it anyway here -->
|
||||||
|
<plugin>
|
||||||
|
<groupId>com.lewisd</groupId>
|
||||||
|
<artifactId>lint-maven-plugin</artifactId>
|
||||||
|
<version>0.0.11</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>pom-lint</id>
|
||||||
|
<phase>none</phase>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
|
||||||
|
|
||||||
|
<!--
|
||||||
|
Use Maven Shade plugin to add a shaded version of the Protobuf dependencies, that can be imported by
|
||||||
|
including this module (org.nd4j.protobuf) as a dependency.
|
||||||
|
The standard com.google.protobuf dependencies will be provided, though are prefixed by org.nd4j.shade.protobuf
|
||||||
|
-->
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-shade-plugin</artifactId>
|
||||||
|
<version>${maven-shade-plugin.version}</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<phase>package</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>shade</goal>
|
||||||
|
</goals>
|
||||||
|
<configuration>
|
||||||
|
<transformers>
|
||||||
|
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
|
||||||
|
<resource>reference.conf</resource>
|
||||||
|
</transformer>
|
||||||
|
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
|
||||||
|
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
|
||||||
|
</transformer>
|
||||||
|
</transformers>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
|
||||||
|
<configuration>
|
||||||
|
<!--
|
||||||
|
Important configuration options here:
|
||||||
|
createDependencyReducedPom: remove the shaded artifacts from the module dependencies. Without this, the
|
||||||
|
original dependencies will be shaded, AND still included as transitive deps
|
||||||
|
in the final POM. This is not what we want.
|
||||||
|
shadedArtifactAttached: If true, the shaded artifact will be a separate JAR file for install, with
|
||||||
|
the original un-shaded JAR being separate. With this being set to false,
|
||||||
|
the original JAR will be modified, and no extra jar will be produced.
|
||||||
|
promoteTransitiveDependencies: This will promote the transitive dependencies of the shaded dependencies
|
||||||
|
to direct dependencies. Without this, we need to manually manage the transitive
|
||||||
|
dependencies of the shaded artifacts.
|
||||||
|
|
||||||
|
Note that using <optional>true</optional> in the dependencies also allows the deps to be shaded (and
|
||||||
|
original dependencies to not be included), but does NOT work with promoteTransitiveDependencies
|
||||||
|
-->
|
||||||
|
<shadedArtifactAttached>false</shadedArtifactAttached>
|
||||||
|
<createDependencyReducedPom>true</createDependencyReducedPom>
|
||||||
|
<promoteTransitiveDependencies>true</promoteTransitiveDependencies>
|
||||||
|
|
||||||
|
<artifactSet>
|
||||||
|
<includes>
|
||||||
|
<include>com.google.protobuf:*</include>
|
||||||
|
<include>com.google.protobuf.*:*</include>
|
||||||
|
</includes>
|
||||||
|
</artifactSet>
|
||||||
|
|
||||||
|
<relocations>
|
||||||
|
<!-- Protobuf dependencies -->
|
||||||
|
<relocation>
|
||||||
|
<pattern>com.google.protobuf</pattern>
|
||||||
|
<shadedPattern>org.nd4j.shade.protobuf</shadedPattern>
|
||||||
|
</relocation>
|
||||||
|
</relocations>
|
||||||
|
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-jar-plugin</artifactId>
|
||||||
|
<configuration>
|
||||||
|
<forceCreation>true</forceCreation>
|
||||||
|
</configuration>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>empty-javadoc-jar</id>
|
||||||
|
<phase>package</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>jar</goal>
|
||||||
|
</goals>
|
||||||
|
<configuration>
|
||||||
|
<classifier>javadoc</classifier>
|
||||||
|
<classesDirectory>${basedir}/javadoc</classesDirectory>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
<execution>
|
||||||
|
<id>empty-sources-jar</id>
|
||||||
|
<phase>package</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>jar</goal>
|
||||||
|
</goals>
|
||||||
|
<configuration>
|
||||||
|
<classifier>sources</classifier>
|
||||||
|
<classesDirectory>${basedir}/src</classesDirectory>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-dependency-plugin</artifactId>
|
||||||
|
<version>3.0.0</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>unpack</id>
|
||||||
|
<phase>package</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>unpack</goal>
|
||||||
|
</goals>
|
||||||
|
<configuration>
|
||||||
|
<artifactItems>
|
||||||
|
<artifactItem>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>protobuf</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
<type>jar</type>
|
||||||
|
<overWrite>false</overWrite>
|
||||||
|
<outputDirectory>${project.build.directory}/classes/</outputDirectory>
|
||||||
|
<includes>**/*.class,**/*.xml</includes>
|
||||||
|
</artifactItem>
|
||||||
|
</artifactItems>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
|
||||||
|
</project>
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.tensorflow.conversion;
|
package org.nd4j.tensorflow.conversion;
|
||||||
|
|
||||||
import com.github.os72.protobuf351.InvalidProtocolBufferException;
|
import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
|
||||||
import org.bytedeco.javacpp.*;
|
import org.bytedeco.javacpp.*;
|
||||||
import org.bytedeco.javacpp.indexer.*;
|
import org.bytedeco.javacpp.indexer.*;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
|
|
|
@ -16,9 +16,9 @@
|
||||||
|
|
||||||
package org.nd4j.tensorflow.conversion.graphrunner;
|
package org.nd4j.tensorflow.conversion.graphrunner;
|
||||||
|
|
||||||
import com.github.os72.protobuf351.ByteString;
|
import org.nd4j.shade.protobuf.ByteString;
|
||||||
import com.github.os72.protobuf351.InvalidProtocolBufferException;
|
import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
|
||||||
import com.github.os72.protobuf351.util.JsonFormat;
|
import org.nd4j.shade.protobuf.util.JsonFormat;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
@ -638,7 +638,7 @@ public class GraphRunner implements Closeable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert a json string written out
|
* Convert a json string written out
|
||||||
* by {@link com.github.os72.protobuf351.util.JsonFormat}
|
* by {@link org.nd4j.shade.protobuf.util.JsonFormat}
|
||||||
* to a {@link org.bytedeco.tensorflow.ConfigProto}
|
* to a {@link org.bytedeco.tensorflow.ConfigProto}
|
||||||
* @param json the json to read
|
* @param json the json to read
|
||||||
* @return the config proto to use
|
* @return the config proto to use
|
||||||
|
|
Loading…
Reference in New Issue