commit
bfaa20e46c
|
@ -21,7 +21,7 @@ package org.deeplearning4j.nn.modelimport.keras.exceptions;
|
||||||
* Indicates that user is attempting to import a Keras model configuration that
|
* Indicates that user is attempting to import a Keras model configuration that
|
||||||
* is malformed or invalid in some other way.
|
* is malformed or invalid in some other way.
|
||||||
*
|
*
|
||||||
* See <a href="https://deeplearning4j.org/docs/latest/keras-import-overview">https://deeplearning4j.org/docs/latest/keras-import-overview</a> for more information.
|
* See <a href="https://deeplearning4j.konduit.ai/keras-import/overview">https://deeplearning4j.konduit.ai/keras-import/overview</a> for more information.
|
||||||
*
|
*
|
||||||
* @author dave@skymind.io
|
* @author dave@skymind.io
|
||||||
*/
|
*/
|
||||||
|
@ -40,6 +40,6 @@ public class InvalidKerasConfigurationException extends Exception {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static String appendDocumentationURL(String message) {
|
private static String appendDocumentationURL(String message) {
|
||||||
return message + ". For more information, see http://deeplearning4j.org/docs/latest/keras-import-overview";
|
return message + ". For more information, see https://deeplearning4j.konduit.ai/keras-import/overview";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@ package org.deeplearning4j.nn.modelimport.keras.exceptions;
|
||||||
* Indicates that user is attempting to import a Keras model configuration that
|
* Indicates that user is attempting to import a Keras model configuration that
|
||||||
* is not currently supported.
|
* is not currently supported.
|
||||||
*
|
*
|
||||||
* See <a href="https://deeplearning4j.org/docs/latest/keras-import-overview">https://deeplearning4j.org/docs/latest/keras-import-overview</a>
|
* See <a href="https://deeplearning4j.konduit.ai/keras-import/overview">https://deeplearning4j.konduit.ai/keras-import/overview</a>
|
||||||
* for more information and file an issue at <a href="https://github.com/eclipse/deeplearning4j/issues">https://github.com/eclipse/deeplearning4j/issues</a>.
|
* for more information and file an issue at <a href="https://github.com/eclipse/deeplearning4j/issues">https://github.com/eclipse/deeplearning4j/issues</a>.
|
||||||
*
|
*
|
||||||
* @author dave@skymind.io
|
* @author dave@skymind.io
|
||||||
|
|
|
@ -103,7 +103,7 @@ public class KerasEmbedding extends KerasLayer {
|
||||||
"on Embedding layers. Zero Masking for the Embedding layer only works with unidirectional LSTM for now."
|
"on Embedding layers. Zero Masking for the Embedding layer only works with unidirectional LSTM for now."
|
||||||
+ " If you want to have this behaviour for your imported model " +
|
+ " If you want to have this behaviour for your imported model " +
|
||||||
"in DL4J, apply masking as a pre-processing step to your input." +
|
"in DL4J, apply masking as a pre-processing step to your input." +
|
||||||
"See http://deeplearning4j.org/docs/latest/deeplearning4j-nn-recurrent#masking for more on this.");
|
"See https://deeplearning4j.konduit.ai/models/recurrent#masking-one-to-many-many-to-one-and-sequence-classification for more on this.");
|
||||||
|
|
||||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_EMBEDDING_INIT(),
|
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_EMBEDDING_INIT(),
|
||||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||||
|
|
|
@ -17,10 +17,10 @@
|
||||||
package org.deeplearning4j.nn.conf;
|
package org.deeplearning4j.nn.conf;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Workspace mode to use. See <a href="https://deeplearning4j.org/docs/latest/deeplearning4j-config-workspaces">https://deeplearning4j.org/docs/latest/deeplearning4j-config-workspaces</a><br>
|
* Workspace mode to use. See <a href="https://deeplearning4j.konduit.ai/config/config-memory/config-workspaces">https://deeplearning4j.konduit.ai/config/config-memory/config-workspaces</a><br>
|
||||||
* <br>
|
* <br>
|
||||||
* NONE: No workspaces will be used for the network. Highest memory use, least performance.<br>
|
* NONE: No workspaces will be used for the network. Highest memory use, least performance.<br>
|
||||||
* ENABLED: Use workspaces.<br>
|
* ENABLED: Use workspaces. This is the default and should almost always be used<br>
|
||||||
* SINGLE: Deprecated. Now equivalent to ENABLED, which should be used instead.<br>
|
* SINGLE: Deprecated. Now equivalent to ENABLED, which should be used instead.<br>
|
||||||
* SEPARATE: Deprecated. Now equivalent to ENABLED, which sohuld be used instead.<br>
|
* SEPARATE: Deprecated. Now equivalent to ENABLED, which sohuld be used instead.<br>
|
||||||
*
|
*
|
||||||
|
|
|
@ -38,7 +38,7 @@ import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* LSTM recurrent neural network layer without peephole connections. Supports CuDNN acceleration - see <a
|
* LSTM recurrent neural network layer without peephole connections. Supports CuDNN acceleration - see <a
|
||||||
* href="https://deeplearning4j.org/docs/latest/deeplearning4j-config-cudnn">https://deeplearning4j.org/docs/latest/deeplearning4j-config-cudnn</a> for details
|
* href="https://deeplearning4j.konduit.ai/config/backends/config-cudnn">https://deeplearning4j.konduit.ai/config/backends/config-cudnn</a> for details
|
||||||
*
|
*
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
* @see GravesLSTM GravesLSTM class for an alternative LSTM (with peephole connections)
|
* @see GravesLSTM GravesLSTM class for an alternative LSTM (with peephole connections)
|
||||||
|
|
|
@ -1540,8 +1540,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
* (not) clearing the layer input arrays.<br>
|
* (not) clearing the layer input arrays.<br>
|
||||||
* Note: this method should NOT be used with clearInputs = true, unless you know what you are doing. Specifically:
|
* Note: this method should NOT be used with clearInputs = true, unless you know what you are doing. Specifically:
|
||||||
* when using clearInputs=false, in combination with workspaces, the layer input fields may leak outside of the
|
* when using clearInputs=false, in combination with workspaces, the layer input fields may leak outside of the
|
||||||
* workspaces in which they were defined - potentially causing a crash. See <a href="https://deeplearning4j.org/docs/latest/deeplearning4j-config-workspaces">
|
* workspaces in which they were defined - potentially causing a crash. See <a href="https://deeplearning4j.konduit.ai/config/config-memory/config-workspaces">
|
||||||
* https://deeplearning4j.org/docs/latest/deeplearning4j-config-workspaces</a>
|
* https://deeplearning4j.konduit.ai/config/config-memory/config-workspaces</a>
|
||||||
* for more details
|
* for more details
|
||||||
*
|
*
|
||||||
* @param input An array of ComputationGraph inputs
|
* @param input An array of ComputationGraph inputs
|
||||||
|
|
|
@ -86,7 +86,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
} else {
|
} else {
|
||||||
OneTimeLogger.info(log, "cuDNN not found: "
|
OneTimeLogger.info(log, "cuDNN not found: "
|
||||||
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
||||||
+ "For more information, please refer to: https://deeplearning4j.org/docs/latest/deeplearning4j-config-cudnn", t);
|
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if("CPU".equalsIgnoreCase(backend)){
|
} else if("CPU".equalsIgnoreCase(backend)){
|
||||||
|
|
|
@ -78,7 +78,7 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
|
||||||
} else {
|
} else {
|
||||||
OneTimeLogger.info(log, "cuDNN not found: "
|
OneTimeLogger.info(log, "cuDNN not found: "
|
||||||
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
||||||
+ "For more information, please refer to: https://deeplearning4j.org/docs/latest/deeplearning4j-config-cudnn", t);
|
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if("CPU".equalsIgnoreCase(backend) ){
|
} else if("CPU".equalsIgnoreCase(backend) ){
|
||||||
|
|
|
@ -86,7 +86,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
||||||
} else {
|
} else {
|
||||||
OneTimeLogger.info(log, "cuDNN not found: "
|
OneTimeLogger.info(log, "cuDNN not found: "
|
||||||
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
||||||
+ "For more information, please refer to: https://deeplearning4j.org/docs/latest/deeplearning4j-config-cudnn", t);
|
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if("CPU".equalsIgnoreCase(backend)){
|
} else if("CPU".equalsIgnoreCase(backend)){
|
||||||
|
|
|
@ -96,7 +96,7 @@ public class LocalResponseNormalization
|
||||||
} else {
|
} else {
|
||||||
OneTimeLogger.info(log, "cuDNN not found: "
|
OneTimeLogger.info(log, "cuDNN not found: "
|
||||||
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
||||||
+ "For more information, please refer to: https://deeplearning4j.org/docs/latest/deeplearning4j-config-cudnn", t);
|
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,7 +32,7 @@ import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* RNN tutorial: https://deeplearning4j.org/docs/latest/deeplearning4j-nn-recurrent
|
* RNN tutorial: https://deeplearning4j.konduit.ai/models/recurrent
|
||||||
* READ THIS FIRST
|
* READ THIS FIRST
|
||||||
*
|
*
|
||||||
* Bdirectional LSTM layer implementation.
|
* Bdirectional LSTM layer implementation.
|
||||||
|
|
|
@ -71,7 +71,7 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
|
||||||
} else {
|
} else {
|
||||||
OneTimeLogger.info(log, "cuDNN not found: "
|
OneTimeLogger.info(log, "cuDNN not found: "
|
||||||
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
+ "use cuDNN for better GPU performance by including the deeplearning4j-cuda module. "
|
||||||
+ "For more information, please refer to: https://deeplearning4j.org/docs/latest/deeplearning4j-config-cudnn", t);
|
+ "For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,8 +52,8 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* RNN tutorial: <a href="https://deeplearning4j.org/docs/latest/deeplearning4j-nn-recurrent">https://deeplearning4j.org/docs/latest/deeplearning4j-nn-recurrent</a>
|
* RNN tutorial: <a href="https://deeplearning4j.konduit.ai/models/recurrent">https://deeplearning4j.konduit.ai/models/recurrent</a>
|
||||||
* READ THIS FIRST if you want to understand what the heck is happening here.
|
* READ THIS FIRST if you want to understand this code.
|
||||||
*
|
*
|
||||||
* Shared code for the standard "forwards" LSTM RNN and the bidirectional LSTM RNN
|
* Shared code for the standard "forwards" LSTM RNN and the bidirectional LSTM RNN
|
||||||
* This was extracted from GravesLSTM and refactored into static helper functions. The general reasoning for this was
|
* This was extracted from GravesLSTM and refactored into static helper functions. The general reasoning for this was
|
||||||
|
|
|
@ -826,7 +826,7 @@ public class ParallelWrapper implements AutoCloseable {
|
||||||
/**
|
/**
|
||||||
* This method allows you to specify training mode for this instance of PW.<br>
|
* This method allows you to specify training mode for this instance of PW.<br>
|
||||||
* 1) AVERAGING - stands for parameters averaging. Each X epochs weights and updaters state will be averaged across all models<br>
|
* 1) AVERAGING - stands for parameters averaging. Each X epochs weights and updaters state will be averaged across all models<br>
|
||||||
* 2) SHARED_GRADIENTS - stands for gradients sharing - more details available here: <a href="https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-intro">https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-intro</a><br>
|
* 2) SHARED_GRADIENTS - stands for gradients sharing - more details available here: <a href="https://deeplearning4j.konduit.ai/distributed-deep-learning/intro">https://deeplearning4j.konduit.ai/distributed-deep-learning/intro</a><br>
|
||||||
* 3) CUSTOM - this method allows you to specify custom gradients accumulator, this giving you better control of configuration params for training.<br>
|
* 3) CUSTOM - this method allows you to specify custom gradients accumulator, this giving you better control of configuration params for training.<br>
|
||||||
*
|
*
|
||||||
* @param mode
|
* @param mode
|
||||||
|
|
|
@ -71,7 +71,7 @@ public class SparkUtils {
|
||||||
+ "for ND4J INDArrays.\nWhen using Kryo, An appropriate Kryo registrator must be used to avoid"
|
+ "for ND4J INDArrays.\nWhen using Kryo, An appropriate Kryo registrator must be used to avoid"
|
||||||
+ " serialization issues (NullPointerException) with off-heap data in INDArrays.\n"
|
+ " serialization issues (NullPointerException) with off-heap data in INDArrays.\n"
|
||||||
+ "Use nd4j-kryo_2.10 or _2.11 artifact, with sparkConf.set(\"spark.kryo.registrator\", \"org.nd4j.kryo.Nd4jRegistrator\");\n"
|
+ "Use nd4j-kryo_2.10 or _2.11 artifact, with sparkConf.set(\"spark.kryo.registrator\", \"org.nd4j.kryo.Nd4jRegistrator\");\n"
|
||||||
+ "See https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-howto#kryo for more details";
|
+ "See https://deeplearning4j.konduit.ai/distributed-deep-learning/howto#how-to-use-kryo-serialization-with-dl-4-j-and-nd-4-j for more details";
|
||||||
|
|
||||||
private static String sparkExecutorId;
|
private static String sparkExecutorId;
|
||||||
|
|
||||||
|
|
|
@ -108,7 +108,7 @@ public class TestSparkDl4jMultiLayer extends BaseSparkTest {
|
||||||
.activation(Activation.SOFTMAX).nIn(100).nOut(10).build())
|
.activation(Activation.SOFTMAX).nIn(100).nOut(10).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
//Configuration for Spark training: see https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-howto for explanation of these configuration options
|
//Configuration for Spark training: see https://deeplearning4j.konduit.ai/distributed-deep-learning/howto for explanation of these configuration options
|
||||||
|
|
||||||
TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(batchSizePerWorker)
|
TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(batchSizePerWorker)
|
||||||
.averagingFrequency(2)
|
.averagingFrequency(2)
|
||||||
|
|
|
@ -195,7 +195,7 @@ public class CharacterIterator implements DataSetIterator {
|
||||||
// dimension 0 = number of examples in minibatch
|
// dimension 0 = number of examples in minibatch
|
||||||
// dimension 1 = size of each vector (i.e., number of characters)
|
// dimension 1 = size of each vector (i.e., number of characters)
|
||||||
// dimension 2 = length of each time series/example
|
// dimension 2 = length of each time series/example
|
||||||
//Why 'f' order here? See https://deeplearning4j.org/docs/latest/deeplearning4j-nn-recurrent data section "Alternative: Implementing a custom DataSetIterator"
|
//Why 'f' order here? See https://deeplearning4j.konduit.ai/models/recurrent data section "Alternative: Implementing a custom DataSetIterator"
|
||||||
INDArray input = Nd4j.create(new int[]{currMinibatchSize, validCharacters.length, exampleLength}, 'f');
|
INDArray input = Nd4j.create(new int[]{currMinibatchSize, validCharacters.length, exampleLength}, 'f');
|
||||||
INDArray labels = Nd4j.create(new int[]{currMinibatchSize, validCharacters.length, exampleLength}, 'f');
|
INDArray labels = Nd4j.create(new int[]{currMinibatchSize, validCharacters.length, exampleLength}, 'f');
|
||||||
|
|
||||||
|
|
|
@ -231,30 +231,34 @@ if(SD_CUDA)
|
||||||
file(GLOB_RECURSE CUSTOMOPS_CUDNN_SOURCES false ../include/ops/declarable/platform/cudnn/*.cu)
|
file(GLOB_RECURSE CUSTOMOPS_CUDNN_SOURCES false ../include/ops/declarable/platform/cudnn/*.cu)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_library(nd4jobj OBJECT ${LOOPS_SOURCES_CUDA} ${LEGACY_SOURCES}
|
add_library(samediff_obj OBJECT ${LOOPS_SOURCES_CUDA} ${LEGACY_SOURCES}
|
||||||
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
||||||
${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES})
|
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES})
|
||||||
|
|
||||||
# Don't output dynamic linked lib when a static lib build is specified unless the tests are built
|
|
||||||
if(NOT SD_STATIC_LIB OR SD_BUILD_TESTS)
|
|
||||||
add_library(${SD_LIBRARY_NAME} SHARED $<TARGET_OBJECTS:nd4jobj>)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
|
|
||||||
if (WIN32)
|
if (WIN32)
|
||||||
message("MSVC runtime for library: ${MSVC_RT_LIB}")
|
message("MSVC runtime for library: ${MSVC_RT_LIB}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# static library is built only if we're going to build tests, skip otherwise
|
# build shared library by default or when it's explicitly requested
|
||||||
if (SD_BUILD_TESTS OR SD_STATIC_LIB)
|
if(NOT SD_STATIC_LIB OR SD_SHARED_LIB)
|
||||||
add_library(${SD_LIBRARY_NAME}static STATIC $<TARGET_OBJECTS:nd4jobj>)
|
add_library(${SD_LIBRARY_NAME} SHARED $<TARGET_OBJECTS:samediff_obj>)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (SD_STATIC_LIB AND SD_SHARED_LIB)
|
||||||
|
# if both static and shared library are going to be built - static library will have special suffix
|
||||||
|
add_library(${SD_LIBRARY_NAME}static STATIC $<TARGET_OBJECTS:samediff_obj>)
|
||||||
set_property(TARGET ${SD_LIBRARY_NAME}static PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
set_property(TARGET ${SD_LIBRARY_NAME}static PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
||||||
install(TARGETS ${SD_LIBRARY_NAME}static DESTINATION .)
|
install(TARGETS ${SD_LIBRARY_NAME}static DESTINATION .)
|
||||||
|
elseif(SD_STATIC_LIB)
|
||||||
|
# if we only build static library - use this name
|
||||||
|
add_library(${SD_LIBRARY_NAME} STATIC $<TARGET_OBJECTS:samediff_obj>)
|
||||||
|
set_property(TARGET ${SD_LIBRARY_NAME} PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
||||||
|
install(TARGETS ${SD_LIBRARY_NAME} DESTINATION .)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# on windows we want to make sure we use MT or MD, but since we use it in one lib, we must use it everywhere to avoid conflicts
|
# on windows we want to make sure we use MT or MD, but since we use it in one lib, we must use it everywhere to avoid conflicts
|
||||||
set_property(TARGET nd4jobj PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
set_property(TARGET samediff_obj PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
||||||
set_property(TARGET ${SD_LIBRARY_NAME} PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
set_property(TARGET ${SD_LIBRARY_NAME} PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
||||||
|
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
|
@ -324,20 +328,28 @@ elseif(SD_CPU)
|
||||||
|
|
||||||
message("CPU BLAS")
|
message("CPU BLAS")
|
||||||
add_definitions(-D__CPUBLAS__=true)
|
add_definitions(-D__CPUBLAS__=true)
|
||||||
add_library(nd4jobj OBJECT ${LEGACY_SOURCES}
|
add_library(samediff_obj OBJECT ${LEGACY_SOURCES}
|
||||||
${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
||||||
${OPS_SOURCES} ${PERF_SOURCES})
|
${OPS_SOURCES} ${PERF_SOURCES})
|
||||||
if(IOS)
|
if(IOS)
|
||||||
add_library(${SD_LIBRARY_NAME} STATIC $<TARGET_OBJECTS:nd4jobj>)
|
add_library(${SD_LIBRARY_NAME} STATIC $<TARGET_OBJECTS:samediff_obj>)
|
||||||
else()
|
else()
|
||||||
# static library is built only if we're going to build tests, skip otherwise
|
# build shared library by default or when it's explicitly requested
|
||||||
if (SD_BUILD_TESTS OR SD_STATIC_LIB)
|
if(NOT SD_STATIC_LIB OR SD_SHARED_LIB)
|
||||||
add_library(${SD_LIBRARY_NAME}static STATIC $<TARGET_OBJECTS:nd4jobj>)
|
add_library(${SD_LIBRARY_NAME} SHARED $<TARGET_OBJECTS:samediff_obj>)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(SD_BUILD_TESTS OR NOT SD_STATIC_LIB)
|
if (SD_STATIC_LIB AND SD_SHARED_LIB)
|
||||||
add_library(${SD_LIBRARY_NAME} SHARED $<TARGET_OBJECTS:nd4jobj>)
|
# if both static and shared library are going to be built - static library will have special suffix
|
||||||
|
add_library(${SD_LIBRARY_NAME}static STATIC $<TARGET_OBJECTS:samediff_obj>)
|
||||||
|
set_property(TARGET ${SD_LIBRARY_NAME}static PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
||||||
|
install(TARGETS ${SD_LIBRARY_NAME}static DESTINATION .)
|
||||||
|
elseif(SD_STATIC_LIB)
|
||||||
|
# if we only build static library - use this name
|
||||||
|
add_library(${SD_LIBRARY_NAME} STATIC $<TARGET_OBJECTS:samediff_obj>)
|
||||||
|
set_property(TARGET ${SD_LIBRARY_NAME} PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
||||||
|
install(TARGETS ${SD_LIBRARY_NAME} DESTINATION .)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -350,7 +362,7 @@ elseif(SD_CPU)
|
||||||
if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}")
|
if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}")
|
||||||
message(STATUS "Building minifier...")
|
message(STATUS "Building minifier...")
|
||||||
add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp)
|
add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp)
|
||||||
target_link_libraries(minifier ${SD_LIBRARY_NAME}static ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
target_link_libraries(minifier samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
|
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
|
||||||
|
|
|
@ -981,12 +981,12 @@ namespace sd {
|
||||||
* these methods suited for FlatBuffers use
|
* these methods suited for FlatBuffers use
|
||||||
*/
|
*/
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<T> getBufferAsVector();
|
std::vector<T> getBufferAsVector() const;
|
||||||
std::vector<Nd4jLong> getShapeAsVector() const;
|
std::vector<Nd4jLong> getShapeAsVector() const;
|
||||||
std::vector<int> getShapeAsVectorInt() const;
|
std::vector<int> getShapeAsVectorInt() const;
|
||||||
std::vector<Nd4jLong> getShapeInfoAsVector();
|
std::vector<Nd4jLong> getShapeInfoAsVector() const;
|
||||||
std::vector<int64_t> getShapeInfoAsFlatVector();
|
std::vector<int64_t> getShapeInfoAsFlatVector() const;
|
||||||
std::vector<int64_t> getShapeAsFlatVector();
|
std::vector<int64_t> getShapeAsFlatVector() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* set new order and shape in case of suitable array length (in-place operation)
|
* set new order and shape in case of suitable array length (in-place operation)
|
||||||
|
|
|
@ -982,16 +982,16 @@ std::string NDArray::asString(Nd4jLong limit) {
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
std::vector<T> NDArray::getBufferAsVector() {
|
std::vector<T> NDArray::getBufferAsVector() const {
|
||||||
std::vector<T> vector(lengthOf());
|
std::vector<T> vector(lengthOf());
|
||||||
for (Nd4jLong e = 0; e < lengthOf(); e++)
|
for (Nd4jLong e = 0; e < lengthOf(); e++)
|
||||||
vector[e] = this->e<T>(e);
|
vector[e] = this->e<T>(e);
|
||||||
return vector;
|
return vector;
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::getBufferAsVector(), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::getBufferAsVector() const, LIBND4J_TYPES);
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
std::vector<int64_t> NDArray::getShapeAsFlatVector() {
|
std::vector<int64_t> NDArray::getShapeAsFlatVector() const {
|
||||||
std::vector<int64_t> vector(this->rankOf());
|
std::vector<int64_t> vector(this->rankOf());
|
||||||
for (int e = 0; e < this->rankOf(); e++)
|
for (int e = 0; e < this->rankOf(); e++)
|
||||||
vector[e] = static_cast<int64_t>(this->sizeAt(e));
|
vector[e] = static_cast<int64_t>(this->sizeAt(e));
|
||||||
|
@ -1019,7 +1019,7 @@ std::vector<int> NDArray::getShapeAsVectorInt() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
std::vector<int64_t> NDArray::getShapeInfoAsFlatVector() {
|
std::vector<int64_t> NDArray::getShapeInfoAsFlatVector() const {
|
||||||
int magicNumber = shape::shapeInfoLength(this->rankOf());
|
int magicNumber = shape::shapeInfoLength(this->rankOf());
|
||||||
std::vector<int64_t> vector(magicNumber);
|
std::vector<int64_t> vector(magicNumber);
|
||||||
|
|
||||||
|
@ -1030,7 +1030,7 @@ std::vector<int64_t> NDArray::getShapeInfoAsFlatVector() {
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
std::vector<Nd4jLong> NDArray::getShapeInfoAsVector() {
|
std::vector<Nd4jLong> NDArray::getShapeInfoAsVector() const {
|
||||||
int magicNumber = shape::shapeInfoLength(this->rankOf());
|
int magicNumber = shape::shapeInfoLength(this->rankOf());
|
||||||
std::vector<Nd4jLong> vector(magicNumber);
|
std::vector<Nd4jLong> vector(magicNumber);
|
||||||
for (int e = 0; e < magicNumber; e++)
|
for (int e = 0; e < magicNumber; e++)
|
||||||
|
|
|
@ -15,7 +15,8 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <system/op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
|
@ -27,24 +28,58 @@
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
CONFIGURABLE_OP_IMPL(clipbyavgnorm, 1, 1, true, 1, 0) {
|
CONFIGURABLE_OP_IMPL(clipbyavgnorm, 1, 1, true, 1, 0) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
const bool isInplace = block.isInplace();
|
const bool isInplace = block.isInplace();
|
||||||
auto ts = NDArrayFactory::create(T_ARG(0), block.launchContext());
|
auto clipNorm = NDArrayFactory::create(T_ARG(0), block.launchContext());
|
||||||
|
|
||||||
helpers::clipByAveraged(block.launchContext(), *input, *output, *block.getIArguments(), ts, isInplace);
|
helpers::clipByNorm(block.launchContext(), *input, *output, *block.getIArguments(), clipNorm, isInplace, true);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(clipbyavgnorm) {
|
DECLARE_TYPES(clipbyavgnorm) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(sd::DataType::ANY)
|
->setAllowedInputTypes(sd::DataType::ANY)
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
CUSTOM_OP_IMPL(clipbyavgnorm_bp, 2, 1, false, 1, 0) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto gradO = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const auto clipNorm = NDArrayFactory::create(gradI->dataType(), T_ARG(0), block.launchContext());
|
||||||
|
|
||||||
|
helpers::clipByNormBp(block.launchContext(), *input, *gradO, *gradI, *block.getIArguments(), clipNorm, true);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
DECLARE_SHAPE_FN(clipbyavgnorm_bp) {
|
||||||
|
|
||||||
|
Nd4jLong *newShape = nullptr;
|
||||||
|
COPY_SHAPE(inputShape->at(1), newShape);
|
||||||
|
|
||||||
|
return SHAPELIST(CONSTANT(newShape));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
DECLARE_TYPES(clipbyavgnorm_bp) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, DataType::ANY)
|
||||||
|
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||||
|
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,10 +31,10 @@ namespace ops {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
const auto clipNorm = NDArrayFactory::create(input->dataType(), T_ARG(0), block.launchContext());
|
const auto clipNorm = NDArrayFactory::create(output->dataType(), T_ARG(0), block.launchContext());
|
||||||
const bool isInplace = block.isInplace();
|
const bool isInplace = block.isInplace();
|
||||||
|
|
||||||
helpers::clipByNorm(block.launchContext(), *input, *output, *block.getIArguments(), clipNorm, isInplace);
|
helpers::clipByNorm(block.launchContext(), *input, *output, *block.getIArguments(), clipNorm, isInplace, false);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -45,15 +45,15 @@ namespace ops {
|
||||||
auto gradO = INPUT_VARIABLE(1);
|
auto gradO = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(0);
|
auto gradI = OUTPUT_VARIABLE(0);
|
||||||
const auto clipNorm = NDArrayFactory::create(T_ARG(0));
|
const auto clipNorm = NDArrayFactory::create(gradI->dataType(), T_ARG(0), block.launchContext());
|
||||||
|
|
||||||
helpers::clipByNormBP(block.launchContext(), *input, *gradO, *gradI, *block.getIArguments(), clipNorm);
|
helpers::clipByNormBp(block.launchContext(), *input, *gradO, *gradI, *block.getIArguments(), clipNorm, false);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(clipbynorm_bp) {
|
DECLARE_SHAPE_FN(clipbynorm_bp) {
|
||||||
auto inShapeInfo = inputShape->at(0);
|
auto inShapeInfo = inputShape->at(1);
|
||||||
|
|
||||||
Nd4jLong *newShape = nullptr;
|
Nd4jLong *newShape = nullptr;
|
||||||
COPY_SHAPE(inShapeInfo, newShape);
|
COPY_SHAPE(inShapeInfo, newShape);
|
||||||
|
|
|
@ -23,8 +23,8 @@
|
||||||
#include<ops/declarable/helpers/transforms.h>
|
#include<ops/declarable/helpers/transforms.h>
|
||||||
#include<array>
|
#include<array>
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -85,6 +85,7 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) {
|
||||||
|
|
||||||
// ******** input validation ******** //
|
// ******** input validation ******** //
|
||||||
REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !");
|
REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !");
|
||||||
|
REQUIRE_TRUE(nonEmptyArrs[0]->dataType() == OUTPUT_VARIABLE(0)->dataType(), 0, "CONCAT op: output array should have the same type as inputs arrays !");
|
||||||
REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
|
REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
|
||||||
|
|
||||||
for(int i = 1; i < numOfNonEmptyArrs; ++i)
|
for(int i = 1; i < numOfNonEmptyArrs; ++i)
|
||||||
|
|
|
@ -33,7 +33,7 @@ CUSTOM_OP_IMPL(mergemaxindex, -1, 1, false, 0, 0) {
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
std::vector<const NDArray*> inArrs(block.width());
|
std::vector<const NDArray*> inArrs(block.width());
|
||||||
|
|
||||||
for(int i = 0; i < block.width(); ++i)
|
for(int i = 0; i < block.width(); ++i)
|
||||||
inArrs[i] = INPUT_VARIABLE(i);
|
inArrs[i] = INPUT_VARIABLE(i);
|
||||||
|
|
||||||
|
@ -46,7 +46,8 @@ DECLARE_SYN(MergeMaxIndex, mergemaxindex);
|
||||||
|
|
||||||
DECLARE_TYPES(mergemaxindex) {
|
DECLARE_TYPES(mergemaxindex) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes({ALL_INTS, ALL_FLOATS});
|
->setAllowedInputTypes({ALL_INTS, ALL_FLOATS})
|
||||||
|
->setAllowedOutputTypes({ALL_INDICES});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
DECLARE_SHAPE_FN(mergemaxindex) {
|
DECLARE_SHAPE_FN(mergemaxindex) {
|
||||||
|
|
|
@ -52,7 +52,7 @@ namespace ops {
|
||||||
else {
|
else {
|
||||||
// check the consistency of input dimensions to reverse along
|
// check the consistency of input dimensions to reverse along
|
||||||
shape::checkDimensions(input->rankOf(), axis);
|
shape::checkDimensions(input->rankOf(), axis);
|
||||||
helpers::reverse(block.launchContext(), input, output, &axis, false);
|
helpers::reverse(block.launchContext(), input, output, &axis);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -85,7 +85,7 @@ namespace ops {
|
||||||
// check the consistency of input dimensions to reverse along
|
// check the consistency of input dimensions to reverse along
|
||||||
shape::checkDimensions(input->rankOf(), axis);
|
shape::checkDimensions(input->rankOf(), axis);
|
||||||
// we just reverse back original array
|
// we just reverse back original array
|
||||||
helpers::reverse(block.launchContext(), eps, output, &axis, false);
|
helpers::reverse(block.launchContext(), eps, output, &axis);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -36,6 +36,7 @@ namespace sd {
|
||||||
|
|
||||||
#if NOT_EXCLUDED(OP_clipbyavgnorm)
|
#if NOT_EXCLUDED(OP_clipbyavgnorm)
|
||||||
DECLARE_CONFIGURABLE_OP(clipbyavgnorm, 1, 1, true, 1, 0);
|
DECLARE_CONFIGURABLE_OP(clipbyavgnorm, 1, 1, true, 1, 0);
|
||||||
|
DECLARE_CUSTOM_OP(clipbyavgnorm_bp, 2, 1, false, 1, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if NOT_EXCLUDED(OP_cumsum)
|
#if NOT_EXCLUDED(OP_cumsum)
|
||||||
|
|
|
@ -15,83 +15,134 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
// @author sgazeos@gmail.com
|
||||||
|
// @author raver119@gmail.com
|
||||||
//
|
//
|
||||||
|
|
||||||
|
|
||||||
#include <ops/declarable/helpers/transforms.h>
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
#include <helpers/Loops.h>
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
void clipByNorm(sd::LaunchContext* context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace, const bool useAverage) {
|
||||||
static void clipByNorm_(NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
|
||||||
|
|
||||||
const int rank = input.rankOf();
|
NDArray* z = nullptr;
|
||||||
const auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions);
|
|
||||||
|
|
||||||
const T normActual = norm2.e<T>(0);
|
if(isInplace) {
|
||||||
const T normClip = clipNorm.e<T>(0);
|
z = &input;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
output.assign(input);
|
||||||
|
z = &output;
|
||||||
|
}
|
||||||
|
|
||||||
if (isInplace) {
|
if(dimensions.empty()) {
|
||||||
|
|
||||||
if(norm2.lengthOf() == 1) {
|
const NDArray actualNorm = useAverage ? z->reduceAlongDimension(reduce::Norm2, {}) / z->lengthOf() : z->reduceAlongDimension(reduce::Norm2, {});
|
||||||
|
|
||||||
if(normActual > normClip)
|
if(actualNorm.e<float>(0) > clipNorm.e<float>(0))
|
||||||
input *= (normClip / normActual);
|
*z *= clipNorm / actualNorm;
|
||||||
}
|
|
||||||
else {
|
|
||||||
|
|
||||||
auto listOfInSubArrs = input.allTensorsAlongDimension(dimensions);
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
|
||||||
for (auto i = start; i < stop; i++) {
|
|
||||||
const T iNormActual = norm2.e<T>(i);
|
|
||||||
if (iNormActual > normClip)
|
|
||||||
*listOfInSubArrs.at(i) *= normClip / iNormActual;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
if(norm2.lengthOf() == 1) {
|
auto listOfSubArrs = z->allTensorsAlongDimension(dimensions);
|
||||||
|
|
||||||
if(normActual > normClip)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
output.assign(input * (normClip / normActual));
|
for (auto i = start; i < stop; i++) {
|
||||||
else
|
const NDArray actualNorm = useAverage ? listOfSubArrs.at(i)->reduceAlongDimension(reduce::Norm2, {}) / listOfSubArrs.at(i)->lengthOf() : listOfSubArrs.at(i)->reduceAlongDimension(reduce::Norm2, {});
|
||||||
output.assign(input);
|
if(actualNorm.e<float>(0) > clipNorm.e<float>(0))
|
||||||
}
|
*listOfSubArrs.at(i) *= clipNorm / actualNorm;
|
||||||
else {
|
}
|
||||||
|
};
|
||||||
auto listOfInSubArrs = input.allTensorsAlongDimension(dimensions);
|
samediff::Threads::parallel_tad(func, 0, listOfSubArrs.size());
|
||||||
auto listOfOutSubArrs = output.allTensorsAlongDimension(dimensions);
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
|
||||||
for (auto i = start; i < stop; i++) {
|
|
||||||
auto inputSubArr = listOfInSubArrs.at(i);
|
|
||||||
auto outputSubArr = listOfOutSubArrs.at(i);
|
|
||||||
outputSubArr->assign(inputSubArr);
|
|
||||||
|
|
||||||
const T iNormActual = norm2.e<T>(i);
|
|
||||||
|
|
||||||
if (iNormActual > clipNorm.e<T>(0))
|
|
||||||
*outputSubArr *= clipNorm / iNormActual;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
template<typename T>
|
||||||
BUILD_SINGLE_SELECTOR(output.dataType(), clipByNorm_, (input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES);
|
static void clipByNormBp_(const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage) {
|
||||||
|
|
||||||
|
const int rank = input.rankOf();
|
||||||
|
|
||||||
|
auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions);
|
||||||
|
auto sums = input.reduceAlongDimension(reduce::Sum, dimensions);
|
||||||
|
|
||||||
|
if(norm2.lengthOf() == 1) {
|
||||||
|
|
||||||
|
const T norm = useAverage ? norm2.e<T>(0) / input.lengthOf() : norm2.e<T>(0);
|
||||||
|
|
||||||
|
auto clipVal = clipNorm.e<T>(0);
|
||||||
|
|
||||||
|
if(norm > clipVal) {
|
||||||
|
|
||||||
|
const T sum = sums.e<T>(0); // reduce to scalar
|
||||||
|
const T factor1 = clipVal / norm;
|
||||||
|
const T factor2 = static_cast<T>(1.f) / (norm * norm); // 1 / (norm*norm*norm)
|
||||||
|
|
||||||
|
auto lambda = LAMBDA_TT(x, y, sum, factor1, factor2) {
|
||||||
|
return factor1 * y * (static_cast<T>(1.f) - factor2 * x * sum);
|
||||||
|
};
|
||||||
|
|
||||||
|
const_cast<NDArray&>(input).applyPairwiseLambda<T>(const_cast<NDArray&>(gradO), lambda, gradI);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
gradI.assign(gradO);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
auto gradISubArrs = gradI.allTensorsAlongDimension({dimensions});
|
||||||
|
auto gradOSubArrs = gradO.allTensorsAlongDimension({dimensions});
|
||||||
|
auto inputSubArrs = input.allTensorsAlongDimension({dimensions});
|
||||||
|
|
||||||
|
auto clipVal = clipNorm.e<T>(0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
|
||||||
|
auto gradOSubArr = gradOSubArrs.at(i);
|
||||||
|
auto gradISubArr = gradISubArrs.at(i);
|
||||||
|
|
||||||
|
const T norm = useAverage ? norm2.e<T>(i) / gradISubArr->lengthOf() : norm2.e<T>(i);
|
||||||
|
|
||||||
|
if (norm > clipVal) {
|
||||||
|
|
||||||
|
auto inputSubArr = inputSubArrs.at(i);
|
||||||
|
|
||||||
|
const T sum = sums.e<T>(i); // reduce to scalar
|
||||||
|
const T factor1 = clipVal / norm;
|
||||||
|
const T factor2 = static_cast<T>(1.f) / (norm * norm); // 1 / (norm*norm*norm)
|
||||||
|
|
||||||
|
auto lambda = LAMBDA_TT(x, y, sum, factor1, factor2) {
|
||||||
|
return factor1 * y * (static_cast<T>(1.f) - factor2 * x * sum);
|
||||||
|
};
|
||||||
|
|
||||||
|
inputSubArr->applyPairwiseLambda<T>(*gradOSubArr, lambda, *gradISubArr);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
gradISubArr->assign(gradOSubArr);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_tad(func, 0, gradISubArrs.size());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void clipByNormBp_, (const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage), FLOAT_TYPES);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void clipByNormBp(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage) {
|
||||||
|
|
||||||
|
const NDArray& castedInput = gradI.dataType() == input.dataType() ? input : input.cast(gradI.dataType());
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBp_, (castedInput, gradO, gradI, dimensions, clipNorm, useAverage), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -132,125 +183,6 @@ void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, co
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace), FLOAT_TYPES);
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template<typename T>
|
|
||||||
static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm) {
|
|
||||||
|
|
||||||
const int rank = input.rankOf();
|
|
||||||
|
|
||||||
auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions);
|
|
||||||
|
|
||||||
if(norm2.lengthOf() == 1) {
|
|
||||||
|
|
||||||
const T N = norm2.e<T>(0);
|
|
||||||
|
|
||||||
auto cn = clipNorm.e<T>(0);
|
|
||||||
|
|
||||||
if(N > cn) {
|
|
||||||
|
|
||||||
const T sumOfProd = (input * gradO).reduceNumber(reduce::Sum).e<T>(0); // reduce to scalar
|
|
||||||
const T factor1 = static_cast<T>(1.f) / N;
|
|
||||||
const T factor3 = factor1 / (N * N); // 1 / (N*N*N)
|
|
||||||
|
|
||||||
auto lambda = LAMBDA_TT(elem1, elem2, cn, sumOfProd, factor1, factor3) {
|
|
||||||
return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd);
|
|
||||||
};
|
|
||||||
|
|
||||||
(const_cast<NDArray&>(input)).applyPairwiseLambda<T>(const_cast<NDArray&>(gradO), lambda, gradI);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
gradI.assign(gradO);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
|
|
||||||
auto gradISubArrs = gradI.allTensorsAlongDimension({dimensions});
|
|
||||||
auto gradOSubArrs = gradO.allTensorsAlongDimension({dimensions});
|
|
||||||
auto inputSubArrs = input.allTensorsAlongDimension({dimensions});
|
|
||||||
|
|
||||||
auto cn = clipNorm.e<T>(0);
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
|
||||||
for (auto i = start; i < stop; i++) {
|
|
||||||
T N = norm2.e<T>(i);
|
|
||||||
|
|
||||||
auto gradOSubArr = gradOSubArrs.at(i);
|
|
||||||
auto gradISubArr = gradISubArrs.at(i);
|
|
||||||
|
|
||||||
if (N > cn) {
|
|
||||||
auto inputSubArr = inputSubArrs.at(i);
|
|
||||||
const T sumOfProd = (*inputSubArr * *gradOSubArr).reduceNumber(reduce::Sum).e<T>(0); // reduce to scalar
|
|
||||||
const T factor1 = static_cast<T>(1.f) / N;
|
|
||||||
const T factor3 = factor1 / (N * N); // 1 / (N*N*N)
|
|
||||||
|
|
||||||
auto lambda = LAMBDA_TT(elem1, elem2, cn, sumOfProd, factor1, factor3) {
|
|
||||||
return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd);
|
|
||||||
};
|
|
||||||
|
|
||||||
inputSubArr->applyPairwiseLambda<T>(*gradOSubArr, lambda, *gradISubArr);
|
|
||||||
} else
|
|
||||||
gradISubArr->assign(gradOSubArr);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_tad(func, 0, gradISubArrs.size());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void clipByNormBP(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm) {
|
|
||||||
BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBP_, (input, gradO, gradI, dimensions, clipNorm), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void clipByNormBP_, (const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm), FLOAT_TYPES);
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template<typename T>
|
|
||||||
static void clipByAveraged_(NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
|
||||||
|
|
||||||
auto cn = clipNorm.e<T>(0);
|
|
||||||
if (dimensions.size() == 0) {
|
|
||||||
// all-reduce
|
|
||||||
T n2 = input.reduceNumber(reduce::Norm2).e<T>(0) / input.lengthOf();
|
|
||||||
if (n2 <= cn) {
|
|
||||||
if (!isInplace)
|
|
||||||
output.assign(input);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
const T factor = cn / n2;
|
|
||||||
auto lambda = LAMBDA_T(_x, factor) { return _x * factor; };
|
|
||||||
input.applyLambda<T>(lambda, output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// along dimension
|
|
||||||
auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions, false);
|
|
||||||
if (!isInplace)
|
|
||||||
output.assign(input);
|
|
||||||
auto tads = output.allTensorsAlongDimension(dimensions);
|
|
||||||
// TODO: make this CUDA-compliant somehow
|
|
||||||
for (int e = 0; e < tads.size(); e++) {
|
|
||||||
T n2 = norm2.e<T>(e) / tads.at(e)->lengthOf();
|
|
||||||
const T factor = cn / n2;
|
|
||||||
if (n2 > cn) {
|
|
||||||
auto lambda = LAMBDA_T(_x, factor) {return _x * factor;};
|
|
||||||
tads.at(e)->applyLambda<T>(lambda, output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void clipByAveraged(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
|
||||||
BUILD_SINGLE_SELECTOR(input.dataType(), clipByAveraged_, (input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void clipByAveraged_, (NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES);
|
|
||||||
|
|
||||||
/*
|
|
||||||
if (d1 > params[1])
|
|
||||||
return params[1];
|
|
||||||
else if (d1 < params[0])
|
|
||||||
return params[0];
|
|
||||||
else return d1;
|
|
||||||
*/
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void clipByValue_(NDArray& input, double leftBound, double rightBound, NDArray& output) {
|
static void clipByValue_(NDArray& input, double leftBound, double rightBound, NDArray& output) {
|
||||||
|
|
|
@ -29,7 +29,7 @@ namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename X, typename Z>
|
||||||
static void mergeMaxIndex_(const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
static void mergeMaxIndex_(const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
|
|
||||||
const Nd4jLong numArgs = inArrs.size();
|
const Nd4jLong numArgs = inArrs.size();
|
||||||
|
@ -37,17 +37,18 @@ static void mergeMaxIndex_(const std::vector<const NDArray*>& inArrs, NDArray& o
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto e = start; e < stop; e++) {
|
for (auto e = start; e < stop; e++) {
|
||||||
T max = -DataTypeUtils::max<T>();
|
X max = -DataTypeUtils::max<X>();
|
||||||
Nd4jLong idx = 0;
|
Z idx = static_cast<Z>(0);
|
||||||
|
|
||||||
for (Nd4jLong i = 0; i < numArgs; i++) {
|
for (Nd4jLong i = 0; i < numArgs; i++) {
|
||||||
T v = inArrs[i]->e<T>(e);
|
X v = inArrs[i]->t<X>(e);
|
||||||
if (v > max) {
|
if (v > max) {
|
||||||
max = v;
|
max = v;
|
||||||
idx = i;
|
idx = static_cast<Z>(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
output.p(e, idx);
|
// FIXME, use .r<Z>(e)
|
||||||
|
output.t<Z>(e) = static_cast<Z>(idx);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -55,14 +56,14 @@ static void mergeMaxIndex_(const std::vector<const NDArray*>& inArrs, NDArray& o
|
||||||
}
|
}
|
||||||
|
|
||||||
void mergeMaxIndex(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
void mergeMaxIndex(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), mergeMaxIndex_, (inArrs, output), LIBND4J_TYPES);
|
BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (inArrs, output), LIBND4J_TYPES, INDEXING_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void mergeMax_(const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
static void mergeMax_(const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||||
|
|
||||||
const Nd4jLong numArgs = inArrs.size();
|
const Nd4jLong numArgs = inArrs.size();
|
||||||
auto x = inArrs[0];
|
auto x = inArrs[0];
|
||||||
|
|
||||||
|
@ -89,15 +90,15 @@ void mergeMax(sd::LaunchContext * context, const std::vector<const NDArray*>& in
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void mergeMaxBp_(const std::vector<const NDArray*>& inArrs, std::vector<NDArray*>& outArrs) {
|
static void mergeMaxBp_(const std::vector<const NDArray*>& inArrs, std::vector<NDArray*>& outArrs) {
|
||||||
|
|
||||||
// outArrs.size() == inArrs.size() - 1
|
// outArrs.size() == inArrs.size() - 1
|
||||||
const Nd4jLong numArgs = outArrs.size();
|
const Nd4jLong numArgs = outArrs.size();
|
||||||
// last array is gradient
|
// last array is gradient
|
||||||
const auto gradient = inArrs[numArgs]->bufferAsT<T>();
|
const auto gradient = inArrs[numArgs]->bufferAsT<T>();
|
||||||
auto length = inArrs[numArgs]->lengthOf();
|
auto length = inArrs[numArgs]->lengthOf();
|
||||||
|
|
||||||
bool bSameOrderAndEws1 = (1 == inArrs[numArgs]->ews());
|
bool bSameOrderAndEws1 = (1 == inArrs[numArgs]->ews());
|
||||||
|
|
||||||
if (bSameOrderAndEws1) {
|
if (bSameOrderAndEws1) {
|
||||||
auto gradOrdering = inArrs[numArgs]->ordering();
|
auto gradOrdering = inArrs[numArgs]->ordering();
|
||||||
|
|
||||||
|
@ -108,8 +109,8 @@ static void mergeMaxBp_(const std::vector<const NDArray*>& inArrs, std::vector<N
|
||||||
bSameOrderAndEws1 &= (1 == outArrs[i]->ews());
|
bSameOrderAndEws1 &= (1 == outArrs[i]->ews());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if(bSameOrderAndEws1){
|
if(bSameOrderAndEws1){
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
for (auto e = start; e < stop; e++) {
|
for (auto e = start; e < stop; e++) {
|
||||||
|
@ -130,7 +131,7 @@ static void mergeMaxBp_(const std::vector<const NDArray*>& inArrs, std::vector<N
|
||||||
samediff::Threads::parallel_for(func, 0, length);
|
samediff::Threads::parallel_for(func, 0, length);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto gradShape = inArrs[numArgs]->shapeInfo();
|
auto gradShape = inArrs[numArgs]->shapeInfo();
|
||||||
std::vector<bool> vbSameShaepeAndStrides(numArgs);
|
std::vector<bool> vbSameShaepeAndStrides(numArgs);
|
||||||
for (int i = 0; i < numArgs; ++i) {
|
for (int i = 0; i < numArgs; ++i) {
|
||||||
|
@ -145,12 +146,12 @@ static void mergeMaxBp_(const std::vector<const NDArray*>& inArrs, std::vector<N
|
||||||
shape::index2coordsCPU(start, e, gradShape, coords);
|
shape::index2coordsCPU(start, e, gradShape, coords);
|
||||||
|
|
||||||
const auto gradOffset = shape::getOffset(gradShape, coords);
|
const auto gradOffset = shape::getOffset(gradShape, coords);
|
||||||
|
|
||||||
T max = -DataTypeUtils::max<T>();
|
T max = -DataTypeUtils::max<T>();
|
||||||
Nd4jLong nMaxIndex = 0;
|
Nd4jLong nMaxIndex = 0;
|
||||||
|
|
||||||
for (Nd4jLong i = 0; i < numArgs; i++) {
|
for (Nd4jLong i = 0; i < numArgs; i++) {
|
||||||
|
|
||||||
const auto xOffset = vbSameShaepeAndStrides[i] ? gradOffset : shape::getOffset(inArrs[i]->shapeInfo(), coords);
|
const auto xOffset = vbSameShaepeAndStrides[i] ? gradOffset : shape::getOffset(inArrs[i]->shapeInfo(), coords);
|
||||||
const T* v = inArrs[i]->bufferAsT<T>();
|
const T* v = inArrs[i]->bufferAsT<T>();
|
||||||
if (v[xOffset] > max) {
|
if (v[xOffset] > max) {
|
||||||
|
@ -160,7 +161,7 @@ static void mergeMaxBp_(const std::vector<const NDArray*>& inArrs, std::vector<N
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto zOffset = vbSameShaepeAndStrides[nMaxIndex] ? gradOffset : shape::getOffset(outArrs[nMaxIndex]->shapeInfo(), coords);
|
const auto zOffset = vbSameShaepeAndStrides[nMaxIndex] ? gradOffset : shape::getOffset(outArrs[nMaxIndex]->shapeInfo(), coords);
|
||||||
|
|
||||||
T* z = outArrs[nMaxIndex]->bufferAsT<T>();
|
T* z = outArrs[nMaxIndex]->bufferAsT<T>();
|
||||||
z[zOffset] = gradient[gradOffset];
|
z[zOffset] = gradient[gradOffset];
|
||||||
}
|
}
|
||||||
|
|
|
@ -193,13 +193,10 @@ static void reverseSequence_(sd::LaunchContext * context, const NDArray* input,
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs, bool isBackProp) {
|
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs) {
|
||||||
|
|
||||||
// we need to reverse axis only if that's new op
|
auto listOut = output->allTensorsAlongDimension(*intArgs);
|
||||||
std::vector<int> dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs;
|
auto listIn = input->allTensorsAlongDimension(*intArgs);
|
||||||
|
|
||||||
auto listOut = output->allTensorsAlongDimension(dimensions);
|
|
||||||
auto listIn = input->allTensorsAlongDimension(dimensions);
|
|
||||||
|
|
||||||
NDArray *subArrIn, *subArrOut;
|
NDArray *subArrIn, *subArrOut;
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,334 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
// @author sgazeos@gmail.com
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ static void clipByNormCuda(const void* vClipNorm, const void* vNorm, const Nd4jLong* normShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int* dimensions, const int dimsLen, const bool useAverage) {
|
||||||
|
|
||||||
|
const T clipNorm = *reinterpret_cast<const T*>(vClipNorm);
|
||||||
|
const T* norm = reinterpret_cast<const T*>(vNorm);
|
||||||
|
T* z = reinterpret_cast<T*>(vz);
|
||||||
|
|
||||||
|
__shared__ Nd4jLong zLen, tadLen, totalThreads;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
zLen = shape::length(zShapeInfo);
|
||||||
|
tadLen = zLen / shape::length(normShapeInfo);
|
||||||
|
totalThreads = gridDim.x * blockDim.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int zCoords[MAX_RANK], normCoords[MAX_RANK];
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
||||||
|
|
||||||
|
shape::index2coords(i, zShapeInfo, zCoords);
|
||||||
|
|
||||||
|
// deduce norm coords
|
||||||
|
for (int j = 0; j < dimsLen; ++j)
|
||||||
|
normCoords[j] = zCoords[dimensions[j]];
|
||||||
|
|
||||||
|
const T actualNorm = useAverage ? norm[shape::getOffset(normShapeInfo, normCoords)] / tadLen : norm[shape::getOffset(normShapeInfo, normCoords)];
|
||||||
|
|
||||||
|
if(actualNorm > clipNorm)
|
||||||
|
z[shape::getOffset(zShapeInfo, zCoords)] *= clipNorm / actualNorm;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__host__ static void clipByNormCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
|
||||||
|
const void* vClipNorm, const void* vNorm, const Nd4jLong* normShapeInfo, void* vz, const Nd4jLong* zShapeInfo,
|
||||||
|
const int* dimensions, const int dimsLen, const bool useAverage) {
|
||||||
|
|
||||||
|
clipByNormCuda<T><<<blocksPerGrid, threadsPerBlock, 512, *stream>>>(vClipNorm, vNorm, normShapeInfo, vz, zShapeInfo, dimensions, dimsLen, useAverage);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void clipByNorm(sd::LaunchContext* context, NDArray& input, NDArray& output, const std::vector<int>& dims, const NDArray& clipNorm, const bool isInplace, const bool useAverage) {
|
||||||
|
|
||||||
|
NDArray* z = nullptr;
|
||||||
|
|
||||||
|
if(isInplace) {
|
||||||
|
z = &input;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
output.assign(input);
|
||||||
|
z = &output;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(dims.empty()) {
|
||||||
|
|
||||||
|
const NDArray actualNorm = useAverage ? z->reduceAlongDimension(reduce::Norm2, {}) / z->lengthOf() : z->reduceAlongDimension(reduce::Norm2, {});
|
||||||
|
|
||||||
|
if(actualNorm.e<float>(0) > clipNorm.e<float>(0))
|
||||||
|
*z *= clipNorm / actualNorm;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
const NDArray actualNorms = z->reduceAlongDimension(reduce::Norm2, dims);
|
||||||
|
|
||||||
|
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(z->rankOf(), dims);
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (z->lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
PointersManager manager(context, "clipByNorm");
|
||||||
|
|
||||||
|
const int* dimensions = reinterpret_cast<const int*>(manager.replicatePointer(dimsToExclude.data(), dimsToExclude.size() * sizeof(int)));
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({z}, {z, &actualNorms, &clipNorm});
|
||||||
|
BUILD_SINGLE_SELECTOR(z->dataType(), clipByNormCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), clipNorm.specialBuffer(), actualNorms.specialBuffer(), actualNorms.specialShapeInfo(), z->specialBuffer(), z->specialShapeInfo(), dimensions, (int)dimsToExclude.size(), useAverage), FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({z}, {z, &actualNorms, &clipNorm});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ static void clipByNormBpCuda(const void* vClipNorm,
|
||||||
|
const void* vx, const Nd4jLong* xShapeInfo, // input
|
||||||
|
const void* vy, const Nd4jLong* yShapeInfo, // gradO
|
||||||
|
const void* vNorm, const Nd4jLong* normShapeInfo,
|
||||||
|
const void* vSum, const Nd4jLong* sumShapeInfo,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo, // gradI
|
||||||
|
const int* dimensions, const int dimsLen, const bool useAverage) {
|
||||||
|
|
||||||
|
const T clipNorm = *reinterpret_cast<const T*>(vClipNorm);
|
||||||
|
const T* norm = reinterpret_cast<const T*>(vNorm);
|
||||||
|
const T* sum = reinterpret_cast<const T*>(vSum);
|
||||||
|
const T* x = reinterpret_cast<const T*>(vx);
|
||||||
|
const T* y = reinterpret_cast<const T*>(vy);
|
||||||
|
T* z = reinterpret_cast<T*>(vz);
|
||||||
|
|
||||||
|
__shared__ Nd4jLong zLen, tadLen, totalThreads;
|
||||||
|
__shared__ bool sameOffsets;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
zLen = shape::length(zShapeInfo);
|
||||||
|
tadLen = zLen / shape::length(normShapeInfo);
|
||||||
|
totalThreads = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
sameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo, zShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int zCoords[MAX_RANK], normCoords[MAX_RANK];
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
||||||
|
|
||||||
|
shape::index2coords(i, zShapeInfo, zCoords);
|
||||||
|
|
||||||
|
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
||||||
|
const auto yOffset = sameOffsets ? zOffset : shape::getOffset(yShapeInfo, zCoords);
|
||||||
|
|
||||||
|
// deduce norm coords
|
||||||
|
for (int j = 0; j < dimsLen; ++j)
|
||||||
|
normCoords[j] = zCoords[dimensions[j]];
|
||||||
|
|
||||||
|
const T actualNorm = useAverage ? norm[shape::getOffset(normShapeInfo, normCoords)] / tadLen : norm[shape::getOffset(normShapeInfo, normCoords)];
|
||||||
|
|
||||||
|
if(actualNorm > clipNorm) {
|
||||||
|
|
||||||
|
const T sumVal = sum[shape::getOffset(sumShapeInfo, normCoords)];
|
||||||
|
const auto xOffset = sameOffsets ? zOffset : shape::getOffset(xShapeInfo, zCoords);
|
||||||
|
|
||||||
|
z[zOffset] = (clipNorm / actualNorm) * y[yOffset] * (static_cast<T>(1.f) - (x[xOffset] * sumVal) / (actualNorm * actualNorm));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
z[zOffset] = y[yOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
void clipByNormBp_(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dims, const NDArray& clipNorm, const bool useAverage) {
|
||||||
|
|
||||||
|
const int rank = input.rankOf();
|
||||||
|
|
||||||
|
auto actualNorms = input.reduceAlongDimension(reduce::Norm2, dims);
|
||||||
|
|
||||||
|
if(actualNorms.lengthOf() == 1) {
|
||||||
|
|
||||||
|
const T norm = useAverage ? actualNorms.e<T>(0) / static_cast<T>(input.lengthOf()) : actualNorms.e<T>(0);
|
||||||
|
|
||||||
|
auto clipVal = clipNorm.e<T>(0);
|
||||||
|
|
||||||
|
if(norm > clipVal) {
|
||||||
|
|
||||||
|
const T sum = input.reduceNumber(reduce::Sum).e<T>(0); // reduce to scalar
|
||||||
|
const T factor1 = clipVal / norm;
|
||||||
|
const T factor2 = static_cast<T>(1.f) / (norm * norm); // 1 / (norm*norm*norm)
|
||||||
|
|
||||||
|
auto lambda = LAMBDA_TT(x, y, sum, factor1, factor2) {
|
||||||
|
return factor1 * y * (static_cast<T>(1.f) - factor2 * x * sum);
|
||||||
|
};
|
||||||
|
|
||||||
|
const_cast<NDArray&>(input).applyPairwiseLambda(const_cast<NDArray&>(gradO), lambda, gradI);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
gradI.assign(gradO);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
const NDArray actualNorms = input.reduceAlongDimension(reduce::Norm2, dims);
|
||||||
|
const NDArray sums = input.reduceAlongDimension(reduce::Sum, dims);
|
||||||
|
|
||||||
|
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(gradI.rankOf(), dims);
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
PointersManager manager(context, "clipByNormBp");
|
||||||
|
|
||||||
|
const int* dimensions = reinterpret_cast<const int*>(manager.replicatePointer(dimsToExclude.data(), dimsToExclude.size() * sizeof(int)));
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&gradI}, {&actualNorms, &sums, &clipNorm, &input, &gradO});
|
||||||
|
clipByNormBpCuda<T><<<blocksPerGrid, threadsPerBlock, 512, *context->getCudaStream()>>>(clipNorm.specialBuffer(), input.specialBuffer(), input.specialShapeInfo(), gradO.specialBuffer(), gradO.specialShapeInfo(), actualNorms.specialBuffer(), actualNorms.specialShapeInfo(), sums.specialBuffer(), sums.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), dimensions, (int)dimsToExclude.size(), useAverage);
|
||||||
|
NDArray::registerSpecialUse({&gradI}, {&actualNorms, &sums, &clipNorm, &input, &gradO});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void clipByNormBp_, (sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage), FLOAT_TYPES);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void clipByNormBp(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage) {
|
||||||
|
|
||||||
|
const NDArray& castedInput = gradI.dataType() == input.dataType() ? input : input.cast(gradI.dataType());
|
||||||
|
BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBp_, (context, castedInput, gradO, gradI, dimensions, clipNorm, useAverage), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void clipByGlobalNorm_(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
|
||||||
|
NDArray globalNorm = NDArrayFactory::create<T>(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list]))
|
||||||
|
|
||||||
|
for (auto i = 0; i < inputs.size(); i++) {
|
||||||
|
auto input = inputs[i];
|
||||||
|
auto l2norm = input->reduceNumber(reduce::Norm2);
|
||||||
|
globalNorm += l2norm * l2norm;
|
||||||
|
}
|
||||||
|
|
||||||
|
globalNorm.applyTransform(transform::Sqrt, globalNorm); // = sd::math::nd4j_sqrt(globalNorm);
|
||||||
|
outputs[inputs.size()]->p(0, globalNorm);
|
||||||
|
globalNorm.syncToHost();
|
||||||
|
const T factor = static_cast<T>(clipNorm) / globalNorm.e<T>(0);
|
||||||
|
|
||||||
|
for (size_t e = 0; e < inputs.size(); e++) {
|
||||||
|
// all-reduce
|
||||||
|
auto input = inputs[e];
|
||||||
|
auto output = outputs[e];
|
||||||
|
|
||||||
|
if (globalNorm.e<double>(0) <= clipNorm) {
|
||||||
|
output->assign(input);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
auto lambda = LAMBDA_T(_x, factor) { return _x * factor; };
|
||||||
|
input->applyLambda(lambda, *output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void clipByGlobalNorm(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
|
||||||
|
BUILD_SINGLE_SELECTOR(outputs[0]->dataType(), clipByGlobalNorm_, (context, inputs, clipNorm, workspace, outputs, isInplace), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace), FLOAT_TYPES);
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void __global__ clipByValueKernel(void* input, const Nd4jLong* inputShape, void* output, const Nd4jLong* outputShape, double leftBound, double rightBound) {
|
||||||
|
__shared__ T* outputBuf;
|
||||||
|
__shared__ T* inputBuf;
|
||||||
|
__shared__ Nd4jLong length;
|
||||||
|
__shared__ bool linearBuffers;
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
outputBuf = reinterpret_cast<T *>(output);
|
||||||
|
inputBuf = reinterpret_cast<T *>(input);
|
||||||
|
length = shape::length(inputShape);
|
||||||
|
linearBuffers = shape::elementWiseStride(inputShape) == shape::elementWiseStride(outputShape) && shape::elementWiseStride(inputShape) == 1;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
for (Nd4jLong e = tid; e < length; e += step) {
|
||||||
|
if (linearBuffers) {
|
||||||
|
if (inputBuf[e] > rightBound) outputBuf[e] = (T) rightBound;
|
||||||
|
else if (inputBuf[e] < leftBound) outputBuf[e] = (T) leftBound;
|
||||||
|
else outputBuf[e] = inputBuf[e];
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto inputOffset = shape::getIndexOffset(e, inputShape);
|
||||||
|
auto outputOffset = shape::getIndexOffset(e, outputShape);
|
||||||
|
if (inputBuf[inputOffset] > rightBound) outputBuf[outputOffset] = (T) rightBound;
|
||||||
|
else if (inputBuf[inputOffset] < leftBound) outputBuf[outputOffset] = (T) leftBound;
|
||||||
|
else outputBuf[outputOffset] = inputBuf[outputOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void clipByValue_(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
if (!input.isActualOnDeviceSide())
|
||||||
|
input.syncToDevice();
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&input});
|
||||||
|
clipByValueKernel<T><<<256, 512, 8192, *stream>>>(input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftBound, rightBound);
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input});
|
||||||
|
}
|
||||||
|
|
||||||
|
void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), clipByValue_, (context, input, leftBound, rightBound, output), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void clipByValue_, (sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output);, FLOAT_TYPES);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -210,14 +210,10 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs, bool isBackProp) {
|
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs) {
|
||||||
// we need to reverse axis only if that's new op
|
|
||||||
std::vector<int> dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs;
|
|
||||||
std::vector<int> axis = ShapeUtils::evalDimsToExclude(input->rankOf(), dimensions);
|
|
||||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions);
|
|
||||||
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions);
|
|
||||||
|
|
||||||
|
|
||||||
|
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), *intArgs);
|
||||||
|
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), *intArgs);
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
|
||||||
|
|
|
@ -300,269 +300,6 @@ void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
// x - input, y - gradO, z - gradI
|
|
||||||
template<typename X, typename Z>
|
|
||||||
__global__ static void clipByNormBPWholeArrCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vreducBuff, const Z clipNormVal) {
|
|
||||||
|
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
|
|
||||||
if(tid >= shape::length(zShapeInfo))
|
|
||||||
return;
|
|
||||||
|
|
||||||
const auto x = reinterpret_cast<const X*>(vx);
|
|
||||||
const auto y = reinterpret_cast<const Z*>(vy);
|
|
||||||
auto z = reinterpret_cast<Z*>(vz);
|
|
||||||
|
|
||||||
auto reducBuff = reinterpret_cast<Z*>(vreducBuff);
|
|
||||||
uint* count = reinterpret_cast<uint*>(vreducBuff) + 16384;
|
|
||||||
|
|
||||||
__shared__ Z* shMem;
|
|
||||||
__shared__ Nd4jLong len;
|
|
||||||
__shared__ bool amIinLastBlock;
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
extern __shared__ unsigned char shmem[];
|
|
||||||
shMem = reinterpret_cast<Z*>(shmem);
|
|
||||||
|
|
||||||
len = shape::length(zShapeInfo); // xLen = yLen = zLen
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// fill shared memory with array elements
|
|
||||||
const auto xVal = x[shape::getIndexOffset(tid, xShapeInfo)];
|
|
||||||
const auto yVal = y[shape::getIndexOffset(tid, yShapeInfo)];
|
|
||||||
|
|
||||||
shMem[2*threadIdx.x] = static_cast<Z>(xVal * xVal); // for norm
|
|
||||||
shMem[2*threadIdx.x + 1] = static_cast<Z>(xVal * yVal); // for input * gradO
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// accumulate sum per block
|
|
||||||
for (int activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
|
|
||||||
|
|
||||||
if (threadIdx.x < activeThreads && tid + activeThreads < len) {
|
|
||||||
|
|
||||||
shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)];
|
|
||||||
shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1];
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
// store accumulated sums in reduction buffer (reducBuff)
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
|
|
||||||
reducBuff[2*blockIdx.x] = shMem[0];
|
|
||||||
reducBuff[2*blockIdx.x + 1] = shMem[1];
|
|
||||||
|
|
||||||
__threadfence();
|
|
||||||
|
|
||||||
amIinLastBlock = gridDim.x == 1 || (atomicInc(count, gridDim.x) == gridDim.x - 1);
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// shared memory of last block is used for final summation of values stored in reduction buffer
|
|
||||||
if (amIinLastBlock) {
|
|
||||||
|
|
||||||
for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) {
|
|
||||||
|
|
||||||
shMem[2*threadIdx.x] = (i == threadIdx.x ) ? reducBuff[2*i] : reducBuff[2*i] + shMem[2*threadIdx.x];
|
|
||||||
shMem[2*threadIdx.x + 1] = (i == threadIdx.x ) ? reducBuff[2*i + 1] : reducBuff[2*i + 1] + shMem[2*threadIdx.x + 1];
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// accumulate sum
|
|
||||||
for (int activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
|
|
||||||
|
|
||||||
if (threadIdx.x < activeThreads && threadIdx.x + activeThreads < gridDim.x) {
|
|
||||||
shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)];
|
|
||||||
shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1];
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
|
|
||||||
reducBuff[0] = math::nd4j_sqrt<Z,Z>(shMem[0]);
|
|
||||||
reducBuff[1] = shMem[1];
|
|
||||||
count = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
// x - input, y - gradO, z - gradI
|
|
||||||
template<typename X, typename Z>
|
|
||||||
__global__ static void clipByNormBPCalcGradCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vreducBuff, const Z clipNormVal) {
|
|
||||||
|
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
|
|
||||||
const Nd4jLong len = shape::length(zShapeInfo); // xLen = yLen = zLen
|
|
||||||
|
|
||||||
if(tid >= len)
|
|
||||||
return;
|
|
||||||
|
|
||||||
const auto x = reinterpret_cast<const X*>(vx);
|
|
||||||
const auto y = reinterpret_cast<const Z*>(vy);
|
|
||||||
auto z = reinterpret_cast<Z*>(vz);
|
|
||||||
|
|
||||||
__shared__ Z norm, sumOfProd;
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
|
|
||||||
norm = reinterpret_cast<Z*>(vreducBuff)[0];
|
|
||||||
sumOfProd = reinterpret_cast<Z*>(vreducBuff)[1];
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
const auto yOffset = shape::getIndexOffset(tid, yShapeInfo);
|
|
||||||
const auto zOffset = shape::getIndexOffset(tid, zShapeInfo);
|
|
||||||
|
|
||||||
if(norm > clipNormVal) {
|
|
||||||
|
|
||||||
const auto xOffset = shape::getIndexOffset(tid, xShapeInfo);
|
|
||||||
|
|
||||||
const Z factor1 = static_cast<Z>(1) / norm; // 1 / norm
|
|
||||||
const Z factor2 = factor1 / (norm * norm); // 1 / (norm * norm * norm)
|
|
||||||
|
|
||||||
z[zOffset] = clipNormVal * (factor1 * y[yOffset] - factor2 * sumOfProd * x[xOffset]);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
z[zOffset] = y[yOffset];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
// x - input, y - gradO, z - gradI
|
|
||||||
template<typename X, typename Z>
|
|
||||||
__global__ static void clipByNormBPTadsCuda(const void* vx, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const void* vy, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, void* vz, const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, const Z clipNormVal) {
|
|
||||||
|
|
||||||
const auto x = reinterpret_cast<const X*>(vx);
|
|
||||||
const auto y = reinterpret_cast<const Z*>(vy);
|
|
||||||
auto z = reinterpret_cast<Z*>(vz);
|
|
||||||
|
|
||||||
__shared__ Z* shMem;
|
|
||||||
__shared__ Nd4jLong tadLen;
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
|
|
||||||
extern __shared__ unsigned char shmem[];
|
|
||||||
shMem = reinterpret_cast<Z*>(shmem);
|
|
||||||
tadLen = shape::length(zTadShapeInfo); // xTadLen = yTadLen = zTadLen
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
const auto* xTad = x + xTadOffsets[blockIdx.x];
|
|
||||||
const auto* yTad = y + yTadOffsets[blockIdx.x];
|
|
||||||
auto* zTad = z + zTadOffsets[blockIdx.x];
|
|
||||||
|
|
||||||
// *** FIRST STAGE - ACCUMULATE REQUIRED SUMS *** //
|
|
||||||
|
|
||||||
Z norm = 0;
|
|
||||||
Z sumOfProd = 0;
|
|
||||||
|
|
||||||
for (uint i = threadIdx.x; i < tadLen; i += blockDim.x) {
|
|
||||||
|
|
||||||
const auto xOffset = shape::getIndexOffset(i, xTadShapeInfo);
|
|
||||||
const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo);
|
|
||||||
|
|
||||||
shMem[2*threadIdx.x] = static_cast<Z>(xTad[xOffset] * xTad[xOffset]); // for norm
|
|
||||||
shMem[2*threadIdx.x + 1] = static_cast<Z>(xTad[xOffset] * yTad[yOffset]); // for input * gradO
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// accumulate sum per block
|
|
||||||
for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
|
|
||||||
|
|
||||||
if (threadIdx.x < activeThreads && i + activeThreads < tadLen) {
|
|
||||||
|
|
||||||
shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)];
|
|
||||||
shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1];
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
norm += shMem[0];
|
|
||||||
sumOfProd += shMem[1];
|
|
||||||
}
|
|
||||||
|
|
||||||
// *** SECOND STAGE - GRADIENT CALCULATION *** //
|
|
||||||
|
|
||||||
norm = math::nd4j_sqrt<Z,Z>(norm);
|
|
||||||
|
|
||||||
for (uint i = threadIdx.x; i < tadLen; i += blockDim.x) {
|
|
||||||
|
|
||||||
const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo);
|
|
||||||
const auto zOffset = shape::getIndexOffset(i, zTadShapeInfo);
|
|
||||||
|
|
||||||
if(norm > clipNormVal) {
|
|
||||||
|
|
||||||
const auto xOffset = shape::getIndexOffset(i, xTadShapeInfo);
|
|
||||||
|
|
||||||
const Z factor1 = static_cast<Z>(1) / norm; // 1 / norm
|
|
||||||
const Z factor2 = factor1 / (norm * norm); // 1 / (norm * norm * norm)
|
|
||||||
|
|
||||||
zTad[zOffset] = clipNormVal * (factor1 * yTad[yOffset] - factor2 * sumOfProd * xTad[xOffset]);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
zTad[zOffset] = yTad[yOffset];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template<typename X, typename Z>
|
|
||||||
static void clipByNormBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
|
|
||||||
const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
|
|
||||||
const void* vy, const Nd4jLong* yShapeInfo, const Nd4jLong* yTadOffsets,
|
|
||||||
void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets,
|
|
||||||
void* vreducBuff, const double clipNormVal) {
|
|
||||||
|
|
||||||
if(xTadOffsets == nullptr) { // means whole array
|
|
||||||
clipByNormBPWholeArrCuda<X,Z><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vreducBuff, static_cast<Z>(clipNormVal));
|
|
||||||
clipByNormBPCalcGradCuda<X,Z><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vreducBuff, static_cast<Z>(clipNormVal));
|
|
||||||
}
|
|
||||||
else // means tads using
|
|
||||||
clipByNormBPTadsCuda<X,Z><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, xTadOffsets, vy, yShapeInfo, yTadOffsets, vz, zShapeInfo, zTadOffsets, static_cast<Z>(clipNormVal));
|
|
||||||
}
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template void clipByNormBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* xTadOffsets, const void *vy, const Nd4jLong *yShapeInfo, const Nd4jLong* yTadOffsets, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, void* vreducBuff, const double clipNormVal), FLOAT_TYPES, FLOAT_TYPES);
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
void clipByNormBP(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm) {
|
|
||||||
|
|
||||||
PointersManager manager(context, "clipByNormBP");
|
|
||||||
|
|
||||||
const double clipNormVal = clipNorm.e<double>(0);
|
|
||||||
|
|
||||||
const auto xType = input.dataType();
|
|
||||||
const auto zType = gradI.dataType();
|
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
|
||||||
const int sharedMem = threadsPerBlock * 2 * input.sizeOfT() + 128;
|
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&gradI}, {&input, &gradO});
|
|
||||||
|
|
||||||
|
|
||||||
if(dimensions.empty() || dimensions.size() == input.rankOf()) { // means whole array
|
|
||||||
|
|
||||||
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), nullptr, gradO.specialBuffer(), gradO.specialShapeInfo(), nullptr, gradI.specialBuffer(), gradI.specialShapeInfo(), nullptr, context->getReductionPointer(), clipNormVal), FLOAT_TYPES, FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
else { // means tads using
|
|
||||||
|
|
||||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions);
|
|
||||||
auto packY = ConstantTadHelper::getInstance()->tadForDimensions(gradO.shapeInfo(), dimensions);
|
|
||||||
auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(gradI.shapeInfo(), dimensions);
|
|
||||||
|
|
||||||
const int blocksPerGrid = packX.numberOfTads();
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), gradO.specialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), gradI.specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), nullptr, clipNormVal), FLOAT_TYPES, FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
NDArray::registerSpecialUse({&gradI}, {&input, &gradO});
|
|
||||||
|
|
||||||
manager.synchronize();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void swapShuffleKernel(T* input, Nd4jLong const* shape, Nd4jLong firstDim, sd::graph::RandomGenerator* rng) {
|
static __global__ void swapShuffleKernel(T* input, Nd4jLong const* shape, Nd4jLong firstDim, sd::graph::RandomGenerator* rng) {
|
||||||
auto tid = blockIdx.x * blockDim.x;
|
auto tid = blockIdx.x * blockDim.x;
|
||||||
|
@ -692,252 +429,6 @@ void clipByNormBP(sd::LaunchContext* context, const NDArray& input, const NDArra
|
||||||
output.setIdentity();
|
output.setIdentity();
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename T>
|
|
||||||
static __global__ void clipByNormInplaceKernel(Nd4jLong numOfSubArrs, T* inputBuffer, Nd4jLong const* shape, Nd4jLong const* inputOffsets, T* norm2Buf, Nd4jLong const* norm2shape, T clipNorm) {
|
|
||||||
for (int arr = blockIdx.x; arr < numOfSubArrs; arr += gridDim.x) {
|
|
||||||
__shared__ T* z;
|
|
||||||
__shared__ Nd4jLong len;
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
len = shape::length(shape);
|
|
||||||
z = inputBuffer + inputOffsets[arr];
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
for (int j = threadIdx.x; j < len; j+= blockDim.x) {
|
|
||||||
auto xIndex = shape::getIndexOffset(j, shape);
|
|
||||||
|
|
||||||
if(norm2Buf[arr] > clipNorm)
|
|
||||||
z[xIndex] *= clipNorm / norm2Buf[arr]; // case with ews = 1 and ordering is 'c'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename T>
|
|
||||||
static __global__ void clipByNormKernel(Nd4jLong numOfSubArrs, T* inputBuffer, Nd4jLong const* shape, Nd4jLong const* inputOffsets, T* outputBuffer, Nd4jLong const* outputShape, Nd4jLong const* outputOffsets, T* norm2Buf, Nd4jLong const* norm2shape, T clipNorm) {
|
|
||||||
|
|
||||||
for (Nd4jLong arr = blockIdx.x; arr < numOfSubArrs; arr += gridDim.x) {
|
|
||||||
__shared__ T* x, *z;
|
|
||||||
__shared__ Nd4jLong lenZ;
|
|
||||||
__shared__ T norm2;
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
x = inputBuffer + inputOffsets[arr];
|
|
||||||
z = outputBuffer + outputOffsets[arr];
|
|
||||||
lenZ = shape::length(outputShape);
|
|
||||||
norm2 = norm2Buf[shape::getIndexOffset(arr, norm2shape)];
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
for (Nd4jLong j = threadIdx.x; j < lenZ; j+= blockDim.x) {
|
|
||||||
auto xIndex = shape::getIndexOffset(j, shape);
|
|
||||||
auto zIndex = shape::getIndexOffset(j, outputShape);
|
|
||||||
if(norm2 > clipNorm) {
|
|
||||||
z[zIndex] = x[xIndex] * clipNorm / norm2; // case with ews = 1 and ordering is 'c'
|
|
||||||
} else {
|
|
||||||
z[zIndex] = x[xIndex];
|
|
||||||
}
|
|
||||||
//printf("%lld: %lf %lf\n", j, z[zIndex], x[xIndex]);
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template<typename T>
|
|
||||||
static void clipByNorm_(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, NDArray const& clipNormA, const bool isInplace) {
|
|
||||||
const int rank = input.rankOf();
|
|
||||||
auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions);
|
|
||||||
clipNormA.syncToHost();
|
|
||||||
//norm2.printBuffer("Norm2");
|
|
||||||
T const clipNorm = clipNormA.e<T>(0);
|
|
||||||
//clipNormA.printBuffer("ClipNorm");
|
|
||||||
auto stream = context->getCudaStream();
|
|
||||||
if (isInplace) {
|
|
||||||
if(norm2.lengthOf() == 1) {
|
|
||||||
norm2.syncToHost();
|
|
||||||
T norm2Val = norm2.e<T>(0);
|
|
||||||
if(norm2Val > clipNorm)
|
|
||||||
input *= clipNorm / norm2Val;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
|
|
||||||
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(rank, dimensions);
|
|
||||||
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.shapeInfo(), dimsToExclude);
|
|
||||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions);
|
|
||||||
//auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), dimsToExclude);
|
|
||||||
T* inputBuffer = reinterpret_cast<T*>(input.specialBuffer());
|
|
||||||
T* norm2buf = reinterpret_cast<T*>(norm2.specialBuffer());
|
|
||||||
|
|
||||||
clipByNormInplaceKernel<T><<<256, 512, 1024, *stream>>>(numOfSubArrs, inputBuffer, packX.specialShapeInfo(), packX.specialOffsets(), norm2buf, norm2.specialShapeInfo(), clipNorm);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
|
|
||||||
if(norm2.lengthOf() == 1) {
|
|
||||||
norm2.syncToHost();
|
|
||||||
T norm2Val = norm2.e<T>(0);
|
|
||||||
|
|
||||||
if(norm2Val > clipNorm)
|
|
||||||
output.assign( input * (clipNorm / norm2Val));
|
|
||||||
else
|
|
||||||
output.assign( input );
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
|
|
||||||
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(rank, dimensions);
|
|
||||||
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.shapeInfo(), dimsToExclude);
|
|
||||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions);
|
|
||||||
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), dimensions);
|
|
||||||
T* inputBuffer = reinterpret_cast<T*>(input.specialBuffer());
|
|
||||||
T* norm2buf = reinterpret_cast<T*>(norm2.specialBuffer());
|
|
||||||
T* outputBuffer = reinterpret_cast<T*>(output.specialBuffer());
|
|
||||||
|
|
||||||
clipByNormKernel<T><<<256, 512, 1024, *stream>>>(numOfSubArrs, inputBuffer, packX.specialShapeInfo(), packX.specialOffsets(), outputBuffer, packZ.specialShapeInfo(), packZ.specialOffsets(), norm2buf, norm2.specialShapeInfo(), clipNorm);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
|
||||||
BUILD_SINGLE_SELECTOR(output.dataType(), clipByNorm_, (context, input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void clipByNorm_, (sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES);
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void clipByGlobalNorm_(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
|
|
||||||
NDArray globalNorm = NDArrayFactory::create<T>(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list]))
|
|
||||||
|
|
||||||
for (auto i = 0; i < inputs.size(); i++) {
|
|
||||||
auto input = inputs[i];
|
|
||||||
auto l2norm = input->reduceNumber(reduce::Norm2);
|
|
||||||
globalNorm += l2norm * l2norm;
|
|
||||||
}
|
|
||||||
|
|
||||||
globalNorm.applyTransform(transform::Sqrt, globalNorm); // = sd::math::nd4j_sqrt(globalNorm);
|
|
||||||
outputs[inputs.size()]->p(0, globalNorm);
|
|
||||||
globalNorm.syncToHost();
|
|
||||||
const T factor = static_cast<T>(clipNorm) / globalNorm.e<T>(0);
|
|
||||||
|
|
||||||
for (size_t e = 0; e < inputs.size(); e++) {
|
|
||||||
// all-reduce
|
|
||||||
auto input = inputs[e];
|
|
||||||
auto output = outputs[e];
|
|
||||||
|
|
||||||
if (globalNorm.e<double>(0) <= clipNorm) {
|
|
||||||
output->assign(input);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
|
|
||||||
auto lambda = LAMBDA_T(_x, factor) { return _x * factor; };
|
|
||||||
input->applyLambda(lambda, *output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void clipByGlobalNorm(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
|
|
||||||
BUILD_SINGLE_SELECTOR(outputs[0]->dataType(), clipByGlobalNorm_, (context, inputs, clipNorm, workspace, outputs, isInplace), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace), FLOAT_TYPES);
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template<typename T>
|
|
||||||
static void clipByAveraged_(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
|
||||||
auto cn = clipNorm.e<T>(0);
|
|
||||||
if (dimensions.size() == 0) {
|
|
||||||
// all-reduce
|
|
||||||
T n2 = input.reduceNumber(reduce::Norm2).e<T>(0) / static_cast<T>(input.lengthOf());
|
|
||||||
if (n2 <= cn) {
|
|
||||||
if (!isInplace)
|
|
||||||
output.assign(input);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
const T factor = cn / n2;
|
|
||||||
//auto lambda = LAMBDA_T(_x, factor) { return _x * factor; };
|
|
||||||
//input.applyLambda<T>(lambda, output);
|
|
||||||
output.assign(input * factor);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// along dimension
|
|
||||||
auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions, false);
|
|
||||||
if (!isInplace)
|
|
||||||
output.assign(input);
|
|
||||||
auto tads = output.allTensorsAlongDimension(dimensions);
|
|
||||||
auto outTads = output.allTensorsAlongDimension(dimensions);
|
|
||||||
// TODO: make this CUDA-compliant somehow
|
|
||||||
for (int e = 0; e < tads.size(); e++) {
|
|
||||||
T n2 = norm2.e<T>(e) / static_cast<T>(tads.at(e)->lengthOf());
|
|
||||||
const T factor = cn / n2;
|
|
||||||
if (n2 > cn) {
|
|
||||||
//auto lambda = LAMBDA_T(_x, factor) {return _x * factor;};
|
|
||||||
tads.at(e)->applyScalar(scalar::Multiply, factor, *outTads.at(e));//applyLambda<T>(lambda, &output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void clipByAveraged(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
|
||||||
BUILD_SINGLE_SELECTOR(input.dataType(), clipByAveraged_, (context, input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void clipByAveraged_, (sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES);
|
|
||||||
|
|
||||||
/*
|
|
||||||
if (d1 > params[1])
|
|
||||||
return params[1];
|
|
||||||
else if (d1 < params[0])
|
|
||||||
return params[0];
|
|
||||||
else return d1;
|
|
||||||
*/
|
|
||||||
template <typename T>
|
|
||||||
static void __global__ clipByValueKernel(void* input, Nd4jLong const* inputShape, void* output, Nd4jLong const* outputShape, double leftBound, double rightBound) {
|
|
||||||
__shared__ T* outputBuf;
|
|
||||||
__shared__ T* inputBuf;
|
|
||||||
__shared__ Nd4jLong length;
|
|
||||||
__shared__ bool linearBuffers;
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
outputBuf = reinterpret_cast<T *>(output);
|
|
||||||
inputBuf = reinterpret_cast<T *>(input);
|
|
||||||
length = shape::length(inputShape);
|
|
||||||
linearBuffers = shape::elementWiseStride(inputShape) == shape::elementWiseStride(outputShape) && shape::elementWiseStride(inputShape) == 1;
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
const auto step = gridDim.x * blockDim.x;
|
|
||||||
|
|
||||||
for (Nd4jLong e = tid; e < length; e += step) {
|
|
||||||
if (linearBuffers) {
|
|
||||||
if (inputBuf[e] > rightBound) outputBuf[e] = (T) rightBound;
|
|
||||||
else if (inputBuf[e] < leftBound) outputBuf[e] = (T) leftBound;
|
|
||||||
else outputBuf[e] = inputBuf[e];
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
auto inputOffset = shape::getIndexOffset(e, inputShape);
|
|
||||||
auto outputOffset = shape::getIndexOffset(e, outputShape);
|
|
||||||
if (inputBuf[inputOffset] > rightBound) outputBuf[outputOffset] = (T) rightBound;
|
|
||||||
else if (inputBuf[inputOffset] < leftBound) outputBuf[outputOffset] = (T) leftBound;
|
|
||||||
else outputBuf[outputOffset] = inputBuf[outputOffset];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static void clipByValue_(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) {
|
|
||||||
auto stream = context->getCudaStream();
|
|
||||||
if (!input.isActualOnDeviceSide())
|
|
||||||
input.syncToDevice();
|
|
||||||
NDArray::prepareSpecialUse({&output}, {&input});
|
|
||||||
clipByValueKernel<T><<<256, 512, 8192, *stream>>>(input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftBound, rightBound);
|
|
||||||
NDArray::registerSpecialUse({&output}, {&input});
|
|
||||||
}
|
|
||||||
|
|
||||||
void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) {
|
|
||||||
BUILD_SINGLE_SELECTOR(input.dataType(), clipByValue_, (context, input, leftBound, rightBound, output), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void clipByValue_, (sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output);, FLOAT_TYPES);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,9 +29,9 @@ namespace helpers {
|
||||||
|
|
||||||
void reverseSequence(sd::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim);
|
void reverseSequence(sd::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim);
|
||||||
|
|
||||||
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs, bool isBackProp);
|
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,13 +63,13 @@ namespace helpers {
|
||||||
void mergeAdd(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output);
|
void mergeAdd(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output);
|
||||||
void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, std::vector<NDArray*>& outArrs);
|
void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, std::vector<NDArray*>& outArrs);
|
||||||
|
|
||||||
void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace);
|
void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace, const bool useAverage);
|
||||||
|
|
||||||
void clipByGlobalNorm(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace);
|
void clipByGlobalNorm(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace);
|
||||||
|
|
||||||
void clipByNormBP(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm);
|
void clipByNormBp(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage);
|
||||||
|
|
||||||
void clipByAveraged(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace);
|
void clipByAveragedNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace);
|
||||||
void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output);
|
|
||||||
|
|
||||||
void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode);
|
void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode);
|
||||||
|
|
||||||
|
|
|
@ -1093,7 +1093,7 @@ namespace sd {
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
|
|
||||||
NDArray *a0 = block.array(0);
|
NDArray *a0 = block.array(0);
|
||||||
for (int e = 0; e < block.width(); e++) {
|
for (int e = 1; e < block.width(); e++) {
|
||||||
auto aV = block.array(e);
|
auto aV = block.array(e);
|
||||||
if (!shape::equalsSoft(a0->shapeInfo(), aV->shapeInfo()))
|
if (!shape::equalsSoft(a0->shapeInfo(), aV->shapeInfo()))
|
||||||
return ND4J_STATUS_BAD_DIMENSIONS;
|
return ND4J_STATUS_BAD_DIMENSIONS;
|
||||||
|
|
|
@ -90,13 +90,12 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
||||||
// x
|
// x
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
||||||
|
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||||
|
|
||||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
|
||||||
// z, output
|
// z, output
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format);
|
||||||
|
mkldnnUtils::setBlockStrides(*z, z_user_md);
|
||||||
mkldnnUtils::setBlockStrides(z, z_user_md);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -112,15 +111,10 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
||||||
// provide memory and check whether reorder is required
|
// provide memory and check whether reorder is required
|
||||||
|
|
||||||
// x
|
// x
|
||||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// z
|
// z
|
||||||
auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer());
|
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_ff_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||||
const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc();
|
|
||||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem;
|
|
||||||
if (zReorder)
|
|
||||||
dnnl::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem);
|
|
||||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
|
||||||
|
|
||||||
// mean
|
// mean
|
||||||
auto mean_mkl_mem = dnnl::memory(op_ff_prim_desc.mean_desc(), engine, const_cast<void*>(mean->buffer()));
|
auto mean_mkl_mem = dnnl::memory(op_ff_prim_desc.mean_desc(), engine, const_cast<void*>(mean->buffer()));
|
||||||
|
@ -141,8 +135,8 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
||||||
dnnl::batch_normalization_forward(op_ff_prim_desc).execute(stream, args);
|
dnnl::batch_normalization_forward(op_ff_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder outputs if necessary
|
// reorder outputs if necessary
|
||||||
if (zReorder)
|
if (op_ff_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
|
|
||||||
|
@ -151,7 +145,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray &dLdO, const NDArray* weights,
|
static void batchnormBpMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray &dLdO, const NDArray* weights,
|
||||||
NDArray* dLdI, NDArray* dLdW, const float epsilon, const bool isNCHW) {
|
NDArray* dLdI, NDArray* dLdW, const float epsilon, const bool isNCHW) {
|
||||||
|
|
||||||
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x
|
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x
|
||||||
|
@ -206,20 +200,17 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
||||||
// x
|
// x
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
||||||
|
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
|
||||||
|
|
||||||
// dLdO
|
// dLdO
|
||||||
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
|
||||||
|
mkldnnUtils::setBlockStrides(dLdO, dLdO_user_md);
|
||||||
mkldnnUtils::setBlockStrides(&dLdO, dLdO_user_md);
|
|
||||||
|
|
||||||
// dLdI
|
// dLdI
|
||||||
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format);
|
||||||
|
mkldnnUtils::setBlockStrides(*dLdI, dLdI_user_md);
|
||||||
mkldnnUtils::setBlockStrides(dLdI, dLdI_user_md);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -239,10 +230,10 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
||||||
// provide memory and check whether reorder is required
|
// provide memory and check whether reorder is required
|
||||||
|
|
||||||
// x
|
// x
|
||||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// dLdO
|
// dLdO
|
||||||
mkldnnUtils::loadDataToMklStream(&dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
||||||
|
|
||||||
// mean
|
// mean
|
||||||
auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, const_cast<void*>(mean->buffer()));
|
auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, const_cast<void*>(mean->buffer()));
|
||||||
|
@ -253,10 +244,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
||||||
args[DNNL_ARG_VARIANCE] = var_mkl_mem;
|
args[DNNL_ARG_VARIANCE] = var_mkl_mem;
|
||||||
|
|
||||||
// dLdI
|
// dLdI
|
||||||
auto dLdI_user_mem = dnnl::memory(dLdI_user_md, engine, dLdI->buffer());
|
auto dLdI_user_mem = mkldnnUtils::loadDataToMklStream(*dLdI, engine, stream, dLdI_user_md, op_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||||
const bool dLdIReorder = op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc();
|
|
||||||
auto dLdI_mkl_mem = dLdIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdI_user_mem;
|
|
||||||
args[DNNL_ARG_DIFF_SRC] = dLdI_mkl_mem;
|
|
||||||
|
|
||||||
// gamma and beta (and their gradients) if they are present
|
// gamma and beta (and their gradients) if they are present
|
||||||
if(weights != nullptr) {
|
if(weights != nullptr) {
|
||||||
|
@ -272,8 +260,8 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
||||||
dnnl::batch_normalization_backward(op_bp_prim_desc).execute(stream, args);
|
dnnl::batch_normalization_backward(op_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder outputs if necessary
|
// reorder outputs if necessary
|
||||||
if (dLdIReorder)
|
if (op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc())
|
||||||
dnnl::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], dLdI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], dLdI_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
|
|
||||||
|
@ -662,9 +650,9 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) {
|
||||||
const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2);
|
const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2);
|
||||||
|
|
||||||
if (shape::strideDescendingCAscendingF(dLdO->shapeInfo()))
|
if (shape::strideDescendingCAscendingF(dLdO->shapeInfo()))
|
||||||
batchnormBackPropMKLDNN(input, mean, variance, *dLdO, weights, dLdI, dLdW, epsilon, isNCHW);
|
batchnormBpMKLDNN(input, mean, variance, *dLdO, weights, dLdI, dLdW, epsilon, isNCHW);
|
||||||
else
|
else
|
||||||
batchnormBackPropMKLDNN(input, mean, variance, dLdO->dup(), weights, dLdI, dLdW, epsilon, isNCHW);
|
batchnormBpMKLDNN(input, mean, variance, dLdO->dup(), weights, dLdI, dLdW, epsilon, isNCHW);
|
||||||
|
|
||||||
*dLdM = 0;
|
*dLdM = 0;
|
||||||
*dLdV = 0;
|
*dLdV = 0;
|
||||||
|
|
|
@ -0,0 +1,186 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <system/platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
static void concatMKLDNN(const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||||
|
|
||||||
|
// data type
|
||||||
|
dnnl::memory::data_type type;
|
||||||
|
if(output.dataType() == DataType::FLOAT32)
|
||||||
|
type = dnnl::memory::data_type::f32;
|
||||||
|
else if(output.dataType() == DataType::HALF)
|
||||||
|
type = dnnl::memory::data_type::f16;
|
||||||
|
else if(output.dataType() == DataType::BFLOAT16)
|
||||||
|
type = dnnl::memory::data_type::bf16;
|
||||||
|
else if(output.dataType() == DataType::UINT8)
|
||||||
|
type = dnnl::memory::data_type::u8;
|
||||||
|
else
|
||||||
|
type = dnnl::memory::data_type::s8;
|
||||||
|
|
||||||
|
std::vector<dnnl::memory::desc> x_user_md(inArrs.size()), x_mkl_md(inArrs.size());
|
||||||
|
|
||||||
|
// inputs
|
||||||
|
for (int i = 0; i < inArrs.size(); ++i) {
|
||||||
|
|
||||||
|
dnnl::memory::dims dims = inArrs[i]->getShapeAsFlatVector();
|
||||||
|
x_user_md[i] = x_mkl_md[i] = dnnl::memory::desc(dims, type, mkldnnUtils::getFormat(*inArrs[i]));
|
||||||
|
mkldnnUtils::setBlockStrides(*inArrs[i], x_user_md[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// output
|
||||||
|
dnnl::memory::dims dims = output.getShapeAsFlatVector();
|
||||||
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, mkldnnUtils::getFormat(output));
|
||||||
|
mkldnnUtils::setBlockStrides(output, z_user_md);
|
||||||
|
|
||||||
|
std::unordered_map<int, dnnl::memory> args;
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
dnnl::concat::primitive_desc op_prim_desc(axis, x_mkl_md, engine);
|
||||||
|
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
|
||||||
|
// inputs
|
||||||
|
for (int i = 0; i < inArrs.size(); ++i)
|
||||||
|
mkldnnUtils::loadDataToMklStream(*inArrs[i], engine, stream, x_user_md[i], op_prim_desc.src_desc(i), args[DNNL_ARG_MULTIPLE_SRC + i]);
|
||||||
|
|
||||||
|
// outputs
|
||||||
|
auto z_user_mem = mkldnnUtils::loadDataToMklStream(output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||||
|
|
||||||
|
// primitive execution
|
||||||
|
dnnl::concat(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// reorder output if necessary
|
||||||
|
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||||
|
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(concat, ENGINE_CPU) {
|
||||||
|
|
||||||
|
REQUIRE_TRUE(block.width() > 0, 0, "CONCAT MKLDNN op: No input arrays were provided");
|
||||||
|
|
||||||
|
const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);
|
||||||
|
|
||||||
|
const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();
|
||||||
|
|
||||||
|
// first of all take into account possible presence of empty arrays
|
||||||
|
// also if scalar is present -> copy its value to vector with length=1
|
||||||
|
std::vector<const NDArray*> nonEmptyArrs;
|
||||||
|
std::vector<int> arrsToDelete;
|
||||||
|
int index = 0;
|
||||||
|
bool allOfSameType = true;
|
||||||
|
auto rankOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0;
|
||||||
|
auto typeOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : block.dataType();
|
||||||
|
|
||||||
|
for(int i = 0; i < numOfInArrs; ++i) {
|
||||||
|
auto input = INPUT_VARIABLE(i);
|
||||||
|
auto currentRank = input->rankOf();
|
||||||
|
|
||||||
|
if(!input->isEmpty()) {
|
||||||
|
|
||||||
|
allOfSameType &= (typeOfFirstArr == input->dataType());
|
||||||
|
|
||||||
|
if(input->rankOf() == 0) {
|
||||||
|
auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext());
|
||||||
|
vec->assign(input);
|
||||||
|
nonEmptyArrs.push_back(vec);
|
||||||
|
arrsToDelete.push_back(index);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
nonEmptyArrs.push_back(input);
|
||||||
|
}
|
||||||
|
++index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int numOfNonEmptyArrs = nonEmptyArrs.size();
|
||||||
|
|
||||||
|
if(numOfNonEmptyArrs == 0){
|
||||||
|
//All inputs are empty arrays -> return empty, mainly for TF import compatibility (no op)
|
||||||
|
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "CONCAT MKLDNN op: If all input variables are empty, output must be empty");
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int rank = nonEmptyArrs[0]->rankOf(); // look up to first non-empty array
|
||||||
|
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : INT_ARG(0);
|
||||||
|
if(axis < 0){
|
||||||
|
axis += rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ******** input validation ******** //
|
||||||
|
REQUIRE_TRUE(allOfSameType, 0, "CONCAT MKLDNN op: all of input arrays must have same type !");
|
||||||
|
REQUIRE_TRUE(nonEmptyArrs[0]->dataType() == OUTPUT_VARIABLE(0)->dataType(), 0, "CONCAT MKLDNN op: output array should have the same type as inputs arrays !");
|
||||||
|
REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT MKLDNN op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
|
||||||
|
|
||||||
|
for(int i = 1; i < numOfNonEmptyArrs; ++i)
|
||||||
|
REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0, "CONCAT MKLDNN op: all input arrays must have the same rank !");
|
||||||
|
|
||||||
|
for(int i = 1; i < numOfNonEmptyArrs; ++i) {
|
||||||
|
for(int dim = 0; dim < rank; ++dim)
|
||||||
|
if(dim != axis)
|
||||||
|
REQUIRE_TRUE(nonEmptyArrs[i]->sizeAt(dim) == nonEmptyArrs[0]->sizeAt(dim), 0, "CONCAT MKLDNN op: all input arrays must have the same dimensions (except those on input axis) !");
|
||||||
|
}
|
||||||
|
// ******** end of input validation ******** //
|
||||||
|
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if(numOfNonEmptyArrs == 1)
|
||||||
|
output->assign(nonEmptyArrs[0]);
|
||||||
|
else
|
||||||
|
concatMKLDNN(nonEmptyArrs, *output, axis);
|
||||||
|
|
||||||
|
// delete dynamically allocated vectors with length=1
|
||||||
|
for(int index : arrsToDelete)
|
||||||
|
delete nonEmptyArrs[index];
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_CHECK(concat, ENGINE_CPU) {
|
||||||
|
|
||||||
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const auto zType = z->dataType();
|
||||||
|
|
||||||
|
return z->rankOf() < 7 && (zType==DataType::FLOAT32 || zType==DataType::HALF || zType==DataType::BFLOAT16 || zType==DataType::UINT8 || zType==DataType::INT8);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -62,33 +62,23 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||||
|
|
||||||
auto type = dnnl::memory::data_type::f32;
|
auto type = dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
|
std::vector<int> permut;
|
||||||
|
if(0 == wFormat)
|
||||||
|
permut = {3,2,0,1}; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||||
|
else if(2 == wFormat)
|
||||||
|
permut = {0,3,1,2}; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
|
||||||
|
|
||||||
// memory descriptors for arrays
|
// memory descriptors for arrays
|
||||||
|
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||||
if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
|
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
|
||||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
uint i0, i1, i2, i3;
|
|
||||||
if(0 == wFormat) {
|
|
||||||
i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
|
||||||
}
|
|
||||||
else if(1 == wFormat) {
|
|
||||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
|
|
||||||
}
|
|
||||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
dnnl::memory::desc b_mkl_md;
|
dnnl::memory::desc b_mkl_md;
|
||||||
|
@ -98,7 +88,7 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||||
// output
|
// output
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(output, z_user_md);
|
mkldnnUtils::setBlockStrides(*output, z_user_md);
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -114,10 +104,10 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
if(bias != nullptr) {
|
if(bias != nullptr) {
|
||||||
|
@ -126,17 +116,14 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||||
}
|
}
|
||||||
|
|
||||||
// output
|
// output
|
||||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
|
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
|
||||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
|
||||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
|
||||||
|
|
||||||
// run calculations
|
// run calculations
|
||||||
dnnl::convolution_forward(op_prim_desc).execute(stream, args);
|
dnnl::convolution_forward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder outputs if necessary
|
// reorder outputs if necessary
|
||||||
if (zReorder)
|
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||||
|
@ -170,64 +157,38 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
||||||
|
|
||||||
auto type = dnnl::memory::data_type::f32;
|
auto type = dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
|
std::vector<int> permut;
|
||||||
|
if(0 == wFormat)
|
||||||
|
permut = {3,2,0,1}; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||||
|
else if(2 == wFormat)
|
||||||
|
permut = {0,3,1,2}; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
|
||||||
|
|
||||||
// memory descriptors for arrays
|
// memory descriptors for arrays
|
||||||
|
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||||
if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
|
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
|
||||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
uint i0, i1, i2, i3;
|
|
||||||
if(0 == wFormat) {
|
|
||||||
i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
|
||||||
}
|
|
||||||
else if(1 == wFormat) {
|
|
||||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
|
|
||||||
}
|
|
||||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
|
||||||
|
|
||||||
// gradW
|
// gradW
|
||||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||||
if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) {
|
mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut);
|
||||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
uint i0, i1, i2, i3;
|
|
||||||
if(0 == wFormat) {
|
|
||||||
i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
|
||||||
}
|
|
||||||
else if(1 == wFormat) {
|
|
||||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
|
|
||||||
}
|
|
||||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
|
|
||||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
|
|
||||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
|
|
||||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// gradB
|
// gradB
|
||||||
dnnl::memory::desc gradB_mkl_md;
|
dnnl::memory::desc gradB_mkl_md;
|
||||||
|
@ -256,10 +217,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
|
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
|
||||||
|
@ -274,16 +235,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
||||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
|
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||||
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
|
||||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
|
||||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
|
||||||
|
|
||||||
// gradW
|
// gradW
|
||||||
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer());
|
auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]);
|
||||||
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
|
||||||
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
|
||||||
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
|
||||||
|
|
||||||
// gradB
|
// gradB
|
||||||
if(gradB != nullptr) {
|
if(gradB != nullptr) {
|
||||||
|
@ -301,10 +256,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
||||||
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder gradI if necessary
|
// reorder gradI if necessary
|
||||||
if (gradIReorder)
|
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
|
||||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
|
||||||
if (gradWReorder)
|
if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc())
|
||||||
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
|
|
||||||
|
|
|
@ -63,6 +63,12 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||||
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
|
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
|
||||||
dnnl::memory::dims zDims = {bS, oC, oD, oH, oW};
|
dnnl::memory::dims zDims = {bS, oC, oD, oH, oW};
|
||||||
|
|
||||||
|
std::vector<int> permut;
|
||||||
|
if(0 == wFormat)
|
||||||
|
permut = {4,3,0,1,2}; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||||
|
else if(2 == wFormat)
|
||||||
|
permut = {0,4,1,2,3}; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
|
||||||
|
|
||||||
auto type = dnnl::memory::data_type::f32;
|
auto type = dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
// memory descriptors for arrays
|
// memory descriptors for arrays
|
||||||
|
@ -70,29 +76,12 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||||
if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
|
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
|
||||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
uint i0, i1, i2, i3, i4;
|
|
||||||
if(0 == wFormat) {
|
|
||||||
i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
|
||||||
}
|
|
||||||
else if(1 == wFormat) {
|
|
||||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
|
|
||||||
}
|
|
||||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
|
|
||||||
}
|
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
dnnl::memory::desc b_mkl_md;
|
dnnl::memory::desc b_mkl_md;
|
||||||
|
@ -102,7 +91,7 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||||
// output
|
// output
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(output, z_user_md);
|
mkldnnUtils::setBlockStrides(*output, z_user_md);
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -118,10 +107,10 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
if(bias != nullptr) {
|
if(bias != nullptr) {
|
||||||
|
@ -130,17 +119,14 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||||
}
|
}
|
||||||
|
|
||||||
// output
|
// output
|
||||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
|
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
|
||||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
|
||||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
|
||||||
|
|
||||||
// run calculations
|
// run calculations
|
||||||
dnnl::convolution_forward(op_prim_desc).execute(stream, args);
|
dnnl::convolution_forward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder outputs if necessary
|
// reorder outputs if necessary
|
||||||
if (zReorder)
|
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
}
|
}
|
||||||
|
@ -177,68 +163,40 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
||||||
|
|
||||||
auto type = dnnl::memory::data_type::f32;
|
auto type = dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
|
std::vector<int> permut;
|
||||||
|
if(0 == wFormat)
|
||||||
|
permut = {4,3,0,1,2}; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||||
|
else if(2 == wFormat)
|
||||||
|
permut = {0,4,1,2,3}; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
|
||||||
|
|
||||||
// memory descriptors for arrays
|
// memory descriptors for arrays
|
||||||
|
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||||
if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
|
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
|
||||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
uint i0, i1, i2, i3, i4;
|
|
||||||
if(0 == wFormat) {
|
|
||||||
i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
|
||||||
}
|
|
||||||
else if(1 == wFormat) {
|
|
||||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
|
|
||||||
}
|
|
||||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
|
|
||||||
}
|
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
|
||||||
|
|
||||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||||
|
|
||||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
|
||||||
|
|
||||||
// gradW
|
// gradW
|
||||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||||
if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) {
|
mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut);
|
||||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
uint i0, i1, i2, i3, i4;
|
|
||||||
if(0 == wFormat) {
|
|
||||||
i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
|
||||||
}
|
|
||||||
else if(1 == wFormat) {
|
|
||||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
|
|
||||||
}
|
|
||||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
|
|
||||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
|
|
||||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
|
|
||||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
|
|
||||||
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4);
|
|
||||||
}
|
|
||||||
|
|
||||||
// gradB
|
// gradB
|
||||||
dnnl::memory::desc gradB_mkl_md;
|
dnnl::memory::desc gradB_mkl_md;
|
||||||
|
@ -267,10 +225,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
|
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
|
||||||
|
@ -285,16 +243,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
||||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
|
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||||
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
|
||||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
|
||||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
|
||||||
|
|
||||||
// gradW
|
// gradW
|
||||||
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer());
|
auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]);
|
||||||
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
|
||||||
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
|
||||||
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
|
||||||
|
|
||||||
// gradB
|
// gradB
|
||||||
if(gradB != nullptr) {
|
if(gradB != nullptr) {
|
||||||
|
@ -312,10 +264,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
||||||
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder gradI if necessary
|
// reorder gradI if necessary
|
||||||
if (gradIReorder)
|
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
|
||||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
|
||||||
if (gradWReorder)
|
if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc())
|
||||||
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
|
|
||||||
|
|
|
@ -47,16 +47,13 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||||
dnnl::memory::dims dilation = { dH-1, dW-1 };
|
dnnl::memory::dims dilation = { dH-1, dW-1 };
|
||||||
|
|
||||||
uint i0, i1, i2, i3;
|
std::vector<int> permut;
|
||||||
if(0 == wFormat) {
|
if(0 == wFormat)
|
||||||
i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
permut = {2,3,0,1}; // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||||
}
|
else if(1 == wFormat)
|
||||||
else if(1 == wFormat) {
|
permut = {1,0,2,3}; // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
|
||||||
i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
|
else
|
||||||
}
|
permut = {3,0,1,2}; // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
|
||||||
else {
|
|
||||||
i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
|
|
||||||
}
|
|
||||||
|
|
||||||
// input type
|
// input type
|
||||||
dnnl::memory::data_type xType;
|
dnnl::memory::data_type xType;
|
||||||
|
@ -99,16 +96,12 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
||||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
|
||||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
dnnl::memory::desc b_mkl_md;
|
dnnl::memory::desc b_mkl_md;
|
||||||
|
@ -118,7 +111,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
// output
|
// output
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(output, z_user_md);
|
mkldnnUtils::setBlockStrides(*output, z_user_md);
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -135,10 +128,10 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
if(bias != nullptr) {
|
if(bias != nullptr) {
|
||||||
|
@ -147,17 +140,14 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
}
|
}
|
||||||
|
|
||||||
// output
|
// output
|
||||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
|
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
|
||||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
|
||||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
|
||||||
|
|
||||||
// run calculations
|
// run calculations
|
||||||
dnnl::deconvolution_forward(op_prim_desc).execute(stream, args);
|
dnnl::deconvolution_forward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder outputs if necessary
|
// reorder outputs if necessary
|
||||||
if (zReorder)
|
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
|
|
||||||
|
@ -180,16 +170,13 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
||||||
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||||
dnnl::memory::dims dilation = { dH-1, dW-1 };
|
dnnl::memory::dims dilation = { dH-1, dW-1 };
|
||||||
|
|
||||||
uint i0, i1, i2, i3;
|
std::vector<int> permut;
|
||||||
if(0 == wFormat) {
|
if(0 == wFormat)
|
||||||
i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
permut = {2,3,0,1}; // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||||
}
|
else if(1 == wFormat)
|
||||||
else if(1 == wFormat) {
|
permut = {1,0,2,3}; // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
|
||||||
i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
|
else
|
||||||
}
|
permut = {3,0,1,2}; // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
|
||||||
else {
|
|
||||||
i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
|
|
||||||
}
|
|
||||||
|
|
||||||
// input type
|
// input type
|
||||||
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||||
|
@ -216,35 +203,27 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
||||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
|
||||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
|
||||||
|
|
||||||
// gradW
|
// gradW
|
||||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl);
|
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl);
|
||||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut);
|
||||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
|
|
||||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
|
|
||||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
|
|
||||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
|
|
||||||
|
|
||||||
// gradB
|
// gradB
|
||||||
dnnl::memory::desc gradB_mkl_md;
|
dnnl::memory::desc gradB_mkl_md;
|
||||||
|
@ -273,10 +252,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
|
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
|
||||||
|
@ -291,16 +270,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
||||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
|
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||||
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
|
||||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
|
||||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
|
||||||
|
|
||||||
// gradW
|
// gradW
|
||||||
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer());
|
auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]);
|
||||||
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
|
||||||
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
|
||||||
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
|
||||||
|
|
||||||
// gradB
|
// gradB
|
||||||
if(gradB != nullptr) {
|
if(gradB != nullptr) {
|
||||||
|
@ -318,10 +291,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
||||||
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder gradI if necessary
|
// reorder gradI if necessary
|
||||||
if (gradIReorder)
|
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
|
||||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
|
||||||
if (gradWReorder)
|
if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc())
|
||||||
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI,
|
static void deconv2TFdBpMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI,
|
||||||
const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW,
|
const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW,
|
||||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||||
const bool isNCHW, const int wFormat) {
|
const bool isNCHW, const int wFormat) {
|
||||||
|
@ -67,21 +67,17 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
||||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(*weights, w_user_md, {3,2,0,1}); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
|
||||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -101,23 +97,20 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
mkldnnUtils::loadDataToMklStream(*gradO, engine, stream, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
|
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||||
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
|
||||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
|
||||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
|
||||||
|
|
||||||
// run backward data calculations
|
// run backward data calculations
|
||||||
dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder gradI if necessary
|
// reorder gradI if necessary
|
||||||
if (gradIReorder)
|
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
|
||||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
|
|
||||||
|
@ -189,7 +182,7 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {
|
||||||
// gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
// gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||||
// }
|
// }
|
||||||
|
|
||||||
deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, wFormat);
|
deconv2TFdBpMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, wFormat);
|
||||||
|
|
||||||
// delete weights;
|
// delete weights;
|
||||||
|
|
||||||
|
|
|
@ -48,16 +48,13 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||||
dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
|
dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
|
||||||
|
|
||||||
uint i0, i1, i2, i3, i4;
|
std::vector<int> permut;
|
||||||
if(0 == wFormat) {
|
if(0 == wFormat)
|
||||||
i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
permut = {3,4,0,1,2}; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||||
}
|
else if(1 == wFormat)
|
||||||
else if(1 == wFormat) {
|
permut = {1,0,2,3,4}; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
|
||||||
i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
|
else
|
||||||
}
|
permut = {4,0,1,2,3}; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
|
||||||
else {
|
|
||||||
i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
|
|
||||||
}
|
|
||||||
|
|
||||||
// input type
|
// input type
|
||||||
dnnl::memory::data_type xType;
|
dnnl::memory::data_type xType;
|
||||||
|
@ -100,17 +97,12 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
||||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
|
||||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
|
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
dnnl::memory::desc b_mkl_md;
|
dnnl::memory::desc b_mkl_md;
|
||||||
|
@ -120,7 +112,7 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
// output
|
// output
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(output, z_user_md);
|
mkldnnUtils::setBlockStrides(*output, z_user_md);
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -137,10 +129,10 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
if(bias != nullptr) {
|
if(bias != nullptr) {
|
||||||
|
@ -149,17 +141,14 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
}
|
}
|
||||||
|
|
||||||
// output
|
// output
|
||||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
|
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
|
||||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
|
||||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
|
||||||
|
|
||||||
// run calculations
|
// run calculations
|
||||||
dnnl::deconvolution_forward(op_prim_desc).execute(stream, args);
|
dnnl::deconvolution_forward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder outputs if necessary
|
// reorder outputs if necessary
|
||||||
if (zReorder)
|
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
|
|
||||||
|
@ -185,16 +174,13 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||||
dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
|
dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
|
||||||
|
|
||||||
uint i0, i1, i2, i3, i4;
|
std::vector<int> permut;
|
||||||
if(0 == wFormat) {
|
if(0 == wFormat)
|
||||||
i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
permut = {3,4,0,1,2}; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||||
}
|
else if(1 == wFormat)
|
||||||
else if(1 == wFormat) {
|
permut = {1,0,2,3,4}; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
|
||||||
i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
|
else
|
||||||
}
|
permut = {4,0,1,2,3}; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
|
||||||
else {
|
|
||||||
i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
|
|
||||||
}
|
|
||||||
|
|
||||||
// input type
|
// input type
|
||||||
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||||
|
@ -221,37 +207,27 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
||||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
|
||||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
|
||||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
|
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
|
||||||
|
|
||||||
// gradW
|
// gradW
|
||||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl);
|
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl);
|
||||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut);
|
||||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
|
|
||||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
|
|
||||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
|
|
||||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
|
|
||||||
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4);
|
|
||||||
|
|
||||||
// gradB
|
// gradB
|
||||||
dnnl::memory::desc gradB_mkl_md;
|
dnnl::memory::desc gradB_mkl_md;
|
||||||
|
@ -281,10 +257,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
|
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
|
||||||
|
@ -299,16 +275,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
|
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||||
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
|
||||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
|
||||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
|
||||||
|
|
||||||
// gradW
|
// gradW
|
||||||
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer());
|
auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]);
|
||||||
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
|
||||||
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
|
||||||
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
|
||||||
|
|
||||||
// gradB
|
// gradB
|
||||||
if(gradB != nullptr) {
|
if(gradB != nullptr) {
|
||||||
|
@ -326,10 +296,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder gradI if necessary
|
// reorder gradI if necessary
|
||||||
if (gradIReorder)
|
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
|
||||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
|
||||||
if (gradWReorder)
|
if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc())
|
||||||
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@
|
||||||
|
|
||||||
using namespace dnnl;
|
using namespace dnnl;
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
|
|
||||||
|
@ -109,7 +109,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
|
@ -129,7 +129,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
// output
|
// output
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFormatMkl);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(output, z_user_md);
|
mkldnnUtils::setBlockStrides(*output, z_user_md);
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -146,10 +146,10 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
if(bias != nullptr) {
|
if(bias != nullptr) {
|
||||||
|
@ -158,24 +158,21 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
}
|
}
|
||||||
|
|
||||||
// output
|
// output
|
||||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
|
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
|
||||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
|
||||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
|
||||||
|
|
||||||
// run calculations
|
// run calculations
|
||||||
dnnl::convolution_forward(op_prim_desc).execute(stream, args);
|
dnnl::convolution_forward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder outputs if necessary
|
// reorder outputs if necessary
|
||||||
if (zReorder)
|
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
static void depthwiseConv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
||||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||||
const int paddingMode, const bool isNCHW, const int wFormat) {
|
const int paddingMode, const bool isNCHW, const int wFormat) {
|
||||||
|
|
||||||
|
@ -235,7 +232,7 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
|
@ -250,12 +247,12 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
||||||
// gradO
|
// gradO
|
||||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFormatMkl);
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFormatMkl);
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFormatMkl);
|
||||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
|
||||||
|
|
||||||
// gradW
|
// gradW
|
||||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
||||||
|
@ -294,10 +291,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
|
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
|
||||||
|
@ -312,16 +309,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
||||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
|
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||||
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
|
||||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
|
||||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
|
||||||
|
|
||||||
// gradW
|
// gradW
|
||||||
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer());
|
auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]);
|
||||||
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
|
||||||
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
|
||||||
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
|
||||||
|
|
||||||
// gradB
|
// gradB
|
||||||
if(gradB != nullptr) {
|
if(gradB != nullptr) {
|
||||||
|
@ -339,10 +330,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
||||||
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder gradI if necessary
|
// reorder gradI if necessary
|
||||||
if (gradIReorder)
|
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
|
||||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
|
||||||
if (gradWReorder)
|
if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc())
|
||||||
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
|
|
||||||
|
@ -458,7 +449,7 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) {
|
||||||
if(bias)
|
if(bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
depthwiseConv2dNackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
|
depthwiseConv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -169,71 +169,43 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
||||||
x_lstm_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::any);
|
x_lstm_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::any);
|
||||||
// x_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, nIn}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, nIn}, type, dnnl::memory::format_tag::ntc);
|
// x_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, nIn}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, nIn}, type, dnnl::memory::format_tag::ntc);
|
||||||
x_user_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::tnc);
|
x_user_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::tnc);
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
|
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
|
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
|
|
||||||
|
|
||||||
// wx
|
// wx
|
||||||
wx_lstm_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::any);
|
wx_lstm_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::any);
|
||||||
wx_user_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::ldigo);
|
wx_user_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::ldigo);
|
||||||
wx_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(*Wx, wx_user_md);
|
||||||
wx_user_md.data.format_desc.blocking.strides[0] = Wx->stridesOf()[0];
|
|
||||||
wx_user_md.data.format_desc.blocking.strides[1] = Wx->stridesOf()[1];
|
|
||||||
wx_user_md.data.format_desc.blocking.strides[2] = Wx->stridesOf()[2];
|
|
||||||
wx_user_md.data.format_desc.blocking.strides[3] = Wx->stridesOf()[3];
|
|
||||||
wx_user_md.data.format_desc.blocking.strides[4] = Wx->stridesOf()[4];
|
|
||||||
|
|
||||||
// wr
|
// wr
|
||||||
wr_lstm_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::any);
|
wr_lstm_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::any);
|
||||||
wr_user_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::ldigo);
|
wr_user_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::ldigo);
|
||||||
wr_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(*Wr, wr_user_md);
|
||||||
wr_user_md.data.format_desc.blocking.strides[0] = Wr->stridesOf()[0];
|
|
||||||
wr_user_md.data.format_desc.blocking.strides[1] = Wr->stridesOf()[1];
|
|
||||||
wr_user_md.data.format_desc.blocking.strides[2] = Wr->stridesOf()[2];
|
|
||||||
wr_user_md.data.format_desc.blocking.strides[3] = Wr->stridesOf()[3];
|
|
||||||
wr_user_md.data.format_desc.blocking.strides[4] = Wr->stridesOf()[4];
|
|
||||||
|
|
||||||
// h
|
// h
|
||||||
h_lstm_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::any);
|
h_lstm_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::any);
|
||||||
// h_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, hDirDim*nOut}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, hDirDim*nOut}, type, dnnl::memory::format_tag::ntc);
|
// h_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, hDirDim*nOut}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, hDirDim*nOut}, type, dnnl::memory::format_tag::ntc);
|
||||||
h_user_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::tnc);
|
h_user_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::tnc);
|
||||||
h_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(*h, h_user_md);
|
||||||
h_user_md.data.format_desc.blocking.strides[0] = h->stridesOf()[0];
|
|
||||||
h_user_md.data.format_desc.blocking.strides[1] = h->stridesOf()[1];
|
|
||||||
h_user_md.data.format_desc.blocking.strides[2] = h->stridesOf()[2];
|
|
||||||
|
|
||||||
// b
|
// b
|
||||||
if(b) {
|
if(b) {
|
||||||
b_lstm_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::any);
|
b_lstm_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::any);
|
||||||
b_user_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::ldgo);
|
b_user_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::ldgo);
|
||||||
b_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(*b, b_user_md);
|
||||||
b_user_md.data.format_desc.blocking.strides[0] = b->stridesOf()[0];
|
|
||||||
b_user_md.data.format_desc.blocking.strides[1] = b->stridesOf()[1];
|
|
||||||
b_user_md.data.format_desc.blocking.strides[2] = b->stridesOf()[2];
|
|
||||||
b_user_md.data.format_desc.blocking.strides[3] = b->stridesOf()[3];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// hI
|
// hI
|
||||||
if(hI) {
|
if(hI) {
|
||||||
hI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any);
|
hI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any);
|
||||||
hI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc);
|
hI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc);
|
||||||
hI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(*hI, hI_user_md);
|
||||||
hI_user_md.data.format_desc.blocking.strides[0] = hI->stridesOf()[0];
|
|
||||||
hI_user_md.data.format_desc.blocking.strides[1] = hI->stridesOf()[1];
|
|
||||||
hI_user_md.data.format_desc.blocking.strides[2] = hI->stridesOf()[2];
|
|
||||||
hI_user_md.data.format_desc.blocking.strides[3] = hI->stridesOf()[3];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// cI
|
// cI
|
||||||
if(cI) {
|
if(cI) {
|
||||||
cI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any);
|
cI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any);
|
||||||
cI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc);
|
cI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc);
|
||||||
cI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(*cI, cI_user_md);
|
||||||
cI_user_md.data.format_desc.blocking.strides[0] = cI->stridesOf()[0];
|
|
||||||
cI_user_md.data.format_desc.blocking.strides[1] = cI->stridesOf()[1];
|
|
||||||
cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[2];
|
|
||||||
cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[3];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// hL
|
// hL
|
||||||
|
@ -241,20 +213,13 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
||||||
hL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::any);
|
hL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::any);
|
||||||
hL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
|
hL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
|
||||||
hL_user_md.data.format_kind = dnnl_blocked; // overrides format
|
hL_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
hL_user_md.data.format_desc.blocking.strides[0] = hL->stridesOf()[0];
|
mkldnnUtils::setBlockStrides(*hL, hL_user_md);
|
||||||
hL_user_md.data.format_desc.blocking.strides[1] = hL->stridesOf()[1];
|
|
||||||
hL_user_md.data.format_desc.blocking.strides[2] = hL->stridesOf()[2];
|
|
||||||
hL_user_md.data.format_desc.blocking.strides[3] = hL->stridesOf()[3];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if(cL) {
|
if(cL) {
|
||||||
cL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
|
cL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
|
||||||
cL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
|
cL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
|
||||||
cL_user_md.data.format_kind = dnnl_blocked; // overrides format
|
mkldnnUtils::setBlockStrides(*cL, cL_user_md);
|
||||||
cL_user_md.data.format_desc.blocking.strides[0] = cL->stridesOf()[0];
|
|
||||||
cL_user_md.data.format_desc.blocking.strides[1] = cL->stridesOf()[1];
|
|
||||||
cL_user_md.data.format_desc.blocking.strides[2] = cL->stridesOf()[2];
|
|
||||||
cL_user_md.data.format_desc.blocking.strides[3] = cL->stridesOf()[3];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// lstm memory description
|
// lstm memory description
|
||||||
|
@ -272,64 +237,49 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
||||||
|
|
||||||
// provide memory and check whether reorder is required
|
// provide memory and check whether reorder is required
|
||||||
// x
|
// x
|
||||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, lstm_prim_desc.src_layer_desc(), args[DNNL_ARG_SRC_LAYER]);
|
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, lstm_prim_desc.src_layer_desc(), args[DNNL_ARG_SRC_LAYER]);
|
||||||
|
|
||||||
// wx
|
// wx
|
||||||
mkldnnUtils::loadDataToMklStream(Wx, engine, stream, wx_user_md, lstm_prim_desc.weights_layer_desc(), args[DNNL_ARG_WEIGHTS_LAYER]);
|
mkldnnUtils::loadDataToMklStream(*Wx, engine, stream, wx_user_md, lstm_prim_desc.weights_layer_desc(), args[DNNL_ARG_WEIGHTS_LAYER]);
|
||||||
|
|
||||||
// wr
|
// wr
|
||||||
mkldnnUtils::loadDataToMklStream(Wr, engine, stream, wr_user_md, lstm_prim_desc.weights_iter_desc(), args[DNNL_ARG_WEIGHTS_ITER]);
|
mkldnnUtils::loadDataToMklStream(*Wr, engine, stream, wr_user_md, lstm_prim_desc.weights_iter_desc(), args[DNNL_ARG_WEIGHTS_ITER]);
|
||||||
|
|
||||||
// h
|
// h
|
||||||
auto h_user_mem = dnnl::memory(h_user_md, engine, h->buffer());
|
auto h_user_mem = mkldnnUtils::loadDataToMklStream(*h, engine, stream, h_user_md, lstm_prim_desc.dst_layer_desc(), args[DNNL_ARG_DST_LAYER]);
|
||||||
const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc();
|
|
||||||
auto h_lstm_mem = hReorder ? dnnl::memory(lstm_prim_desc.dst_layer_desc(), engine) : h_user_mem;
|
|
||||||
args[DNNL_ARG_DST_LAYER] = h_lstm_mem;
|
|
||||||
|
|
||||||
// b
|
// b
|
||||||
if(b) {
|
if(b)
|
||||||
mkldnnUtils::loadDataToMklStream(b, engine, stream, b_user_md, lstm_prim_desc.bias_desc(), args[DNNL_ARG_BIAS]);
|
mkldnnUtils::loadDataToMklStream(*b, engine, stream, b_user_md, lstm_prim_desc.bias_desc(), args[DNNL_ARG_BIAS]);
|
||||||
}
|
|
||||||
|
|
||||||
// hI
|
// hI
|
||||||
if(hI) {
|
if(hI)
|
||||||
mkldnnUtils::loadDataToMklStream(hI, engine, stream, hI_user_md, lstm_prim_desc.src_iter_desc(), args[DNNL_ARG_SRC_ITER]);
|
mkldnnUtils::loadDataToMklStream(*hI, engine, stream, hI_user_md, lstm_prim_desc.src_iter_desc(), args[DNNL_ARG_SRC_ITER]);
|
||||||
}
|
|
||||||
|
|
||||||
// cI
|
// cI
|
||||||
if(cI) {
|
if(cI)
|
||||||
mkldnnUtils::loadDataToMklStream(cI, engine, stream, cI_user_md, lstm_prim_desc.src_iter_c_desc(), args[DNNL_ARG_SRC_ITER_C]);
|
mkldnnUtils::loadDataToMklStream(*cI, engine, stream, cI_user_md, lstm_prim_desc.src_iter_c_desc(), args[DNNL_ARG_SRC_ITER_C]);
|
||||||
}
|
|
||||||
|
|
||||||
bool hLReorder(false), cLReorder(false);
|
|
||||||
dnnl::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem;
|
dnnl::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem;
|
||||||
|
|
||||||
// hL
|
// hL
|
||||||
if(hL) {
|
if(hL)
|
||||||
hL_user_mem = dnnl::memory(hL_user_md, engine, hL->buffer());
|
hL_user_mem = mkldnnUtils::loadDataToMklStream(*hL, engine, stream, hL_user_md, lstm_prim_desc.dst_iter_desc(), args[DNNL_ARG_DST_ITER]);
|
||||||
hLReorder = lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc();
|
|
||||||
hL_lstm_mem = hLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_desc(), engine) : hL_user_mem;
|
|
||||||
args[DNNL_ARG_DST_ITER] = hL_lstm_mem;
|
|
||||||
}
|
|
||||||
|
|
||||||
// cL
|
// cL
|
||||||
if(cL) {
|
if(cL)
|
||||||
cL_user_mem = dnnl::memory(cL_user_md, engine, cL->buffer());
|
cL_user_mem = mkldnnUtils::loadDataToMklStream(*cL, engine, stream, cL_user_md, lstm_prim_desc.dst_iter_c_desc(), args[DNNL_ARG_DST_ITER_C]);
|
||||||
cLReorder = lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc();
|
|
||||||
cL_lstm_mem = cLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_c_desc(), engine) : cL_user_mem;
|
|
||||||
args[DNNL_ARG_DST_ITER_C] = cL_lstm_mem;
|
|
||||||
}
|
|
||||||
|
|
||||||
// run calculations
|
// run calculations
|
||||||
lstm_forward(lstm_prim_desc).execute(stream, args);
|
lstm_forward(lstm_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder outputs if necessary
|
// reorder outputs if necessary
|
||||||
if (hReorder)
|
if (lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc())
|
||||||
reorder(h_lstm_mem, h_user_mem).execute(stream, h_lstm_mem, h_user_mem);
|
reorder(args[DNNL_ARG_DST_LAYER], h_user_mem).execute(stream, args[DNNL_ARG_DST_LAYER], h_user_mem);
|
||||||
if(hLReorder)
|
if(lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc())
|
||||||
reorder(hL_lstm_mem, hL_user_mem).execute(stream, hL_lstm_mem, hL_user_mem);
|
reorder(args[DNNL_ARG_DST_ITER], hL_user_mem).execute(stream, args[DNNL_ARG_DST_ITER], hL_user_mem);
|
||||||
if(cLReorder)
|
if(lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc())
|
||||||
reorder(cL_lstm_mem, cL_user_mem).execute(stream, cL_lstm_mem, cL_user_mem);
|
reorder(args[DNNL_ARG_DST_ITER_C], cL_user_mem).execute(stream, args[DNNL_ARG_DST_ITER_C], cL_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
}
|
}
|
||||||
|
@ -377,9 +327,9 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
|
||||||
auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step
|
auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step
|
||||||
|
|
||||||
// evaluate dimensions
|
// evaluate dimensions
|
||||||
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
|
const Nd4jLong sL = x->sizeAt(dataFormat);
|
||||||
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2);
|
const Nd4jLong bS = dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0);
|
||||||
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1);
|
const Nd4jLong nIn = x->sizeAt(2);
|
||||||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||||||
|
|
||||||
// inputs validations
|
// inputs validations
|
||||||
|
@ -435,14 +385,21 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
|
||||||
|
|
||||||
WxR = new NDArray(Wx->reshape(Wx->ordering(), {1,dirDim,nIn,4,nOut}));
|
WxR = new NDArray(Wx->reshape(Wx->ordering(), {1,dirDim,nIn,4,nOut}));
|
||||||
WrR = new NDArray(Wr->reshape(Wr->ordering(), {1,dirDim,nOut,4,nOut}));
|
WrR = new NDArray(Wr->reshape(Wr->ordering(), {1,dirDim,nOut,4,nOut}));
|
||||||
|
|
||||||
if(b)
|
if(b)
|
||||||
bR = new NDArray(b->reshape(b->ordering(), {1,dirDim,4,nOut}));
|
bR = new NDArray(b->reshape(b->ordering(), {1,dirDim,4,nOut}));
|
||||||
|
else
|
||||||
|
bR = new NDArray(x->ordering(), {1,dirDim,4,nOut}, x->dataType(), x->getContext()); // already nullified
|
||||||
|
|
||||||
if(hI)
|
if(hI)
|
||||||
hIR = new NDArray(hI->reshape(hI->ordering(), {1,dirDim,bS,nOut}));
|
hIR = new NDArray(hI->reshape(hI->ordering(), {1,dirDim,bS,nOut}));
|
||||||
|
|
||||||
if(cI)
|
if(cI)
|
||||||
cIR = new NDArray(cI->reshape(cI->ordering(), {1,dirDim,bS,nOut}));
|
cIR = new NDArray(cI->reshape(cI->ordering(), {1,dirDim,bS,nOut}));
|
||||||
|
|
||||||
if(hL)
|
if(hL)
|
||||||
hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}, false));
|
hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}, false));
|
||||||
|
|
||||||
if(cL)
|
if(cL)
|
||||||
cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}, false));
|
cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}, false));
|
||||||
|
|
||||||
|
|
|
@ -31,20 +31,6 @@ namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
|
|
||||||
dnnl::memory::format_tag get_format_tag(const sd::NDArray &array) {
|
|
||||||
switch (array.rankOf()) {
|
|
||||||
case 1:
|
|
||||||
return dnnl::memory::format_tag::ab;
|
|
||||||
case 2:
|
|
||||||
return array.ordering() == 'c' ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::ba;
|
|
||||||
case 3:
|
|
||||||
return array.ordering() == 'c' ? dnnl::memory::format_tag::abc : dnnl::memory::format_tag::cba;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error("MKLDNN matmul only supports 2D/3D arrays");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY, float alpha = 1.f, float beta = 0.f) {
|
static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY, float alpha = 1.f, float beta = 0.f) {
|
||||||
|
|
||||||
|
@ -123,11 +109,16 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
|
||||||
else if(z->dataType() == DataType::INT8)
|
else if(z->dataType() == DataType::INT8)
|
||||||
zType = dnnl::memory::data_type::s8;
|
zType = dnnl::memory::data_type::s8;
|
||||||
|
|
||||||
|
|
||||||
|
const auto xFormat = xRank == 1 ? dnnl::memory::format_tag::ab : mkldnnUtils::getFormat(*xTR);
|
||||||
|
const auto yFormat = yRank == 1 ? dnnl::memory::format_tag::ab : mkldnnUtils::getFormat(*yTR);
|
||||||
|
const auto zFormat = zRank == 1 ? dnnl::memory::format_tag::ab : mkldnnUtils::getFormat(*zR);
|
||||||
|
|
||||||
// memory descriptors for arrays
|
// memory descriptors for arrays
|
||||||
|
dnnl::memory::desc x_mkl_md, x_user_md, y_mkl_md, y_user_md, z_mkl_md, z_user_md;
|
||||||
|
|
||||||
// x
|
// x
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, get_format_tag(*xTR));
|
x_user_md = x_mkl_md = dnnl::memory::desc(xShape, xType, xFormat);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, get_format_tag(*xTR));
|
|
||||||
if(xTR->ews() != 1) {
|
if(xTR->ews() != 1) {
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = xRank == 1 ? 1 : xTR->strideAt(0);
|
x_user_md.data.format_desc.blocking.strides[0] = xRank == 1 ? 1 : xTR->strideAt(0);
|
||||||
|
@ -137,8 +128,7 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
|
||||||
}
|
}
|
||||||
|
|
||||||
// y
|
// y
|
||||||
dnnl::memory::desc y_mkl_md = dnnl::memory::desc(yShape, yType, get_format_tag(*yTR));
|
y_user_md = y_mkl_md = dnnl::memory::desc(yShape, yType, yFormat);
|
||||||
dnnl::memory::desc y_user_md = dnnl::memory::desc(yShape, yType, get_format_tag(*yTR));
|
|
||||||
if(yTR->ews() != 1) {
|
if(yTR->ews() != 1) {
|
||||||
y_user_md.data.format_kind = dnnl_blocked; // overrides format
|
y_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
y_user_md.data.format_desc.blocking.strides[0] = yRank == 1 ? 1 : yTR->strideAt(0);
|
y_user_md.data.format_desc.blocking.strides[0] = yRank == 1 ? 1 : yTR->strideAt(0);
|
||||||
|
@ -148,8 +138,7 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
|
||||||
}
|
}
|
||||||
|
|
||||||
// z
|
// z
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, get_format_tag(*zR));
|
z_user_md = z_mkl_md = dnnl::memory::desc(zShape, zType, zFormat);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, get_format_tag(*zR));
|
|
||||||
if(zR->ews() != 1) {
|
if(zR->ews() != 1) {
|
||||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
z_user_md.data.format_desc.blocking.strides[0] = zRank == 1 ? 1 : zR->strideAt(0);
|
z_user_md.data.format_desc.blocking.strides[0] = zRank == 1 ? 1 : zR->strideAt(0);
|
||||||
|
@ -181,37 +170,20 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(xTR, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*xTR, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
/*
|
|
||||||
auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->buffer());
|
|
||||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
|
||||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
|
||||||
if (xReorder)
|
|
||||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
|
||||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
|
||||||
*/
|
|
||||||
// y
|
// y
|
||||||
mkldnnUtils::loadDataToMklStream(yTR, engine, stream, y_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
mkldnnUtils::loadDataToMklStream(*yTR, engine, stream, y_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||||
/*
|
|
||||||
auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->buffer());
|
|
||||||
const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc();
|
|
||||||
auto y_mkl_mem = yReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : y_user_mem;
|
|
||||||
if (yReorder)
|
|
||||||
dnnl::reorder(y_user_mem, y_mkl_mem).execute(stream, y_user_mem, y_mkl_mem);
|
|
||||||
args[DNNL_ARG_WEIGHTS] = y_mkl_mem;
|
|
||||||
*/
|
|
||||||
// z
|
// z
|
||||||
auto z_user_mem = dnnl::memory(z_user_md, engine, zR->buffer());
|
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*zR, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
|
||||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
|
||||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
|
||||||
|
|
||||||
// run calculations
|
// run calculations
|
||||||
dnnl::matmul(op_prim_desc).execute(stream, args);
|
dnnl::matmul(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder outputs if necessary
|
// reorder outputs if necessary
|
||||||
if (zReorder)
|
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
|
|
||||||
|
|
|
@ -38,45 +38,65 @@ void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims){
|
||||||
mklDims = dnnl::memory::dims(vDims);
|
mklDims = dnnl::memory::dims(vDims);
|
||||||
}
|
}
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
dnnl::memory::format_tag getFormat(const int rank){
|
dnnl::memory::format_tag getFormat(const NDArray& arr) {
|
||||||
if (2 == rank) {
|
|
||||||
return dnnl::memory::format_tag::ab;
|
dnnl::memory::format_tag result;
|
||||||
}
|
|
||||||
else if (3 == rank) {
|
switch (arr.rankOf()) {
|
||||||
return dnnl::memory::format_tag::abc;
|
case 1:
|
||||||
}
|
result = dnnl::memory::format_tag::a;
|
||||||
else if (4 == rank) {
|
break;
|
||||||
return dnnl::memory::format_tag::abcd;
|
case 2:
|
||||||
}
|
result = arr.ordering() == 'c' ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::ba;
|
||||||
else if (5 == rank) {
|
break;
|
||||||
return dnnl::memory::format_tag::abcde;
|
case 3:
|
||||||
}
|
result = arr.ordering() == 'c' ? dnnl::memory::format_tag::abc : dnnl::memory::format_tag::cba;
|
||||||
else if (6 == rank) {
|
break;
|
||||||
return dnnl::memory::format_tag::abcdef;
|
case 4:
|
||||||
}
|
result = dnnl::memory::format_tag::abcd;
|
||||||
return dnnl::memory::format_tag::a; // 1 == dataSetRank
|
break;
|
||||||
|
case 5:
|
||||||
|
result = dnnl::memory::format_tag::abcde;
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
result = dnnl::memory::format_tag::abcdef;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("MKLDNN getFormat: do we really want to use arras with rank > 6 ?");
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd){
|
void setBlockStrides(const NDArray& array, dnnl::memory::desc& mklMd, const std::vector<int>& permut) {
|
||||||
|
|
||||||
if (array->ews() != 1 || array->ordering() != 'c') {
|
if (array.ews() != 1 || (array.rankOf() > 3 && array.ordering() == 'f') || !permut.empty()) {
|
||||||
mklMd.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
for (auto i = 0; i < array->rankOf(); ++i) {
|
mklMd.data.format_kind = dnnl_blocked; // overrides format
|
||||||
mklMd.data.format_desc.blocking.strides[i] = array->strideAt(i);
|
|
||||||
|
if(permut.empty())
|
||||||
|
for (auto i = 0; i < array.rankOf(); ++i)
|
||||||
|
mklMd.data.format_desc.blocking.strides[i] = array.strideAt(i);
|
||||||
|
else {
|
||||||
|
if(array.rankOf() != permut.size())
|
||||||
|
throw std::invalid_argument("mkldnnUtils::setBlockStrides: size of permut vector is not equal to array rank !");
|
||||||
|
for (auto i = 0; i < array.rankOf(); ++i)
|
||||||
|
mklMd.data.format_desc.blocking.strides[i] = array.strideAt(permut[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md,
|
dnnl::memory loadDataToMklStream(const NDArray& array, const dnnl::engine& engine, const dnnl::stream& stream,
|
||||||
dnnl::memory& arg) {
|
const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md, dnnl::memory& arg) {
|
||||||
|
|
||||||
auto user_mem = dnnl::memory(user_md, engine,const_cast<void*>(array->buffer()));
|
auto user_mem = dnnl::memory(user_md, engine, const_cast<NDArray&>(array).buffer());
|
||||||
const bool bReorder = primitive_md != user_mem.get_desc();
|
const bool bReorder = primitive_md != user_mem.get_desc();
|
||||||
auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem;
|
auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem;
|
||||||
if (bReorder)
|
if (bReorder)
|
||||||
dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem);
|
dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem);
|
||||||
arg = mkl_mem;
|
arg = mkl_mem;
|
||||||
|
return user_mem;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -122,33 +142,21 @@ void poolingMKLDNN(const NDArray *input, NDArray *output,
|
||||||
xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> permut;
|
||||||
|
if(!isNCHW)
|
||||||
|
permut = rank == 4 ? std::vector<int>({0,3,1,2}) : std::vector<int>({0,4,1,2,3});
|
||||||
|
|
||||||
// memory descriptors for arrays
|
// memory descriptors for arrays
|
||||||
|
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(*input, x_user_md, permut);
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
|
|
||||||
if(rank == 5)
|
|
||||||
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// output
|
// output
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||||
if(output->ews() != 1 || output->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(*output, z_user_md, permut);
|
||||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(isNCHW ? 1 :-1);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(isNCHW ? 2 : 1);
|
|
||||||
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(isNCHW ? 3 : 2);
|
|
||||||
if(rank == 5)
|
|
||||||
z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(isNCHW ? 4 : 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -164,20 +172,17 @@ void poolingMKLDNN(const NDArray *input, NDArray *output,
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// output
|
// output
|
||||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
|
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
|
||||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
|
||||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
|
||||||
|
|
||||||
// run calculations
|
// run calculations
|
||||||
dnnl::pooling_forward(op_prim_desc).execute(stream, args);
|
dnnl::pooling_forward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder outputs if necessary
|
// reorder outputs if necessary
|
||||||
if (zReorder)
|
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
}
|
}
|
||||||
|
@ -226,46 +231,27 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
|
||||||
xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> permut;
|
||||||
|
if(!isNCHW)
|
||||||
|
permut = rank == 4 ? std::vector<int>({0,3,1,2}) : std::vector<int>({0,4,1,2,3});
|
||||||
|
|
||||||
|
|
||||||
// memory descriptors for arrays
|
// memory descriptors for arrays
|
||||||
|
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(*input, x_user_md, permut);
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
|
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
|
|
||||||
if(rank == 5)
|
|
||||||
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||||
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md, permut);
|
||||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(isNCHW ? 1 :-1);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(isNCHW ? 2 : 1);
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(isNCHW ? 3 : 2);
|
|
||||||
if(rank == 5)
|
|
||||||
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(isNCHW ? 4 : 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md, permut);
|
||||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(isNCHW ? 1 :-1);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(isNCHW ? 2 : 1);
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(isNCHW ? 3 : 2);
|
|
||||||
if(rank == 5)
|
|
||||||
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(isNCHW ? 4 : 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
dnnl::stream stream(engine);
|
dnnl::stream stream(engine);
|
||||||
|
@ -282,18 +268,15 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
|
||||||
std::unordered_map<int, dnnl::memory> args;
|
std::unordered_map<int, dnnl::memory> args;
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
mkldnnUtils::loadDataToMklStream(*gradO, engine, stream, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
|
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||||
const bool gradIReorder = op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
|
||||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
|
||||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
|
||||||
|
|
||||||
if(mode == algorithm::pooling_max) {
|
if(mode == algorithm::pooling_max) {
|
||||||
|
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// z
|
// z
|
||||||
auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine);
|
auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine);
|
||||||
|
@ -310,10 +293,9 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
|
||||||
// run backward calculations
|
// run backward calculations
|
||||||
dnnl::pooling_backward(op_bp_prim_desc).execute(stream, args);
|
dnnl::pooling_backward(op_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
|
||||||
// reorder gradI if necessary
|
// reorder gradI if necessary
|
||||||
if (gradIReorder)
|
if (op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
|
||||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
}
|
}
|
||||||
|
|
|
@ -100,6 +100,8 @@ namespace sd {
|
||||||
|
|
||||||
DECLARE_PLATFORM(xw_plus_b_bp, ENGINE_CPU);
|
DECLARE_PLATFORM(xw_plus_b_bp, ENGINE_CPU);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(concat, ENGINE_CPU);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -123,19 +125,13 @@ namespace sd {
|
||||||
*/
|
*/
|
||||||
void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims);
|
void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims);
|
||||||
/**
|
/**
|
||||||
* This function generate memory format tag based on rank
|
* This function evaluate memory format tag based on array shapeInfo
|
||||||
* @param const array rank
|
* @param const array
|
||||||
* @return memory format
|
* @return memory format
|
||||||
*/
|
*/
|
||||||
dnnl::memory::format_tag getFormat(const int rank);
|
dnnl::memory::format_tag getFormat(const NDArray& arr);
|
||||||
/**
|
|
||||||
* This function generate memory format tag based on rank
|
void setBlockStrides(const NDArray& array, dnnl::memory::desc& mklMd, const std::vector<int>& permut = {});
|
||||||
* @param const pointer to dataset
|
|
||||||
* @param const dataset rank
|
|
||||||
* @param reference to memory descriptor
|
|
||||||
* @return memory format
|
|
||||||
*/
|
|
||||||
void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd);
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
/**
|
/**
|
||||||
* This function load and reorder user memory to mkl
|
* This function load and reorder user memory to mkl
|
||||||
|
@ -147,7 +143,7 @@ namespace sd {
|
||||||
* @param primitive memory descriptor
|
* @param primitive memory descriptor
|
||||||
* @param dnnl arg activation enumerator
|
* @param dnnl arg activation enumerator
|
||||||
*/
|
*/
|
||||||
void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md,
|
dnnl::memory loadDataToMklStream(const NDArray& array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md,
|
||||||
dnnl::memory& arg);
|
dnnl::memory& arg);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -35,32 +35,37 @@ namespace sd {
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
static void softmaxMKLDNN(const NDArray* x, NDArray* z, const int axis) {
|
static void softmaxMKLDNN(const NDArray* x, NDArray* z, const int axis) {
|
||||||
|
|
||||||
const auto xRank = x->rankOf();
|
dnnl::memory::dims shape = x->getShapeAsFlatVector();
|
||||||
dnnl::memory::dims xShape, zShape;
|
|
||||||
|
|
||||||
mkldnnUtils::getDims(x, xRank, xShape);
|
const int xRank = x->rankOf();
|
||||||
mkldnnUtils::getDims(z, xRank, zShape);
|
|
||||||
|
|
||||||
|
dnnl::memory::format_tag xFormat = mkldnnUtils::getFormat(*x);
|
||||||
|
dnnl::memory::format_tag zFormat = mkldnnUtils::getFormat(*z);
|
||||||
|
|
||||||
dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
|
|
||||||
// optimized cases
|
// optimized cases
|
||||||
if (2 == xRank && 0 == axis) {
|
if (2 == xRank && 0 == axis) {
|
||||||
format = dnnl::memory::format_tag::ba;
|
if(x->ews() == 1)
|
||||||
|
xFormat = dnnl::memory::format_tag::ba;
|
||||||
|
if(z->ews() == 1)
|
||||||
|
zFormat = dnnl::memory::format_tag::ba;
|
||||||
}
|
}
|
||||||
else if (4 == xRank && 1 == axis && (x->sizeAt(2) * x->sizeAt(3)) > 1) {
|
else if (4 == xRank && 1 == axis && (x->sizeAt(2) * x->sizeAt(3)) > 1) {
|
||||||
format = dnnl::memory::format_tag::acdb;
|
if(x->ews() == 1)
|
||||||
|
xFormat = dnnl::memory::format_tag::acdb;
|
||||||
|
if(z->ews() == 1)
|
||||||
|
zFormat = dnnl::memory::format_tag::acdb;
|
||||||
}
|
}
|
||||||
|
|
||||||
dnnl::memory::data_type xType = dnnl::memory::data_type::f32;
|
dnnl::memory::data_type xType = dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format);
|
dnnl::memory::desc x_mkl_md, x_user_md, z_mkl_md, z_user_md;
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format);
|
|
||||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
x_user_md = x_mkl_md = dnnl::memory::desc(shape, xType, xFormat);
|
||||||
|
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||||
|
|
||||||
// z
|
// z
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, xType, format);
|
z_user_md = z_mkl_md = dnnl::memory::desc(shape, xType, zFormat);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, xType, format);
|
mkldnnUtils::setBlockStrides(*z, z_user_md);
|
||||||
mkldnnUtils::setBlockStrides(z, z_user_md);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -80,20 +85,17 @@ namespace sd {
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// z
|
// z
|
||||||
auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer());
|
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
|
||||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
|
||||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
|
||||||
|
|
||||||
// run calculations
|
// run calculations
|
||||||
dnnl::softmax_forward(op_prim_desc).execute(stream, args);
|
dnnl::softmax_forward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder outputs if necessary
|
// reorder outputs if necessary
|
||||||
if (zReorder)
|
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
}
|
}
|
||||||
|
@ -142,33 +144,19 @@ namespace sd {
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
static void softmaxBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx, const int axis) {
|
static void softmaxBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx, const int axis) {
|
||||||
|
|
||||||
const auto xRank = x->rankOf();
|
dnnl::memory::desc x_user_md, x_mkl_md, dLdx_mkl_md, dLdx_user_md, dLdz_mkl_md, dLdz_user_md;
|
||||||
const auto dLdzRank = dLdz->rankOf();
|
|
||||||
|
|
||||||
dnnl::memory::dims xShape, dLdxShape, dLdzShape;
|
|
||||||
|
|
||||||
mkldnnUtils::getDims(x, xRank, xShape);
|
|
||||||
mkldnnUtils::getDims(dLdx, xRank, dLdxShape);
|
|
||||||
mkldnnUtils::getDims(dLdz, dLdzRank, dLdzShape);
|
|
||||||
|
|
||||||
dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
|
|
||||||
|
|
||||||
// x
|
// x
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
x_mkl_md = x_user_md = dnnl::memory::desc(x->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x));
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
|
||||||
|
|
||||||
// dLdx
|
// dLdx
|
||||||
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format);
|
dLdx_mkl_md = dLdx_user_md = dnnl::memory::desc(dLdx->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdx));
|
||||||
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format);
|
mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md);
|
||||||
mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md);
|
|
||||||
// todo if mkl does not support broadcast we can remove this
|
|
||||||
format = mkldnnUtils::getFormat(dLdzRank);
|
|
||||||
|
|
||||||
// dLdz
|
// dLdz
|
||||||
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format);
|
dLdz_mkl_md = dLdz_user_md = dnnl::memory::desc(dLdz->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdz));
|
||||||
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format);
|
mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md);
|
||||||
mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -188,19 +176,18 @@ namespace sd {
|
||||||
|
|
||||||
// provide memory buffers and check whether reorder is required for forward
|
// provide memory buffers and check whether reorder is required for forward
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
|
// dLdz
|
||||||
|
mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), argsbp[DNNL_ARG_DIFF_DST]);
|
||||||
|
|
||||||
// dLdx
|
// dLdx
|
||||||
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer());
|
auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_DST]);
|
||||||
const bool dLdxReorder = op_ff_prim_desc.dst_desc() != dLdx_user_mem.get_desc();
|
|
||||||
auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : dLdx_user_mem;
|
|
||||||
argsff[DNNL_ARG_DST] = dLdx_mkl_mem;
|
|
||||||
|
|
||||||
// check and arg set for backprob
|
// check and arg set for backprob
|
||||||
argsbp[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem;
|
argsbp[DNNL_ARG_DIFF_SRC] = argsff[DNNL_ARG_DST];
|
||||||
argsbp[DNNL_ARG_DST] = dLdx_mkl_mem;
|
argsbp[DNNL_ARG_DST] = argsff[DNNL_ARG_DST];
|
||||||
// dLdz
|
|
||||||
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), argsbp[DNNL_ARG_DIFF_DST]);
|
|
||||||
|
|
||||||
// run calculations forward
|
// run calculations forward
|
||||||
dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff);
|
dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff);
|
||||||
|
@ -209,8 +196,8 @@ namespace sd {
|
||||||
dnnl::softmax_backward(op_bp_prim_desc).execute(stream, argsbp);
|
dnnl::softmax_backward(op_bp_prim_desc).execute(stream, argsbp);
|
||||||
|
|
||||||
// reorder outputs if necessary
|
// reorder outputs if necessary
|
||||||
if (dLdxReorder)
|
if (op_ff_prim_desc.dst_desc() != dLdx_user_mem.get_desc())
|
||||||
dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem);
|
dnnl::reorder(argsff[DNNL_ARG_DST], dLdx_user_mem).execute(stream, argsff[DNNL_ARG_DST], dLdx_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,22 +34,16 @@ namespace sd {
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
static void tanhMKLDNN(const NDArray* x, NDArray* z) {
|
static void tanhMKLDNN(const NDArray* x, NDArray* z) {
|
||||||
|
|
||||||
const auto xRank = x->rankOf();
|
dnnl::memory::dims shape = x->getShapeAsFlatVector();
|
||||||
dnnl::memory::dims xShape, zShape;
|
|
||||||
|
|
||||||
mkldnnUtils::getDims(x, xRank, xShape);
|
dnnl::memory::desc x_mkl_md, x_user_md, z_mkl_md, z_user_md;
|
||||||
mkldnnUtils::getDims(z, xRank, zShape);
|
|
||||||
|
|
||||||
dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
|
x_user_md = x_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x));
|
||||||
|
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
|
||||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
|
||||||
|
|
||||||
// z
|
// z
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format);
|
z_user_md = z_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*z));
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format);
|
mkldnnUtils::setBlockStrides(*z, z_user_md);
|
||||||
mkldnnUtils::setBlockStrides(z, z_user_md);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -68,20 +62,17 @@ namespace sd {
|
||||||
|
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// z
|
// z
|
||||||
auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer());
|
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
|
||||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
|
||||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
|
||||||
|
|
||||||
// run calculations
|
// run calculations
|
||||||
dnnl::eltwise_forward(op_prim_desc).execute(stream, args);
|
dnnl::eltwise_forward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder outputs if necessary
|
// reorder outputs if necessary
|
||||||
if (zReorder)
|
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
}
|
}
|
||||||
|
@ -121,28 +112,21 @@ namespace sd {
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
static void tanhBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx) {
|
static void tanhBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx) {
|
||||||
|
|
||||||
const auto xRank = x->rankOf();
|
dnnl::memory::dims shape = x->getShapeAsFlatVector();
|
||||||
dnnl::memory::dims xShape, dLdzShape, dLdxShape;
|
|
||||||
|
|
||||||
mkldnnUtils::getDims(x, xRank, xShape);
|
dnnl::memory::desc x_mkl_md, x_user_md, dLdx_mkl_md, dLdx_user_md, dLdz_mkl_md, dLdz_user_md;
|
||||||
mkldnnUtils::getDims(dLdz, xRank, dLdzShape);
|
|
||||||
mkldnnUtils::getDims(dLdx, xRank, dLdxShape);
|
|
||||||
|
|
||||||
dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
|
// x
|
||||||
|
x_user_md = x_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x));
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
|
||||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
|
||||||
|
|
||||||
// dLdz
|
// dLdz
|
||||||
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
dLdz_user_md = dLdz_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdz));
|
||||||
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md);
|
||||||
mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md);
|
|
||||||
|
|
||||||
// dLdx
|
// dLdx
|
||||||
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
dLdx_user_md = dLdx_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdx));
|
||||||
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md);
|
||||||
mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -162,23 +146,20 @@ namespace sd {
|
||||||
|
|
||||||
// provide memory buffers and check whether reorder is required for forward
|
// provide memory buffers and check whether reorder is required for forward
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// dLdz
|
// dLdz
|
||||||
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
||||||
|
|
||||||
// dLdx
|
// dLdx
|
||||||
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer());
|
auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, op_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||||
const bool dLdxReorder = op_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc();
|
|
||||||
auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_prim_desc.diff_src_desc(), engine) : dLdx_user_mem;
|
|
||||||
args[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem;
|
|
||||||
|
|
||||||
// run calculations backward
|
// run calculations backward
|
||||||
dnnl::eltwise_backward(op_prim_desc).execute(stream, args);
|
dnnl::eltwise_backward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder outputs if necessary
|
// reorder outputs if necessary
|
||||||
if (dLdxReorder)
|
if (op_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc())
|
||||||
dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], dLdx_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], dLdx_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,33 +82,23 @@ namespace sd {
|
||||||
// memory descriptors for arrays
|
// memory descriptors for arrays
|
||||||
// x
|
// x
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, mkldnnUtils::getFormat(*x));
|
||||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, wType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, wType, format);
|
dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, wType, mkldnnUtils::getFormat(*weights));
|
||||||
if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) {
|
mkldnnUtils::setBlockStrides(*weights, weights_user_md, bShouldTransp ? std::vector<int>({1,0}) : std::vector<int>());
|
||||||
|
|
||||||
weights_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
if (bShouldTransp) {
|
|
||||||
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1);
|
|
||||||
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0);
|
|
||||||
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// bias
|
// bias
|
||||||
dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x);
|
dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::a);
|
||||||
dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x);
|
dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::a);
|
||||||
mkldnnUtils::setBlockStrides(bias, bias_user_md);
|
mkldnnUtils::setBlockStrides(*bias, bias_user_md);
|
||||||
|
|
||||||
// z
|
// z
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, mkldnnUtils::getFormat(*z));
|
||||||
mkldnnUtils::setBlockStrides(z, z_user_md);
|
mkldnnUtils::setBlockStrides(*z, z_user_md);
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -125,27 +115,24 @@ namespace sd {
|
||||||
// provide memory buffers and check whether reorder is required
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
// input
|
// input
|
||||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, weights_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
auto bias_mkl_mem = dnnl::memory(bias_mkl_md, engine, const_cast<void*>(bias->buffer()));
|
auto bias_mkl_mem = dnnl::memory(bias_mkl_md, engine, const_cast<void*>(bias->buffer()));
|
||||||
args[DNNL_ARG_BIAS] = bias_mkl_mem;
|
args[DNNL_ARG_BIAS] = bias_mkl_mem;
|
||||||
|
|
||||||
// z
|
// z
|
||||||
auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer());
|
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
|
||||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
|
||||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
|
||||||
|
|
||||||
// run calculations
|
// run calculations
|
||||||
dnnl::inner_product_forward(op_prim_desc).execute(stream, args);
|
dnnl::inner_product_forward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
// reorder outputs if necessary
|
// reorder outputs if necessary
|
||||||
if (zReorder)
|
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
}
|
}
|
||||||
|
@ -160,7 +147,7 @@ namespace sd {
|
||||||
|
|
||||||
// [M,K] x [K,N] = [M,N]
|
// [M,K] x [K,N] = [M,N]
|
||||||
const int M = x->sizeAt(0);
|
const int M = x->sizeAt(0);
|
||||||
const int K = x->sizeAt(1); // K == wK
|
const int K = x->sizeAt(1); // K == wK
|
||||||
const int N = dLdz->sizeAt(1);
|
const int N = dLdz->sizeAt(1);
|
||||||
// input dims
|
// input dims
|
||||||
dnnl::memory::dims xShape = dnnl::memory::dims({ M, K });
|
dnnl::memory::dims xShape = dnnl::memory::dims({ M, K });
|
||||||
|
@ -168,71 +155,53 @@ namespace sd {
|
||||||
dnnl::memory::dims dLdzShape = dnnl::memory::dims({ M, N });
|
dnnl::memory::dims dLdzShape = dnnl::memory::dims({ M, N });
|
||||||
|
|
||||||
dnnl::memory::dims bShape = dnnl::memory::dims({ N });
|
dnnl::memory::dims bShape = dnnl::memory::dims({ N });
|
||||||
|
|
||||||
// output dims
|
// output dims
|
||||||
dnnl::memory::dims dLdxShape = xShape;
|
dnnl::memory::dims dLdxShape = xShape;
|
||||||
dnnl::memory::dims dLdwShape = wShape;
|
dnnl::memory::dims dLdwShape = wShape;
|
||||||
|
|
||||||
dnnl::memory::format_tag format = dnnl::memory::format_tag::ab;
|
|
||||||
dnnl::memory::data_type dataType = dnnl::memory::data_type::f32;
|
dnnl::memory::data_type dataType = dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
// memory descriptors for arrays
|
// memory descriptors for arrays
|
||||||
// x
|
// x
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dataType, format);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dataType, mkldnnUtils::getFormat(*x));
|
||||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, dataType, format);
|
dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, dataType, mkldnnUtils::getFormat(*weights));
|
||||||
if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) {
|
mkldnnUtils::setBlockStrides(*weights, weights_user_md, bShouldTransp ? std::vector<int>({1,0}) : std::vector<int>());
|
||||||
|
|
||||||
weights_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
if (bShouldTransp) {
|
|
||||||
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1);
|
|
||||||
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0);
|
|
||||||
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// bias
|
// bias
|
||||||
dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
|
dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
|
dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, dataType, mkldnnUtils::getFormat(*bias));
|
||||||
mkldnnUtils::setBlockStrides(bias, bias_user_md);
|
mkldnnUtils::setBlockStrides(*bias, bias_user_md);
|
||||||
|
|
||||||
// dLdz
|
// dLdz
|
||||||
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dataType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dataType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dataType, format);
|
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dataType, mkldnnUtils::getFormat(*dLdz));
|
||||||
mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md);
|
mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md);
|
||||||
|
|
||||||
|
|
||||||
// dLdw
|
// dLdw
|
||||||
dnnl::memory::desc dLdw_mkl_md = dnnl::memory::desc(wShape, dataType, format);
|
dnnl::memory::desc dLdw_mkl_md = dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc dLdw_user_md = dnnl::memory::desc(wShape, dataType, format);
|
dnnl::memory::desc dLdw_user_md = dnnl::memory::desc(wShape, dataType, mkldnnUtils::getFormat(*dLdw));
|
||||||
if (dLdw->ews() != 1 || dLdw->ordering() != 'c' || bShouldTransp) {
|
mkldnnUtils::setBlockStrides(*dLdw, dLdw_user_md, bShouldTransp ? std::vector<int>({1,0}) : std::vector<int>());
|
||||||
|
|
||||||
dLdw_user_md.data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
if (bShouldTransp) {
|
|
||||||
dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(1);
|
|
||||||
dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(0);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(0);
|
|
||||||
dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// dLdb
|
// dLdb
|
||||||
dnnl::memory::desc dLdb_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
|
dnnl::memory::desc dLdb_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc dLdb_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
|
dnnl::memory::desc dLdb_user_md = dnnl::memory::desc(bShape, dataType, mkldnnUtils::getFormat(*dLdb));
|
||||||
mkldnnUtils::setBlockStrides(dLdb, dLdb_user_md);
|
mkldnnUtils::setBlockStrides(*dLdb, dLdb_user_md);
|
||||||
|
|
||||||
// dLdx
|
// dLdx
|
||||||
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dataType, format);
|
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dataType, mkldnnUtils::getFormat(*dLdx));
|
||||||
mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md);
|
mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md);
|
||||||
|
|
||||||
|
// create engine
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
// forward
|
// forward
|
||||||
// operation primitive description
|
// operation primitive description
|
||||||
dnnl::inner_product_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, dLdz_mkl_md);
|
dnnl::inner_product_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, dLdz_mkl_md);
|
||||||
|
@ -254,34 +223,25 @@ namespace sd {
|
||||||
dnnl::stream stream(engine);
|
dnnl::stream stream(engine);
|
||||||
|
|
||||||
// dLdz dw
|
// dLdz dw
|
||||||
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdw_prim_desc.diff_dst_desc(), argsDw[DNNL_ARG_DIFF_DST]);
|
mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_bpdw_prim_desc.diff_dst_desc(), argsDw[DNNL_ARG_DIFF_DST]);
|
||||||
|
|
||||||
// dLdz - dx
|
// dLdz - dx
|
||||||
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdx_prim_desc.diff_dst_desc(), argsDx[DNNL_ARG_DIFF_DST]);
|
mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_bpdx_prim_desc.diff_dst_desc(), argsDx[DNNL_ARG_DIFF_DST]);
|
||||||
|
|
||||||
// input x for dw
|
// input x for dw
|
||||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bpdw_prim_desc.src_desc(), argsDw[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_bpdw_prim_desc.src_desc(), argsDw[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// weights - dx
|
// weights - dx
|
||||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_bpdx_prim_desc.weights_desc(), argsDx[DNNL_ARG_WEIGHTS]);
|
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, weights_user_md, op_bpdx_prim_desc.weights_desc(), argsDx[DNNL_ARG_WEIGHTS]);
|
||||||
|
|
||||||
// dLdw
|
// dLdw
|
||||||
auto dLdw_user_mem = dnnl::memory(dLdw_user_md, engine, dLdw->buffer());
|
auto dLdw_user_mem = mkldnnUtils::loadDataToMklStream(*dLdw, engine, stream, dLdw_user_md, op_bpdw_prim_desc.diff_weights_desc(), argsDw[DNNL_ARG_DIFF_WEIGHTS]);
|
||||||
const bool dLdwReorder = op_bpdw_prim_desc.diff_weights_desc() != dLdw_user_mem.get_desc();
|
|
||||||
auto dLdw_mkl_mem = dLdwReorder ? dnnl::memory(op_bpdw_prim_desc.diff_weights_desc(), engine) : dLdw_user_mem;
|
|
||||||
argsDw[DNNL_ARG_DIFF_WEIGHTS] = dLdw_mkl_mem;
|
|
||||||
|
|
||||||
// dLdx
|
// dLdx
|
||||||
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer());
|
auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, op_bpdx_prim_desc.diff_src_desc(), argsDx[DNNL_ARG_DIFF_SRC]);
|
||||||
const bool dLdxReorder = op_bpdx_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc();
|
|
||||||
auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_bpdx_prim_desc.diff_src_desc(), engine) : dLdx_user_mem;
|
|
||||||
argsDx[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem;
|
|
||||||
|
|
||||||
// dLdb
|
// dLdb
|
||||||
auto dLdb_user_mem = dnnl::memory(dLdb_user_md, engine, dLdb->buffer());
|
auto dLdb_user_mem = mkldnnUtils::loadDataToMklStream(*dLdb, engine, stream, dLdb_user_md, op_bpdw_prim_desc.diff_bias_desc(), argsDw[DNNL_ARG_DIFF_BIAS]);
|
||||||
const bool dLdbReorder = op_bpdw_prim_desc.diff_bias_desc() != dLdb_user_mem.get_desc();
|
|
||||||
auto dLdb_mkl_mem = dLdbReorder ? dnnl::memory(op_bpdw_prim_desc.diff_bias_desc(), engine) : dLdb_user_mem;
|
|
||||||
argsDw[DNNL_ARG_DIFF_BIAS] = dLdb_mkl_mem;
|
|
||||||
|
|
||||||
// run calculations dw
|
// run calculations dw
|
||||||
dnnl::inner_product_backward_weights(op_bpdw_prim_desc).execute(stream, argsDw);
|
dnnl::inner_product_backward_weights(op_bpdw_prim_desc).execute(stream, argsDw);
|
||||||
|
@ -289,14 +249,14 @@ namespace sd {
|
||||||
dnnl::inner_product_backward_data(op_bpdx_prim_desc).execute(stream, argsDx);
|
dnnl::inner_product_backward_data(op_bpdx_prim_desc).execute(stream, argsDx);
|
||||||
|
|
||||||
// reorder outputs if necessary
|
// reorder outputs if necessary
|
||||||
if (dLdxReorder)
|
if (op_bpdx_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc())
|
||||||
dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem);
|
dnnl::reorder(argsDx[DNNL_ARG_DIFF_SRC], dLdx_user_mem).execute(stream, argsDx[DNNL_ARG_DIFF_SRC], dLdx_user_mem);
|
||||||
|
|
||||||
if (dLdwReorder)
|
if (op_bpdw_prim_desc.diff_weights_desc() != dLdw_user_mem.get_desc())
|
||||||
dnnl::reorder(dLdw_mkl_mem, dLdw_user_mem).execute(stream, dLdw_mkl_mem, dLdw_user_mem);
|
dnnl::reorder(argsDw[DNNL_ARG_DIFF_WEIGHTS], dLdw_user_mem).execute(stream, argsDw[DNNL_ARG_DIFF_WEIGHTS], dLdw_user_mem);
|
||||||
|
|
||||||
if (dLdbReorder)
|
if (op_bpdw_prim_desc.diff_bias_desc() != dLdb_user_mem.get_desc())
|
||||||
dnnl::reorder(dLdb_mkl_mem, dLdb_user_mem).execute(stream, dLdb_mkl_mem, dLdb_user_mem);
|
dnnl::reorder(argsDw[DNNL_ARG_DIFF_BIAS], dLdb_user_mem).execute(stream, argsDw[DNNL_ARG_DIFF_BIAS], dLdb_user_mem);
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
}
|
}
|
||||||
|
@ -315,7 +275,7 @@ namespace sd {
|
||||||
const int wRank = w->rankOf();
|
const int wRank = w->rankOf();
|
||||||
const int zRank = z->rankOf();
|
const int zRank = z->rankOf();
|
||||||
|
|
||||||
const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N]
|
const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N]
|
||||||
|
|
||||||
REQUIRE_TRUE(xRank == 2, 0, "xw_plus_b MKL: Input x array should have rank equal 2, but got instead %i!", xRank);
|
REQUIRE_TRUE(xRank == 2, 0, "xw_plus_b MKL: Input x array should have rank equal 2, but got instead %i!", xRank);
|
||||||
REQUIRE_TRUE(wRank == 2, 0, "xw_plus_b MKL: Input weights array should have rank equal 2, but got instead %i!", wRank);
|
REQUIRE_TRUE(wRank == 2, 0, "xw_plus_b MKL: Input weights array should have rank equal 2, but got instead %i!", wRank);
|
||||||
|
@ -378,7 +338,7 @@ namespace sd {
|
||||||
const int wRank = w->rankOf();
|
const int wRank = w->rankOf();
|
||||||
const int dLdzRank = dLdz->rankOf();
|
const int dLdzRank = dLdz->rankOf();
|
||||||
|
|
||||||
const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N]
|
const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N]
|
||||||
|
|
||||||
REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP MKL: Input x array should have rank equal 2, but got instead %i!", x->rankOf());
|
REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP MKL: Input x array should have rank equal 2, but got instead %i!", x->rankOf());
|
||||||
REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP MKL: Input weights array should have rank equal 2, but got instead %i!", w->rankOf());
|
REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP MKL: Input weights array should have rank equal 2, but got instead %i!", w->rankOf());
|
||||||
|
|
|
@ -107,6 +107,25 @@ namespace sd {
|
||||||
// samediff::Threads::parallel_tad(func, 0, numOfArrs);
|
// samediff::Threads::parallel_tad(func, 0, numOfArrs);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
// static Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) {
|
||||||
|
|
||||||
|
// Nd4jLong result = 9223372036854775807LL;
|
||||||
|
|
||||||
|
// for(uint i = 0; i < shape::rank(inShapeInfo); ++i) {
|
||||||
|
|
||||||
|
// const auto currentStride = shape::stride(inShapeInfo)[i];
|
||||||
|
|
||||||
|
// if(i == axis || shape::shapeOf(inShapeInfo)[i] == 1)
|
||||||
|
// continue;
|
||||||
|
|
||||||
|
// if(result > currentStride)
|
||||||
|
// result = currentStride;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return result == 9223372036854775807LL ? 1 : result;
|
||||||
|
// }
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void SpecialMethods<T>::concatCpuGeneric(const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
|
void SpecialMethods<T>::concatCpuGeneric(const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||||
|
|
||||||
|
@ -150,7 +169,7 @@ void SpecialMethods<T>::concatCpuGeneric(const std::vector<const NDArray*>& inAr
|
||||||
// if(!areInputsContin || !allSameOrder)
|
// if(!areInputsContin || !allSameOrder)
|
||||||
// break;
|
// break;
|
||||||
|
|
||||||
// strideOfContigStride[i] = shape::strideOverContigAxis(axis, inArrs[i]->shapeInfo());
|
// strideOfContigStride[i] = strideOverContigAxis(axis, inArrs[i]->getShapeInfo());
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
@ -158,7 +177,7 @@ void SpecialMethods<T>::concatCpuGeneric(const std::vector<const NDArray*>& inAr
|
||||||
|
|
||||||
// if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and output array
|
// if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and output array
|
||||||
|
|
||||||
// const auto zStep = shape::strideOverContigAxis(axis, output.shapeInfo());
|
// const auto zStep = strideOverContigAxis(axis, output.getShapeInfo());
|
||||||
|
|
||||||
// for (uint i = 0; i < output.lengthOf() / output.sizeAt(axis); ++i) {
|
// for (uint i = 0; i < output.lengthOf() / output.sizeAt(axis); ++i) {
|
||||||
|
|
||||||
|
|
|
@ -130,7 +130,7 @@ if (SD_CPU)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_executable(runtests ${TEST_SOURCES})
|
add_executable(runtests ${TEST_SOURCES})
|
||||||
target_link_libraries(runtests ${SD_LIBRARY_NAME}static ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} gtest gtest_main)
|
target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} gtest gtest_main)
|
||||||
elseif(SD_CUDA)
|
elseif(SD_CUDA)
|
||||||
|
|
||||||
add_executable(runtests ${TEST_SOURCES})
|
add_executable(runtests ${TEST_SOURCES})
|
||||||
|
@ -148,5 +148,5 @@ elseif(SD_CUDA)
|
||||||
message("CUDNN library: ${CUDNN}")
|
message("CUDNN library: ${CUDNN}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
target_link_libraries(runtests ${SD_LIBRARY_NAME}static ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN} gtest gtest_main)
|
target_link_libraries(runtests samediff_obj ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN} gtest gtest_main)
|
||||||
endif()
|
endif()
|
|
@ -184,7 +184,7 @@ TEST_F(DeclarableOpsTests16, test_range_2) {
|
||||||
double tArgs[] = { -1.0, 1.0, 0.01 };
|
double tArgs[] = { -1.0, 1.0, 0.01 };
|
||||||
|
|
||||||
auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0, nullptr, 0);
|
auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0, nullptr, 0);
|
||||||
shape::printShapeInfoLinear("Result", shapes->at(0));
|
// shape::printShapeInfoLinear("Result", shapes->at(0));
|
||||||
ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0)));
|
ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0)));
|
||||||
|
|
||||||
delete shapes;
|
delete shapes;
|
||||||
|
@ -426,7 +426,7 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) {
|
||||||
0.928968489f, 0.684074104f
|
0.928968489f, 0.684074104f
|
||||||
});
|
});
|
||||||
|
|
||||||
//get subarray
|
//get subarray
|
||||||
//get subarray
|
//get subarray
|
||||||
NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||||
NDArray expected = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
NDArray expected = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||||
|
@ -627,7 +627,7 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) {
|
||||||
});
|
});
|
||||||
|
|
||||||
auto actual = NDArrayFactory::create<float>('c', { 3 });
|
auto actual = NDArrayFactory::create<float>('c', { 3 });
|
||||||
//get subarray
|
//get subarray
|
||||||
NDArray subArrHsvs = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
NDArray subArrHsvs = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||||
subArrHsvs.reshapei({ 3 });
|
subArrHsvs.reshapei({ 3 });
|
||||||
NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||||
|
@ -635,7 +635,7 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) {
|
||||||
#if 0
|
#if 0
|
||||||
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
|
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
|
||||||
subArrHsvs.printShapeInfo("subArrHsvs");
|
subArrHsvs.printShapeInfo("subArrHsvs");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
Context ctx(1);
|
Context ctx(1);
|
||||||
ctx.setInputArray(0, &subArrHsvs);
|
ctx.setInputArray(0, &subArrHsvs);
|
||||||
|
@ -855,7 +855,7 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) {
|
||||||
-0.04447775f, -0.44518381f
|
-0.04447775f, -0.44518381f
|
||||||
});
|
});
|
||||||
|
|
||||||
//get subarray
|
//get subarray
|
||||||
NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||||
NDArray expected = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
NDArray expected = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||||
subArrRgbs.reshapei({ 3 });
|
subArrRgbs.reshapei({ 3 });
|
||||||
|
@ -1054,7 +1054,7 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
|
||||||
0.280231822f, 1.91936605f
|
0.280231822f, 1.91936605f
|
||||||
});
|
});
|
||||||
|
|
||||||
//get subarray
|
//get subarray
|
||||||
NDArray subArrYiqs = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
NDArray subArrYiqs = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||||
NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||||
subArrYiqs.reshapei({ 3 });
|
subArrYiqs.reshapei({ 3 });
|
||||||
|
@ -1074,3 +1074,422 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
ASSERT_TRUE(expected.equalsTo(actual));
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbynorm_1) {
|
||||||
|
auto x= NDArrayFactory::create<double>('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
|
||||||
|
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0});
|
||||||
|
|
||||||
|
sd::ops::clipbynorm op;
|
||||||
|
auto result = op.evaluate({&x}, {4.0}, {});
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbynorm_2) {
|
||||||
|
auto x= NDArrayFactory::create<double>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
|
||||||
|
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
|
||||||
|
|
||||||
|
sd::ops::clipbynorm op;
|
||||||
|
auto result = op.evaluate({&x}, {6.0}, {});
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbynorm_3) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3, 5});
|
||||||
|
auto unities = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., 1.});
|
||||||
|
auto scale = NDArrayFactory::create<double>('c', {3, 1}, {1.1, 1., 0.9});
|
||||||
|
|
||||||
|
x.linspace(100.);
|
||||||
|
|
||||||
|
auto xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true);
|
||||||
|
x /= xNorm1;
|
||||||
|
xNorm1 = x.reduceAlongDimension(reduce::Norm2,{1}, true);
|
||||||
|
|
||||||
|
ASSERT_TRUE(unities.isSameShape(xNorm1));
|
||||||
|
ASSERT_TRUE(unities.equalsTo(xNorm1));
|
||||||
|
|
||||||
|
x *= scale;
|
||||||
|
xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true);
|
||||||
|
|
||||||
|
sd::ops::clipbynorm op;
|
||||||
|
auto result = op.evaluate({&x}, {1.0}, {1});
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true);
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., xNorm1.e<double>(2)});
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(&zNorm1));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(&zNorm1));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbynorm_4) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3, 5}, {0.7044955, 0.55606544, 0.15833677, 0.001874401, 0.61595726, 0.3924779, 0.7414847, 0.4127324, 0.24026828, 0.26093036, 0.46741188, 0.01863421, 0.08528871, 0.529365, 0.5510694});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {0.405392, 0.319980, 0.091113, 0.001079, 0.354444, 0.225846, 0.426676, 0.237501, 0.138259, 0.150149, 0.268965, 0.010723, 0.049078, 0.304615, 0.317105});
|
||||||
|
|
||||||
|
sd::ops::clipbynorm op;
|
||||||
|
auto result = op.evaluate({&x}, {1.f}, {});
|
||||||
|
auto output = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbynorm_5) {
|
||||||
|
|
||||||
|
// auto x = NDArrayFactory::create<double>('c', {3, 5}, {1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5});
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3, 5});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {1., 2., 2.89271, 3.50524, 4.00892, 6., 7., 7.71389, 7.88678, 8.01784, 11., 12., 12.53507, 12.26833, 12.02676});
|
||||||
|
// auto exp = NDArrayFactory::create<double>('c', {3, 5}, {1,1,1,1,1,1,1,1,1,1,1,1,1,1,1});
|
||||||
|
|
||||||
|
x.linspace(1);
|
||||||
|
|
||||||
|
sd::ops::clipbynorm op;
|
||||||
|
auto result = op.evaluate({&x}, {15.f}, {0});
|
||||||
|
auto output = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbynorm_6) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3, 5});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {1., 2., 3., 4., 5., 4.95434, 5.78006, 6.60578, 7.43151, 8.25723, 5.64288, 6.15587, 6.66886, 7.18185, 7.69484});
|
||||||
|
|
||||||
|
x.linspace(1);
|
||||||
|
|
||||||
|
sd::ops::clipbynorm op;
|
||||||
|
auto result = op.evaluate({&x}, {15.f}, {1});
|
||||||
|
auto output = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbynorm_7) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3, 5});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {0.42597, 0.85194, 1.27791, 1.70389, 2.12986, 2.55583, 2.9818 , 3.40777, 3.83374, 4.25971, 4.68569, 5.11166, 5.53763, 5.9636 , 6.38957});
|
||||||
|
|
||||||
|
x.linspace(1);
|
||||||
|
|
||||||
|
sd::ops::clipbynorm op;
|
||||||
|
auto result = op.evaluate({&x}, {15.f}, {0,1});
|
||||||
|
auto output = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbynorm_8) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3, 5});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {0.42597, 0.85194, 1.27791, 1.70389, 2.12986, 2.55583, 2.9818 , 3.40777, 3.83374, 4.25971, 4.68569, 5.11166, 5.53763, 5.9636 , 6.38957});
|
||||||
|
|
||||||
|
x.linspace(1);
|
||||||
|
|
||||||
|
sd::ops::clipbynorm op;
|
||||||
|
auto result = op.evaluate({&x}, {15.}, {});
|
||||||
|
auto output = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbynorm_9) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2}, {3., 4.});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {2}, {2.4, 3.2});
|
||||||
|
|
||||||
|
sd::ops::clipbynorm op;
|
||||||
|
auto result = op.evaluate({&x}, {4.}, {});
|
||||||
|
auto output = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbynorm_10) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>(6.);
|
||||||
|
auto exp = NDArrayFactory::create<double>(5.);
|
||||||
|
|
||||||
|
sd::ops::clipbynorm op;
|
||||||
|
auto result = op.evaluate({&x}, {5.}, {});
|
||||||
|
auto output = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbynorm_11) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2, 3, 4});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {1., 2., 3., 4., 4.44787, 5.33745, 6.22702, 7.1166 , 6.33046, 7.03384, 7.73723, 8.44061,
|
||||||
|
13., 14., 15., 16., 15.12277, 16.01235, 16.90192, 17.7915 ,14.77107, 15.47446, 16.17784, 16.88123});
|
||||||
|
|
||||||
|
x.linspace(1);
|
||||||
|
|
||||||
|
sd::ops::clipbynorm op;
|
||||||
|
auto result = op.evaluate({&x}, {35.}, {0, 2});
|
||||||
|
auto output = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbynorm_12) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5,6, 7, 8, 9});
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {3, 3}, {0.03198684, 0.06397368, 0.09596053, 0.12794736, 0.15993419, 0.19192106, 0.22390789, 0.25589472, 0.28788155});
|
||||||
|
|
||||||
|
sd::ops::clipbynorm op;
|
||||||
|
auto result = op.evaluate({&x}, {0.54}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *result.at(0));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbynorm_13) {
|
||||||
|
|
||||||
|
const int bS = 5;
|
||||||
|
const int nOut = 4;
|
||||||
|
const int axis = 0;
|
||||||
|
const double clip = 2.;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.897 ,0.173 ,0.931 ,0.736 ,0.540 ,0.953 ,0.278 ,0.573 ,0.787 ,0.320 ,0.776 ,0.338 ,0.311 ,0.835 ,0.909 ,0.890 ,0.290}); // uniform random in range [0,1]
|
||||||
|
auto colVect = NDArrayFactory::create<double>('c', {bS, 1}, {0.9, 0.95, 1.00, 1.05, 1.1});
|
||||||
|
auto expect = NDArrayFactory::create<double>('c', {bS, nOut});
|
||||||
|
|
||||||
|
auto norm2 = x.reduceAlongDimension(reduce::Norm2, {axis}, true); // norm2 has shape [1, nOut]
|
||||||
|
|
||||||
|
auto y = ( (x / norm2) * clip) * colVect ;
|
||||||
|
auto temp = (x / norm2) * clip;
|
||||||
|
|
||||||
|
for (int j = 0; j < nOut; ++j) {
|
||||||
|
auto yCol = y({0,0, j,j+1});
|
||||||
|
const double norm2Col = yCol.reduceNumber(reduce::Norm2).e<double>(0);
|
||||||
|
if (norm2Col <= clip)
|
||||||
|
expect({0,0, j,j+1}).assign(yCol);
|
||||||
|
else
|
||||||
|
expect({0,0, j,j+1}).assign ( yCol * (clip / norm2Col) );
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::ops::clipbynorm op;
|
||||||
|
auto result = op.evaluate({&y}, {clip}, {axis});
|
||||||
|
auto outFF = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expect.isSameShape(outFF));
|
||||||
|
ASSERT_TRUE(expect.equalsTo(outFF));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbynorm_bp_1) {
|
||||||
|
|
||||||
|
const int bS = 2;
|
||||||
|
const int nOut = 3;
|
||||||
|
const double clip = 0.7;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
||||||
|
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
|
||||||
|
|
||||||
|
const OpArgsHolder argsHolderFF({&x}, {clip}, {});
|
||||||
|
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {});
|
||||||
|
|
||||||
|
sd::ops::clipbynorm opFF;
|
||||||
|
sd::ops::clipbynorm_bp opBP;
|
||||||
|
|
||||||
|
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||||
|
|
||||||
|
ASSERT_TRUE(isGradCorrect);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbynorm_bp_2) {
|
||||||
|
|
||||||
|
const int bS = 2;
|
||||||
|
const int nOut = 3;
|
||||||
|
const int axis = 0;
|
||||||
|
const double clip = 0.7;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
||||||
|
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
|
||||||
|
|
||||||
|
const OpArgsHolder argsHolderFF({&x}, {clip}, {axis});
|
||||||
|
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis});
|
||||||
|
|
||||||
|
sd::ops::clipbynorm opFF;
|
||||||
|
sd::ops::clipbynorm_bp opBP;
|
||||||
|
|
||||||
|
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||||
|
|
||||||
|
ASSERT_TRUE(isGradCorrect);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbynorm_bp_3) {
|
||||||
|
|
||||||
|
const int bS = 2;
|
||||||
|
const int nOut = 3;
|
||||||
|
const int axis = 1;
|
||||||
|
const double clip = 1.;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
||||||
|
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
|
||||||
|
|
||||||
|
const OpArgsHolder argsHolderFF({&x}, {clip}, {axis});
|
||||||
|
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis});
|
||||||
|
|
||||||
|
sd::ops::clipbynorm opFF;
|
||||||
|
sd::ops::clipbynorm_bp opBP;
|
||||||
|
|
||||||
|
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||||
|
|
||||||
|
ASSERT_TRUE(isGradCorrect);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbyavgnorm_1) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {2, 3}, {-2.88, 0.0, 0.0, 3.84, 0.0, 0.0});
|
||||||
|
|
||||||
|
sd::ops::clipbyavgnorm op;
|
||||||
|
auto result = op.evaluate({&x}, {0.8}, {});
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbyavgnorm_2) {
|
||||||
|
auto x= NDArrayFactory::create<float>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
|
||||||
|
auto exp= NDArrayFactory::create<float>('c', {2, 3}, {-3.f, 0.0f, 0.0f, 4.f, 0.0f, 0.0f});
|
||||||
|
|
||||||
|
sd::ops::clipbyavgnorm op;
|
||||||
|
auto result = op.evaluate({&x}, {0.9}, {});
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbyavgnorm_bp_1) {
|
||||||
|
|
||||||
|
const int bS = 2;
|
||||||
|
const int nOut = 3;
|
||||||
|
const double clip = 0.7;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
||||||
|
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
|
||||||
|
|
||||||
|
const OpArgsHolder argsHolderFF({&x}, {clip}, {});
|
||||||
|
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {});
|
||||||
|
|
||||||
|
sd::ops::clipbyavgnorm opFF;
|
||||||
|
sd::ops::clipbyavgnorm_bp opBP;
|
||||||
|
|
||||||
|
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||||
|
|
||||||
|
ASSERT_TRUE(isGradCorrect);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbyavgnorm_bp_2) {
|
||||||
|
|
||||||
|
const int bS = 2;
|
||||||
|
const int nOut = 3;
|
||||||
|
const int axis = 1;
|
||||||
|
const double clip = 1.;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
||||||
|
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
|
||||||
|
|
||||||
|
const OpArgsHolder argsHolderFF({&x}, {clip}, {axis});
|
||||||
|
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis});
|
||||||
|
|
||||||
|
sd::ops::clipbyavgnorm opFF;
|
||||||
|
sd::ops::clipbyavgnorm_bp opBP;
|
||||||
|
|
||||||
|
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||||
|
|
||||||
|
ASSERT_TRUE(isGradCorrect);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests16, clipbyavgnorm_bp_3) {
|
||||||
|
|
||||||
|
NDArray x('c', {2, 3, 4}, {-0.14 ,0.96 ,0.47 ,-0.98 ,0.03 ,0.95 ,0.33 ,-0.97 ,0.59 ,-0.92 ,-0.12 ,-0.33 ,0.82 ,-0.76 ,-0.69 ,-0.95 ,-0.77 ,0.25 ,-0.35 ,0.94 ,0.50 ,0.04 ,0.61 ,0.99}, sd::DataType::DOUBLE);
|
||||||
|
NDArray gradO('c', {2, 3, 4}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
const OpArgsHolder argsHolderFF({&x}, {0.7}, {0,2});
|
||||||
|
const OpArgsHolder argsHolderBP({&x, &gradO}, {0.7}, {0,2});
|
||||||
|
|
||||||
|
sd::ops::clipbyavgnorm opFF;
|
||||||
|
sd::ops::clipbyavgnorm_bp opBP;
|
||||||
|
|
||||||
|
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||||
|
|
||||||
|
ASSERT_TRUE(isGradCorrect);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,7 @@ TEST_F(DeclarableOpsTests3, Test_Tile_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ TEST_F(DeclarableOpsTests3, Test_Tile_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_Permute_1) {
|
TEST_F(DeclarableOpsTests3, Test_Permute_1) {
|
||||||
|
@ -123,7 +123,7 @@ TEST_F(DeclarableOpsTests3, Test_Unique_1) {
|
||||||
ASSERT_TRUE(expI.isSameShape(i));
|
ASSERT_TRUE(expI.isSameShape(i));
|
||||||
ASSERT_TRUE(expI.equalsTo(i));
|
ASSERT_TRUE(expI.equalsTo(i));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_Unique_2) {
|
TEST_F(DeclarableOpsTests3, Test_Unique_2) {
|
||||||
|
@ -171,7 +171,7 @@ TEST_F(DeclarableOpsTests3, Test_Rint_1) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -226,7 +226,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) {
|
||||||
ASSERT_TRUE(exp0.isSameShape(z0));
|
ASSERT_TRUE(exp0.isSameShape(z0));
|
||||||
ASSERT_TRUE(exp0.equalsTo(z0));
|
ASSERT_TRUE(exp0.equalsTo(z0));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
auto result1 = op.evaluate({&x, &axis}, {1}, {});
|
auto result1 = op.evaluate({&x, &axis}, {1}, {});
|
||||||
|
|
||||||
|
@ -244,94 +244,6 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_1) {
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 3}, {-2.88, 0.0, 0.0, 3.84, 0.0, 0.0});
|
|
||||||
|
|
||||||
sd::ops::clipbyavgnorm op;
|
|
||||||
auto result = op.evaluate({&x}, {0.8}, {});
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_2) {
|
|
||||||
auto x= NDArrayFactory::create<float>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
|
|
||||||
auto exp= NDArrayFactory::create<float>('c', {2, 3}, {-3.f, 0.0f, 0.0f, 4.f, 0.0f, 0.0f});
|
|
||||||
|
|
||||||
sd::ops::clipbyavgnorm op;
|
|
||||||
auto result = op.evaluate({&x}, {0.9}, {});
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_ClipByNorm_1) {
|
|
||||||
auto x= NDArrayFactory::create<double>('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
|
|
||||||
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0});
|
|
||||||
|
|
||||||
sd::ops::clipbynorm op;
|
|
||||||
auto result = op.evaluate({&x}, {4.0}, {});
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_ClipByNorm_2) {
|
|
||||||
auto x= NDArrayFactory::create<double>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
|
|
||||||
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
|
|
||||||
|
|
||||||
sd::ops::clipbynorm op;
|
|
||||||
auto result = op.evaluate({&x}, {6.0}, {});
|
|
||||||
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_ClipByNorm_3) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {3, 5});
|
|
||||||
auto unities = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., 1.});
|
|
||||||
auto scale = NDArrayFactory::create<double>('c', {3, 1}, {1.1, 1., 0.9});
|
|
||||||
|
|
||||||
x.linspace(100.);
|
|
||||||
|
|
||||||
auto xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true);
|
|
||||||
x /= xNorm1;
|
|
||||||
xNorm1 = x.reduceAlongDimension(reduce::Norm2,{1}, true);
|
|
||||||
|
|
||||||
ASSERT_TRUE(unities.isSameShape(xNorm1));
|
|
||||||
ASSERT_TRUE(unities.equalsTo(xNorm1));
|
|
||||||
|
|
||||||
x *= scale;
|
|
||||||
xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true);
|
|
||||||
|
|
||||||
sd::ops::clipbynorm op;
|
|
||||||
auto result = op.evaluate({&x}, {1.0}, {1});
|
|
||||||
auto z = result.at(0);
|
|
||||||
|
|
||||||
auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true);
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., xNorm1.e<double>(2)});
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(&zNorm1));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(&zNorm1));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_ListDiff_1) {
|
TEST_F(DeclarableOpsTests3, Test_ListDiff_1) {
|
||||||
auto x= NDArrayFactory::create<float>('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
auto x= NDArrayFactory::create<float>('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
||||||
auto y= NDArrayFactory::create<float>('c', {3}, {1.f, 3.f, 5.f});
|
auto y= NDArrayFactory::create<float>('c', {3}, {1.f, 3.f, 5.f});
|
||||||
|
@ -551,7 +463,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
delete exp;
|
delete exp;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) {
|
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) {
|
||||||
|
@ -579,7 +491,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
delete exp;
|
delete exp;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) {
|
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) {
|
||||||
|
@ -607,7 +519,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) {
|
||||||
}
|
}
|
||||||
|
|
||||||
delete exp;
|
delete exp;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) {
|
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) {
|
||||||
|
@ -635,7 +547,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) {
|
||||||
}
|
}
|
||||||
|
|
||||||
delete exp;
|
delete exp;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) {
|
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) {
|
||||||
|
@ -663,7 +575,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) {
|
||||||
}
|
}
|
||||||
|
|
||||||
delete exp;
|
delete exp;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -692,7 +604,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_6) {
|
||||||
}
|
}
|
||||||
|
|
||||||
delete exp;
|
delete exp;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) {
|
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) {
|
||||||
|
@ -722,7 +634,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) {
|
||||||
}
|
}
|
||||||
|
|
||||||
delete exp;
|
delete exp;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) {
|
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) {
|
||||||
|
@ -734,7 +646,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) {
|
||||||
sd::ops::batched_gemm op;
|
sd::ops::batched_gemm op;
|
||||||
try {
|
try {
|
||||||
auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3});
|
auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3});
|
||||||
|
|
||||||
ASSERT_TRUE(false);
|
ASSERT_TRUE(false);
|
||||||
} catch (std::invalid_argument &e) {
|
} catch (std::invalid_argument &e) {
|
||||||
//
|
//
|
||||||
|
@ -875,7 +787,7 @@ TEST_F(DeclarableOpsTests3, sruCell_test3) {
|
||||||
ASSERT_TRUE(expCt.isSameShape(ct));
|
ASSERT_TRUE(expCt.isSameShape(ct));
|
||||||
ASSERT_TRUE(expCt.equalsTo(ct));
|
ASSERT_TRUE(expCt.equalsTo(ct));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
|
@ -946,7 +858,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test2) {
|
||||||
ASSERT_TRUE(expHt.isSameShape(ht));
|
ASSERT_TRUE(expHt.isSameShape(ht));
|
||||||
ASSERT_TRUE(expHt.equalsTo(ht));
|
ASSERT_TRUE(expHt.equalsTo(ht));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1001,7 +913,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test1) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -1021,7 +933,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test2) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -1099,7 +1011,7 @@ TEST_F(DeclarableOpsTests3, diag_test_vector) {
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
delete input;
|
delete input;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1120,7 +1032,7 @@ TEST_F(DeclarableOpsTests3, diag_test_col_vector) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
delete input;
|
delete input;
|
||||||
}
|
}
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -1245,7 +1157,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test2) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -1551,7 +1463,7 @@ TEST_F(DeclarableOpsTests3, betainc_test8) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output, 1e-6));
|
ASSERT_TRUE(expected.equalsTo(output, 1e-6));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -1576,7 +1488,7 @@ TEST_F(DeclarableOpsTests3, betainc_test9) {
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -1642,7 +1554,7 @@ TEST_F(DeclarableOpsTests3, betainc_test12) {
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -1689,7 +1601,7 @@ TEST_F(DeclarableOpsTests3, zeta_test2) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -1831,7 +1743,7 @@ TEST_F(DeclarableOpsTests3, zeta_test8) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -1856,7 +1768,7 @@ TEST_F(DeclarableOpsTests3, zeta_test9) {
|
||||||
ASSERT_TRUE(expected.isSameShape(z));
|
ASSERT_TRUE(expected.isSameShape(z));
|
||||||
ASSERT_TRUE(expected.equalsTo(z));
|
ASSERT_TRUE(expected.equalsTo(z));
|
||||||
|
|
||||||
//
|
//
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -1881,7 +1793,7 @@ TEST_F(DeclarableOpsTests3, zeta_test10) {
|
||||||
ASSERT_TRUE(expected.isSameShape(z));
|
ASSERT_TRUE(expected.isSameShape(z));
|
||||||
ASSERT_TRUE(expected.equalsTo(z));
|
ASSERT_TRUE(expected.equalsTo(z));
|
||||||
|
|
||||||
//
|
//
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1908,7 +1820,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test1) {
|
||||||
x.assign(0.5);
|
x.assign(0.5);
|
||||||
|
|
||||||
auto expected= NDArrayFactory::create<double>('c', {3,3}, {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, 1290440.250000, -20644900.000000, 3.71595e+08});
|
auto expected= NDArrayFactory::create<double>('c', {3,3}, {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, 1290440.250000, -20644900.000000, 3.71595e+08});
|
||||||
|
|
||||||
sd::ops::polygamma op;
|
sd::ops::polygamma op;
|
||||||
auto result = op.evaluate({&n, &x}, {}, {});
|
auto result = op.evaluate({&n, &x}, {}, {});
|
||||||
|
|
||||||
|
@ -1920,7 +1832,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test1) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -2263,7 +2175,7 @@ TEST_F(DeclarableOpsTests3, svd_test7) {
|
||||||
|
|
||||||
ASSERT_TRUE(expS.equalsTo(s));
|
ASSERT_TRUE(expS.equalsTo(s));
|
||||||
ASSERT_TRUE(expS.isSameShape(s));
|
ASSERT_TRUE(expS.isSameShape(s));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -2416,7 +2328,7 @@ TEST_F(DeclarableOpsTests3, svd_test7) {
|
||||||
// ASSERT_NEAR(sd::math::nd4j_abs(expV.e<float>(i)), sd::math::nd4j_abs(v->e<float>(i)), 1e-5);
|
// ASSERT_NEAR(sd::math::nd4j_abs(expV.e<float>(i)), sd::math::nd4j_abs(v->e<float>(i)), 1e-5);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
//
|
//
|
||||||
// }
|
// }
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -57,7 +57,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_1) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) {
|
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) {
|
||||||
|
@ -78,7 +78,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) {
|
||||||
|
|
||||||
ASSERT_EQ(exp, *z);
|
ASSERT_EQ(exp, *z);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) {
|
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) {
|
||||||
|
@ -100,7 +100,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) {
|
||||||
ASSERT_TRUE(z->isEmpty());
|
ASSERT_TRUE(z->isEmpty());
|
||||||
//ASSERT_EQ(exp, *z);
|
//ASSERT_EQ(exp, *z);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
|
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
|
||||||
|
@ -122,7 +122,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
|
||||||
ASSERT_TRUE(z->equalsTo(exp));
|
ASSERT_TRUE(z->equalsTo(exp));
|
||||||
//ASSERT_EQ(exp, *z);
|
//ASSERT_EQ(exp, *z);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) {
|
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) {
|
||||||
|
@ -185,7 +185,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_5) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) {
|
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) {
|
||||||
|
@ -205,7 +205,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) {
|
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) {
|
||||||
|
@ -226,7 +226,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
//ASSERT_TRUE(exp.equalsTo(z));
|
//ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
|
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
|
||||||
|
@ -248,7 +248,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
//ASSERT_TRUE(exp.equalsTo(z));
|
//ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
|
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
|
||||||
|
@ -270,7 +270,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
//ASSERT_TRUE(exp.equalsTo(z));
|
//ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) {
|
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) {
|
||||||
|
@ -292,7 +292,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
//ASSERT_TRUE(exp.equalsTo(z));
|
//ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
|
TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
|
||||||
|
@ -309,7 +309,7 @@ TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_Order_1) {
|
TEST_F(DeclarableOpsTests6, Test_Order_1) {
|
||||||
|
@ -326,7 +326,7 @@ TEST_F(DeclarableOpsTests6, Test_Order_1) {
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
ASSERT_NE(x.ordering(), z->ordering());
|
ASSERT_NE(x.ordering(), z->ordering());
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, cumSum_1) {
|
TEST_F(DeclarableOpsTests6, cumSum_1) {
|
||||||
|
@ -342,7 +342,7 @@ TEST_F(DeclarableOpsTests6, cumSum_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, cumSum_2) {
|
TEST_F(DeclarableOpsTests6, cumSum_2) {
|
||||||
|
@ -359,7 +359,7 @@ TEST_F(DeclarableOpsTests6, cumSum_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, cumSum_3) {
|
TEST_F(DeclarableOpsTests6, cumSum_3) {
|
||||||
|
@ -375,7 +375,7 @@ TEST_F(DeclarableOpsTests6, cumSum_3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, cumSum_4) {
|
TEST_F(DeclarableOpsTests6, cumSum_4) {
|
||||||
|
@ -391,7 +391,7 @@ TEST_F(DeclarableOpsTests6, cumSum_4) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, cumSum_5) {
|
TEST_F(DeclarableOpsTests6, cumSum_5) {
|
||||||
|
@ -406,7 +406,7 @@ TEST_F(DeclarableOpsTests6, cumSum_5) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, cumSum_6) {
|
TEST_F(DeclarableOpsTests6, cumSum_6) {
|
||||||
|
@ -421,7 +421,7 @@ TEST_F(DeclarableOpsTests6, cumSum_6) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, cumSum_7) {
|
TEST_F(DeclarableOpsTests6, cumSum_7) {
|
||||||
|
@ -436,7 +436,7 @@ TEST_F(DeclarableOpsTests6, cumSum_7) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, cumSum_8) {
|
TEST_F(DeclarableOpsTests6, cumSum_8) {
|
||||||
|
@ -452,7 +452,7 @@ TEST_F(DeclarableOpsTests6, cumSum_8) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -477,7 +477,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_TRUE(expFF.equalsTo(z));
|
ASSERT_TRUE(expFF.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
//************************************//
|
//************************************//
|
||||||
exclusive = 1; reverse = 0;
|
exclusive = 1; reverse = 0;
|
||||||
|
@ -486,7 +486,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
z = result.at(0);
|
z = result.at(0);
|
||||||
ASSERT_TRUE(expTF.equalsTo(z));
|
ASSERT_TRUE(expTF.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
//************************************//
|
//************************************//
|
||||||
exclusive = 0; reverse = 1;
|
exclusive = 0; reverse = 1;
|
||||||
|
@ -495,7 +495,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
z = result.at(0);
|
z = result.at(0);
|
||||||
ASSERT_TRUE(expFT.equalsTo(z));
|
ASSERT_TRUE(expFT.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
//************************************//
|
//************************************//
|
||||||
exclusive = 1; reverse = 1;
|
exclusive = 1; reverse = 1;
|
||||||
|
@ -504,7 +504,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
z = result.at(0);
|
z = result.at(0);
|
||||||
ASSERT_TRUE(expTT.equalsTo(z));
|
ASSERT_TRUE(expTT.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -517,7 +517,7 @@ TEST_F(DeclarableOpsTests6, cumSum_10) {
|
||||||
auto result = op.evaluate({&x, &y}, {}, {1, 1});
|
auto result = op.evaluate({&x, &y}, {}, {1, 1});
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -536,7 +536,7 @@ TEST_F(DeclarableOpsTests6, cumSum_11) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -555,7 +555,7 @@ TEST_F(DeclarableOpsTests6, cumSum_12) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -574,7 +574,7 @@ TEST_F(DeclarableOpsTests6, cumSum_13) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -593,7 +593,7 @@ TEST_F(DeclarableOpsTests6, cumSum_14) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -612,7 +612,7 @@ TEST_F(DeclarableOpsTests6, cumSum_15) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -631,7 +631,7 @@ TEST_F(DeclarableOpsTests6, cumSum_16) {
|
||||||
ASSERT_TRUE(z->ews() == 1);
|
ASSERT_TRUE(z->ews() == 1);
|
||||||
ASSERT_TRUE(x.ews() == 1);
|
ASSERT_TRUE(x.ews() == 1);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -664,7 +664,7 @@ TEST_F(DeclarableOpsTests6, cumSum_17) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -697,7 +697,7 @@ TEST_F(DeclarableOpsTests6, cumSum_18) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -731,7 +731,7 @@ TEST_F(DeclarableOpsTests6, cumSum_19) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -764,7 +764,7 @@ TEST_F(DeclarableOpsTests6, cumSum_20) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -779,30 +779,40 @@ TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_1) {
|
||||||
auto res = op.evaluate({&x, &y, &z}, {}, {}, {});
|
auto res = op.evaluate({&x, &y, &z}, {}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res.status());
|
ASSERT_EQ(ND4J_STATUS_OK, res.status());
|
||||||
// res.at(0)->printIndexedBuffer("MergeMaxIndex Result is ");
|
|
||||||
// res.at(0)->printShapeInfo("Shape info for MergeMaxIdex");
|
|
||||||
// x.printIndexedBuffer("Input is");
|
|
||||||
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_2) {
|
TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_2) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
|
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 60.f, 7.f, 8.f});
|
||||||
auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
|
auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
|
||||||
auto z = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f});
|
auto z = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 6.f, 7.f, 80.f});
|
||||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2});
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 0, 1, 2});
|
||||||
sd::ops::mergemaxindex op;
|
sd::ops::mergemaxindex op;
|
||||||
|
|
||||||
auto ress = op.evaluate({&x, &y, &z}, {}, {sd::DataType::INT64});
|
auto ress = op.evaluate({&x, &y, &z}, {}, {sd::DataType::INT64});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, ress.status());
|
ASSERT_EQ(ND4J_STATUS_OK, ress.status());
|
||||||
// res.at(0)->printIndexedBuffer("MergeMaxIndex2 Result is ");
|
|
||||||
// res.at(0)->printShapeInfo("Shape info for MergeMaxIdex2");
|
|
||||||
// x.printIndexedBuffer("Input is");
|
|
||||||
ASSERT_TRUE(ress.at(0)->equalsTo(exp));
|
ASSERT_TRUE(ress.at(0)->equalsTo(exp));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_3) {
|
||||||
|
|
||||||
|
auto x1 = NDArrayFactory::create<double>('c', {3}, {1.f, 0.f, 0.f});
|
||||||
|
auto x2 = NDArrayFactory::create<double>('c', {3}, {0.f, 1.f, 0.f});
|
||||||
|
auto x3 = NDArrayFactory::create<double>('c', {3}, {0.f, 0.f, 1.f});
|
||||||
|
NDArray z('c', {3}, sd::DataType::INT32);
|
||||||
|
NDArray expZ('c', {3}, {0, 1, 2}, sd::DataType::INT32);
|
||||||
|
|
||||||
|
sd::ops::mergemaxindex op;
|
||||||
|
auto result = op.execute({&x1, &x2, &x3}, {&z}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result);
|
||||||
|
ASSERT_TRUE(z.equalsTo(expZ));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -818,7 +828,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_1) {
|
||||||
//res.at(0)->printIndexedBuffer("Result is ");
|
//res.at(0)->printIndexedBuffer("Result is ");
|
||||||
//x.printIndexedBuffer("Input is");
|
//x.printIndexedBuffer("Input is");
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests6, TestMod_1) {
|
TEST_F(DeclarableOpsTests6, TestMod_1) {
|
||||||
|
@ -834,7 +844,7 @@ TEST_F(DeclarableOpsTests6, TestMod_1) {
|
||||||
// res.at(0)->printIndexedBuffer("MOD Result is ");
|
// res.at(0)->printIndexedBuffer("MOD Result is ");
|
||||||
// x.printIndexedBuffer("Input is");
|
// x.printIndexedBuffer("Input is");
|
||||||
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -853,7 +863,7 @@ TEST_F(DeclarableOpsTests6, TestMod_BP_1) {
|
||||||
|
|
||||||
// x.printIndexedBuffer("Input is");
|
// x.printIndexedBuffer("Input is");
|
||||||
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -870,7 +880,7 @@ TEST_F(DeclarableOpsTests6, TestRank_1) {
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res.status());
|
ASSERT_EQ(ND4J_STATUS_OK, res.status());
|
||||||
|
|
||||||
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
||||||
|
|
||||||
}
|
}
|
||||||
TEST_F(DeclarableOpsTests6, TestDropout_2) {
|
TEST_F(DeclarableOpsTests6, TestDropout_2) {
|
||||||
// auto x0 = NDArrayFactory::create<double>('c', {10, 10});
|
// auto x0 = NDArrayFactory::create<double>('c', {10, 10});
|
||||||
|
@ -883,7 +893,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_2) {
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res.status());
|
ASSERT_EQ(ND4J_STATUS_OK, res.status());
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, TestDropout_3) {
|
TEST_F(DeclarableOpsTests6, TestDropout_3) {
|
||||||
|
@ -898,7 +908,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_3) {
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res.status());
|
ASSERT_EQ(ND4J_STATUS_OK, res.status());
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -922,7 +932,7 @@ TEST_F(DeclarableOpsTests6, MaxPoolWithArgmax_1) {
|
||||||
|
|
||||||
ASSERT_TRUE(expI.equalsTo(res.at(1)));
|
ASSERT_TRUE(expI.equalsTo(res.at(1)));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -947,7 +957,7 @@ TEST_F(DeclarableOpsTests6, SufficientStatistics_1) {
|
||||||
ASSERT_TRUE(sumExp.equalsTo(res.at(1)));
|
ASSERT_TRUE(sumExp.equalsTo(res.at(1)));
|
||||||
ASSERT_TRUE(sqrExp.equalsTo(res.at(2)));
|
ASSERT_TRUE(sqrExp.equalsTo(res.at(2)));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -979,7 +989,7 @@ TEST_F(DeclarableOpsTests6, SufficientStatistics_2) {
|
||||||
ASSERT_TRUE(sumExp.equalsTo(res.at(1)));
|
ASSERT_TRUE(sumExp.equalsTo(res.at(1)));
|
||||||
ASSERT_TRUE(sqrExp.equalsTo(res.at(2)));
|
ASSERT_TRUE(sqrExp.equalsTo(res.at(2)));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1270,7 +1280,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_1) {
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
// ASSERT_TRUE(expNorm.equalsTo(norm));
|
// ASSERT_TRUE(expNorm.equalsTo(norm));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1310,7 +1320,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_2) {
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(y));
|
ASSERT_TRUE(exp.equalsTo(y));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1344,7 +1354,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_3) {
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(y));
|
ASSERT_TRUE(exp.equalsTo(y));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1365,7 +1375,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1386,7 +1396,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1407,7 +1417,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1428,7 +1438,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_4) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1452,7 +1462,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_5) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1477,7 +1487,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_6) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1496,7 +1506,7 @@ TEST_F(DeclarableOpsTests6, LogMatrixDeterminant_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1514,7 +1524,7 @@ TEST_F(DeclarableOpsTests6, LogDet_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1533,7 +1543,7 @@ TEST_F(DeclarableOpsTests6, LogDet_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1552,7 +1562,7 @@ TEST_F(DeclarableOpsTests6, LogDet_3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1596,7 +1606,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1615,7 +1625,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_010) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1634,7 +1644,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_01) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1653,7 +1663,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_02) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1700,7 +1710,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
|
TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
|
||||||
|
@ -1733,7 +1743,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1767,7 +1777,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1801,7 +1811,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1835,7 +1845,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_04) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1864,7 +1874,7 @@ TEST_F(DeclarableOpsTests6, ReluLayer_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_Reduce3_Edge) {
|
TEST_F(DeclarableOpsTests6, Test_Reduce3_Edge) {
|
||||||
|
@ -1917,7 +1927,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test1) {
|
||||||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -1960,7 +1970,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test2) {
|
||||||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -2003,7 +2013,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test3) {
|
||||||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -2045,7 +2055,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test4) {
|
||||||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -2087,7 +2097,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test5) {
|
||||||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -2141,7 +2151,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test1) {
|
||||||
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
||||||
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -2194,7 +2204,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test2) {
|
||||||
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
||||||
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2247,7 +2257,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test3) {
|
||||||
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
||||||
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -2290,7 +2300,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test1) {
|
||||||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2335,7 +2345,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test2) {
|
||||||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -2377,7 +2387,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test3) {
|
||||||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -2418,7 +2428,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test4) {
|
||||||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -2459,7 +2469,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test5) {
|
||||||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -2521,7 +2531,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test1) {
|
||||||
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
||||||
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -2581,7 +2591,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test2) {
|
||||||
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
||||||
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -2637,7 +2647,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test3) {
|
||||||
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
||||||
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -2696,7 +2706,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test4) {
|
||||||
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
||||||
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) {
|
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) {
|
||||||
|
@ -2749,7 +2759,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) {
|
||||||
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
||||||
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2763,7 +2773,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_1) {
|
||||||
|
|
||||||
ASSERT_EQ(e, *result.at(0));
|
ASSERT_EQ(e, *result.at(0));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_Diag_119_2) {
|
TEST_F(DeclarableOpsTests6, Test_Diag_119_2) {
|
||||||
|
@ -2776,7 +2786,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_2) {
|
||||||
|
|
||||||
ASSERT_EQ(e, *result.at(0));
|
ASSERT_EQ(e, *result.at(0));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_Diag_119_3) {
|
TEST_F(DeclarableOpsTests6, Test_Diag_119_3) {
|
||||||
|
@ -2789,7 +2799,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_3) {
|
||||||
|
|
||||||
ASSERT_EQ(e, *result.at(0));
|
ASSERT_EQ(e, *result.at(0));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -236,10 +236,10 @@ TEST_F(DeclarableOpsTests9, ScalarOpTest_MixedOrders_1) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test1) {
|
TEST_F(DeclarableOpsTests9, concat_test1) {
|
||||||
|
|
||||||
auto x0 = NDArrayFactory::create<double>('c', {2,3,4});
|
auto x0 = NDArrayFactory::create<float>('c', {2,3,4});
|
||||||
auto x1 = NDArrayFactory::create<double>('c', {2,2,4});
|
auto x1 = NDArrayFactory::create<float>('c', {2,2,4});
|
||||||
auto x2 = NDArrayFactory::create<double>('c', {2,1,4});
|
auto x2 = NDArrayFactory::create<float>('c', {2,1,4});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2,6,4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
|
auto exp = NDArrayFactory::create<float>('c', {2,6,4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
|
||||||
13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.});
|
13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.});
|
||||||
|
|
||||||
x0.linspace(1);
|
x0.linspace(1);
|
||||||
|
@ -261,10 +261,10 @@ TEST_F(DeclarableOpsTests9, concat_test1) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test2) {
|
TEST_F(DeclarableOpsTests9, concat_test2) {
|
||||||
|
|
||||||
auto x0 = NDArrayFactory::create<double>('c', {1,3,1});
|
auto x0 = NDArrayFactory::create<float>('c', {1,3,1});
|
||||||
auto x1 = NDArrayFactory::create<double>('c', {1,2,1});
|
auto x1 = NDArrayFactory::create<float>('c', {1,2,1});
|
||||||
auto x2 = NDArrayFactory::create<double>('c', {1,1,1});
|
auto x2 = NDArrayFactory::create<float>('c', {1,1,1});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {1,6,1}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f});
|
auto exp = NDArrayFactory::create<float>('c', {1,6,1}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f});
|
||||||
|
|
||||||
x0.linspace(1);
|
x0.linspace(1);
|
||||||
x1.linspace(1);
|
x1.linspace(1);
|
||||||
|
@ -285,10 +285,10 @@ TEST_F(DeclarableOpsTests9, concat_test2) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test3) {
|
TEST_F(DeclarableOpsTests9, concat_test3) {
|
||||||
|
|
||||||
auto x0 = NDArrayFactory::create<double>('c', {3});
|
auto x0 = NDArrayFactory::create<float>('c', {3});
|
||||||
auto x1 = NDArrayFactory::create<double>('c', {2});
|
auto x1 = NDArrayFactory::create<float>('c', {2});
|
||||||
auto x2 = NDArrayFactory::create<double>('c', {1});
|
auto x2 = NDArrayFactory::create<float>('c', {1});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {6}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f});
|
auto exp = NDArrayFactory::create<float>('c', {6}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f});
|
||||||
|
|
||||||
x0.linspace(1);
|
x0.linspace(1);
|
||||||
x1.linspace(1);
|
x1.linspace(1);
|
||||||
|
@ -300,21 +300,17 @@ TEST_F(DeclarableOpsTests9, concat_test3) {
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
auto output = result.at(0);
|
auto output = result.at(0);
|
||||||
|
|
||||||
output->printBuffer();
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test4) {
|
TEST_F(DeclarableOpsTests9, concat_test4) {
|
||||||
|
|
||||||
auto x0 = NDArrayFactory::create<double>('c', {1,1,1}, {1.f});
|
auto x0 = NDArrayFactory::create<float>('c', {1,1,1}, {1.f});
|
||||||
auto x1 = NDArrayFactory::create<double>('c', {1,1,1}, {2.f});
|
auto x1 = NDArrayFactory::create<float>('c', {1,1,1}, {2.f});
|
||||||
auto x2 = NDArrayFactory::create<double>('c', {1,1,1}, {3.f});
|
auto x2 = NDArrayFactory::create<float>('c', {1,1,1}, {3.f});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {1,3,1}, {1.f, 2.f, 3.f});
|
auto exp = NDArrayFactory::create<float>('c', {1,3,1}, {1.f, 2.f, 3.f});
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
|
|
||||||
|
@ -331,10 +327,10 @@ TEST_F(DeclarableOpsTests9, concat_test4) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test5) {
|
TEST_F(DeclarableOpsTests9, concat_test5) {
|
||||||
|
|
||||||
auto x0 = NDArrayFactory::create<double>(1.f);
|
auto x0 = NDArrayFactory::create<float>(1.f);
|
||||||
auto x1 = NDArrayFactory::create<double>('c', {1}, {2.f});
|
auto x1 = NDArrayFactory::create<float>('c', {1}, {2.f});
|
||||||
auto x2 = NDArrayFactory::create<double>(3.f);
|
auto x2 = NDArrayFactory::create<float>(3.f);
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3}, {1.f, 2.f, 3.f});
|
auto exp = NDArrayFactory::create<float>('c', {3}, {1.f, 2.f, 3.f});
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
|
|
||||||
|
@ -351,10 +347,10 @@ TEST_F(DeclarableOpsTests9, concat_test5) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test6) {
|
TEST_F(DeclarableOpsTests9, concat_test6) {
|
||||||
|
|
||||||
auto x0 = NDArrayFactory::create<double>(1.f);
|
auto x0 = NDArrayFactory::create<float>(1.f);
|
||||||
auto x1 = NDArrayFactory::create<double>('c', {2}, {2.f, 20.f});
|
auto x1 = NDArrayFactory::create<float>('c', {2}, {2.f, 20.f});
|
||||||
auto x2 = NDArrayFactory::create<double>(3.f);
|
auto x2 = NDArrayFactory::create<float>(3.f);
|
||||||
auto exp = NDArrayFactory::create<double>('c', {4}, {1.f, 2.f, 20.f, 3.f});
|
auto exp = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 20.f, 3.f});
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
|
|
||||||
|
@ -371,10 +367,10 @@ TEST_F(DeclarableOpsTests9, concat_test6) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test7) {
|
TEST_F(DeclarableOpsTests9, concat_test7) {
|
||||||
|
|
||||||
auto x0 = NDArrayFactory::create<double>(1.f);
|
auto x0 = NDArrayFactory::create<float>(1.f);
|
||||||
auto x1 = NDArrayFactory::create<double>(2.f);
|
auto x1 = NDArrayFactory::create<float>(2.f);
|
||||||
auto x2 = NDArrayFactory::create<double>(3.f);
|
auto x2 = NDArrayFactory::create<float>(3.f);
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3}, {1.f, 2.f, 3.f});
|
auto exp = NDArrayFactory::create<float>('c', {3}, {1.f, 2.f, 3.f});
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
|
|
||||||
|
@ -391,8 +387,8 @@ TEST_F(DeclarableOpsTests9, concat_test7) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test8) {
|
TEST_F(DeclarableOpsTests9, concat_test8) {
|
||||||
|
|
||||||
auto x0 = NDArrayFactory::create<double>(1.f);
|
auto x0 = NDArrayFactory::create<float>(1.f);
|
||||||
auto exp = NDArrayFactory::create<double>('c', {1}, {1.f});
|
auto exp = NDArrayFactory::create<float>('c', {1}, {1.f});
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
|
|
||||||
|
@ -409,8 +405,8 @@ TEST_F(DeclarableOpsTests9, concat_test8) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test9) {
|
TEST_F(DeclarableOpsTests9, concat_test9) {
|
||||||
|
|
||||||
auto x0 = NDArrayFactory::create<double>('c', {1}, {1.f});
|
auto x0 = NDArrayFactory::create<float>('c', {1}, {1.f});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {1}, {1.f});
|
auto exp = NDArrayFactory::create<float>('c', {1}, {1.f});
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
|
|
||||||
|
@ -427,10 +423,10 @@ TEST_F(DeclarableOpsTests9, concat_test9) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test10) {
|
TEST_F(DeclarableOpsTests9, concat_test10) {
|
||||||
|
|
||||||
auto x0 = NDArrayFactory::create<double>('c', {2,3,4});
|
auto x0 = NDArrayFactory::create<float>('c', {2,3,4});
|
||||||
auto x1 = NDArrayFactory::create<double>('f', {2,2,4});
|
auto x1 = NDArrayFactory::create<float>('f', {2,2,4});
|
||||||
auto x2 = NDArrayFactory::create<double>('c', {2,1,4});
|
auto x2 = NDArrayFactory::create<float>('c', {2,1,4});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
|
auto exp = NDArrayFactory::create<float>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
|
||||||
13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f});
|
13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f});
|
||||||
|
|
||||||
x0.linspace(1);
|
x0.linspace(1);
|
||||||
|
@ -452,10 +448,10 @@ TEST_F(DeclarableOpsTests9, concat_test10) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test11) {
|
TEST_F(DeclarableOpsTests9, concat_test11) {
|
||||||
|
|
||||||
auto x0 = NDArrayFactory::create<double>('c', {2,3,4});
|
auto x0 = NDArrayFactory::create<float>('c', {2,3,4});
|
||||||
auto x1 = NDArrayFactory::create<double>('f', {2,2,4});
|
auto x1 = NDArrayFactory::create<float>('f', {2,2,4});
|
||||||
auto x2 = NDArrayFactory::create<double>('f', {2,1,4});
|
auto x2 = NDArrayFactory::create<float>('f', {2,1,4});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
|
auto exp = NDArrayFactory::create<float>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
|
||||||
13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f});
|
13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f});
|
||||||
|
|
||||||
x0.linspace(1);
|
x0.linspace(1);
|
||||||
|
@ -477,10 +473,10 @@ TEST_F(DeclarableOpsTests9, concat_test11) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test12) {
|
TEST_F(DeclarableOpsTests9, concat_test12) {
|
||||||
|
|
||||||
auto x0 = NDArrayFactory::create<double>('c', {2,3,4});
|
auto x0 = NDArrayFactory::create<float>('c', {2,3,4});
|
||||||
auto x1 = NDArrayFactory::create<double>('f', {2,2,4});
|
auto x1 = NDArrayFactory::create<float>('f', {2,2,4});
|
||||||
auto x2 = NDArrayFactory::create<double>('f', {2,1,4});
|
auto x2 = NDArrayFactory::create<float>('f', {2,1,4});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
|
auto exp = NDArrayFactory::create<float>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
|
||||||
13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f});
|
13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f});
|
||||||
|
|
||||||
x0.linspace(1);
|
x0.linspace(1);
|
||||||
|
@ -502,10 +498,10 @@ TEST_F(DeclarableOpsTests9, concat_test12) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test13) {
|
TEST_F(DeclarableOpsTests9, concat_test13) {
|
||||||
|
|
||||||
auto x0 = NDArrayFactory::create<double>('f', {2,3,4});
|
auto x0 = NDArrayFactory::create<float>('f', {2,3,4});
|
||||||
auto x1 = NDArrayFactory::create<double>('f', {2,2,4});
|
auto x1 = NDArrayFactory::create<float>('f', {2,2,4});
|
||||||
auto x2 = NDArrayFactory::create<double>('f', {2,1,4});
|
auto x2 = NDArrayFactory::create<float>('f', {2,1,4});
|
||||||
auto exp = NDArrayFactory::create<double>('f', {2,6,4}, { 1.f, 13.f, 5.f, 17.f, 9.f, 21.f, 1.f, 9.f, 5.f, 13.f, 1.f, 5.f, 2.f, 14.f, 6.f, 18.f,10.f, 22.f, 2.f, 10.f, 6.f, 14.f, 2.f, 6.f,
|
auto exp = NDArrayFactory::create<float>('f', {2,6,4}, { 1.f, 13.f, 5.f, 17.f, 9.f, 21.f, 1.f, 9.f, 5.f, 13.f, 1.f, 5.f, 2.f, 14.f, 6.f, 18.f,10.f, 22.f, 2.f, 10.f, 6.f, 14.f, 2.f, 6.f,
|
||||||
3.f, 15.f, 7.f, 19.f,11.f, 23.f, 3.f, 11.f, 7.f, 15.f, 3.f, 7.f, 4.f, 16.f, 8.f, 20.f,12.f, 24.f, 4.f, 12.f, 8.f, 16.f, 4.f, 8.f});
|
3.f, 15.f, 7.f, 19.f,11.f, 23.f, 3.f, 11.f, 7.f, 15.f, 3.f, 7.f, 4.f, 16.f, 8.f, 20.f,12.f, 24.f, 4.f, 12.f, 8.f, 16.f, 4.f, 8.f});
|
||||||
|
|
||||||
x0.linspace(1);
|
x0.linspace(1);
|
||||||
|
@ -527,8 +523,8 @@ TEST_F(DeclarableOpsTests9, concat_test13) {
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests9, concat_test14) {
|
TEST_F(DeclarableOpsTests9, concat_test14) {
|
||||||
|
|
||||||
NDArray x0('c', {1, 40, 60}, sd::DataType::DOUBLE);
|
NDArray x0('c', {1, 40, 60}, sd::DataType::FLOAT32);
|
||||||
NDArray x1('c', {1, 40, 60}, sd::DataType::DOUBLE);
|
NDArray x1('c', {1, 40, 60}, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
x0 = 1.;
|
x0 = 1.;
|
||||||
x1 = 2.;
|
x1 = 2.;
|
||||||
|
@ -544,7 +540,7 @@ TEST_F(DeclarableOpsTests9, concat_test14) {
|
||||||
|
|
||||||
for (int e = 0; e < numOfTads; ++e) {
|
for (int e = 0; e < numOfTads; ++e) {
|
||||||
NDArray tad = (*z)(e, {0});
|
NDArray tad = (*z)(e, {0});
|
||||||
auto mean = tad.meanNumber().e<double>(0);
|
auto mean = tad.meanNumber().e<float>(0);
|
||||||
ASSERT_NEAR((e+1)*1., mean, 1e-5);
|
ASSERT_NEAR((e+1)*1., mean, 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -552,9 +548,9 @@ TEST_F(DeclarableOpsTests9, concat_test14) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests9, concat_test15) {
|
TEST_F(DeclarableOpsTests9, concat_test15) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {2}, {1, 0});
|
auto x = NDArrayFactory::create<float>('c', {2}, {1, 0});
|
||||||
auto y = NDArrayFactory::create<double> (3.0f);
|
auto y = NDArrayFactory::create<float> (3.0f);
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3}, {1, 0, 3});
|
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 0, 3});
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
auto result = op.evaluate({&x, &y}, {}, {0});
|
auto result = op.evaluate({&x, &y}, {}, {0});
|
||||||
|
@ -571,9 +567,9 @@ TEST_F(DeclarableOpsTests9, concat_test15) {
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test16) {
|
TEST_F(DeclarableOpsTests9, concat_test16) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {0,2,3});
|
auto x = NDArrayFactory::create<float>('c', {0,2,3});
|
||||||
auto y = NDArrayFactory::create<double>('c', {0,2,3});
|
auto y = NDArrayFactory::create<float>('c', {0,2,3});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {0,2,3});
|
auto exp = NDArrayFactory::create<float>('c', {0,2,3});
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
auto result = op.evaluate({&x, &y}, {}, {0});
|
auto result = op.evaluate({&x, &y}, {}, {0});
|
||||||
|
@ -587,8 +583,8 @@ TEST_F(DeclarableOpsTests9, concat_test16) {
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test17) {
|
TEST_F(DeclarableOpsTests9, concat_test17) {
|
||||||
|
|
||||||
NDArray x0('c', {1, 55, 40}, sd::DataType::DOUBLE);
|
NDArray x0('c', {1, 55, 40}, sd::DataType::FLOAT32);
|
||||||
NDArray x1('c', {1, 55, 40}, sd::DataType::DOUBLE);
|
NDArray x1('c', {1, 55, 40}, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
x0 = 1.;
|
x0 = 1.;
|
||||||
x1 = 2.;
|
x1 = 2.;
|
||||||
|
@ -606,7 +602,7 @@ TEST_F(DeclarableOpsTests9, concat_test17) {
|
||||||
|
|
||||||
for (int e = 0; e < numOfTads; ++e) {
|
for (int e = 0; e < numOfTads; ++e) {
|
||||||
NDArray tad = (*z)(e, {0});
|
NDArray tad = (*z)(e, {0});
|
||||||
auto mean = tad.meanNumber().e<double>(0);
|
auto mean = tad.meanNumber().e<float>(0);
|
||||||
ASSERT_NEAR((e+1)*1., mean, 1e-5);
|
ASSERT_NEAR((e+1)*1., mean, 1e-5);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -664,10 +660,10 @@ TEST_F(DeclarableOpsTests9, concat_test19) {
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test20) {
|
TEST_F(DeclarableOpsTests9, concat_test20) {
|
||||||
auto x0 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
auto x0 = NDArrayFactory::create<float>('c', {1, 100, 150});
|
||||||
auto x1 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
auto x1 = NDArrayFactory::create<float>('c', {1, 100, 150});
|
||||||
auto x2 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
auto x2 = NDArrayFactory::create<float>('c', {1, 100, 150});
|
||||||
auto x3 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
auto x3 = NDArrayFactory::create<float>('c', {1, 100, 150});
|
||||||
|
|
||||||
x0.assign(1.0);
|
x0.assign(1.0);
|
||||||
x1.assign(2.0);
|
x1.assign(2.0);
|
||||||
|
@ -685,8 +681,8 @@ TEST_F(DeclarableOpsTests9, concat_test20) {
|
||||||
|
|
||||||
for (int e = 0; e < numOfTads; e++) {
|
for (int e = 0; e < numOfTads; e++) {
|
||||||
NDArray tad = (*z)(e, {0});
|
NDArray tad = (*z)(e, {0});
|
||||||
auto mean = tad.meanNumber().e<double>(0);
|
auto mean = tad.meanNumber().e<float>(0);
|
||||||
ASSERT_NEAR((double) e+1, mean, 1e-5);
|
ASSERT_NEAR((float) e+1, mean, 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -710,10 +706,10 @@ TEST_F(DeclarableOpsTests9, concat_test21) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test22) {
|
TEST_F(DeclarableOpsTests9, concat_test22) {
|
||||||
|
|
||||||
NDArray x0('c', {1,6}, {1,2,3,4,5,6});
|
NDArray x0('c', {1,6}, {1,2,3,4,5,6}, sd::DataType::FLOAT32);
|
||||||
NDArray x1('c', {1,6}, {7,8,9,10,11,12});
|
NDArray x1('c', {1,6}, {7,8,9,10,11,12}, sd::DataType::FLOAT32);
|
||||||
NDArray output('f', {2,6}, sd::DataType::DOUBLE);
|
NDArray output('f', {2,6}, sd::DataType::FLOAT32);
|
||||||
NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12});
|
NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
|
|
||||||
|
@ -726,10 +722,10 @@ TEST_F(DeclarableOpsTests9, concat_test22) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test23) {
|
TEST_F(DeclarableOpsTests9, concat_test23) {
|
||||||
|
|
||||||
NDArray x0('c', {1,4}, {1,2,3,4});
|
NDArray x0('c', {1,4}, {1,2,3,4},sd::DataType::FLOAT32);
|
||||||
NDArray x1('c', {1,4}, {5,6,7,8});
|
NDArray x1('c', {1,4}, {5,6,7,8},sd::DataType::FLOAT32);
|
||||||
NDArray output('c', {2,4}, sd::DataType::DOUBLE);
|
NDArray output('c', {2,4}, sd::DataType::FLOAT32);
|
||||||
NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8});
|
NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8}, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
|
|
||||||
|
@ -741,10 +737,10 @@ TEST_F(DeclarableOpsTests9, concat_test23) {
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test24) {
|
TEST_F(DeclarableOpsTests9, concat_test24) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 1}, {1, 1});
|
auto x = NDArrayFactory::create<float>('c', {2, 1}, {1, 1});
|
||||||
auto y = NDArrayFactory::create<double>('c', {2, 1}, {0, 0});
|
auto y = NDArrayFactory::create<float>('c', {2, 1}, {0, 0});
|
||||||
auto e = NDArrayFactory::create<double>('c', {2, 2}, {1, 0, 1, 0});
|
auto e = NDArrayFactory::create<float>('c', {2, 2}, {1, 0, 1, 0});
|
||||||
auto z = NDArrayFactory::create<double>('c', {2, 2});
|
auto z = NDArrayFactory::create<float>('c', {2, 2});
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
auto status = op.execute({&x, &y}, {&z}, {}, {1}, {});
|
auto status = op.execute({&x, &y}, {&z}, {}, {1}, {});
|
||||||
|
@ -756,10 +752,10 @@ TEST_F(DeclarableOpsTests9, concat_test24) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test25) {
|
TEST_F(DeclarableOpsTests9, concat_test25) {
|
||||||
|
|
||||||
auto x0 = NDArrayFactory::create<double>('c', {1,4}, {1,2,3,4});
|
auto x0 = NDArrayFactory::create<float>('c', {1,4}, {1,2,3,4});
|
||||||
auto x1 = NDArrayFactory::create<double>('c', {1,4}, {5,6,7,8});
|
auto x1 = NDArrayFactory::create<float>('c', {1,4}, {5,6,7,8});
|
||||||
auto axis = NDArrayFactory::create<double>('c', {1}, {0.});
|
auto axis = NDArrayFactory::create<float>('c', {1}, {0.});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2,4}, {1,2,3,4,5,6,7,8});
|
auto exp = NDArrayFactory::create<float>('c', {2,4}, {1,2,3,4,5,6,7,8});
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
|
|
||||||
|
@ -793,7 +789,7 @@ TEST_F(DeclarableOpsTests9, concat_test26) {
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
auto output = result.at(0);
|
auto output = result.at(0);
|
||||||
output->printLinearBuffer();
|
// output->printLinearBuffer();
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
@ -802,10 +798,10 @@ TEST_F(DeclarableOpsTests9, concat_test26) {
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test27) {
|
TEST_F(DeclarableOpsTests9, concat_test27) {
|
||||||
|
|
||||||
auto x1 = NDArrayFactory::create<double>('c', {0,1});
|
auto x1 = NDArrayFactory::create<float>('c', {0,1});
|
||||||
auto x2 = NDArrayFactory::create<double>('c', {0,1});
|
auto x2 = NDArrayFactory::create<float>('c', {0,1});
|
||||||
auto x3 = NDArrayFactory::create<double>('c', {0,1});
|
auto x3 = NDArrayFactory::create<float>('c', {0,1});
|
||||||
auto x4 = NDArrayFactory::create<double>('c', {0,1});
|
auto x4 = NDArrayFactory::create<float>('c', {0,1});
|
||||||
|
|
||||||
std::vector<Nd4jLong> expShape = {0, 4};
|
std::vector<Nd4jLong> expShape = {0, 4};
|
||||||
|
|
||||||
|
@ -1245,109 +1241,6 @@ TEST_F(DeclarableOpsTests9, test_unstack_SGO_1) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, clipbynorm_test12) {
|
|
||||||
|
|
||||||
const int bS = 5;
|
|
||||||
const int nOut = 4;
|
|
||||||
const int axis = 0;
|
|
||||||
const double clip = 2.;
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.897 ,0.173 ,0.931 ,0.736 ,0.540 ,0.953 ,0.278 ,0.573 ,0.787 ,0.320 ,0.776 ,0.338 ,0.311 ,0.835 ,0.909 ,0.890 ,0.290}); // uniform random in range [0,1]
|
|
||||||
auto colVect = NDArrayFactory::create<double>('c', {bS, 1}, {0.9, 0.95, 1.00, 1.05, 1.1});
|
|
||||||
auto expect = NDArrayFactory::create<double>('c', {bS, nOut});
|
|
||||||
|
|
||||||
auto norm2 = x.reduceAlongDimension(reduce::Norm2, {axis}, true); // norm2 has shape [1, nOut]
|
|
||||||
|
|
||||||
auto y = ( (x / norm2) * clip) * colVect ;
|
|
||||||
auto temp = (x / norm2) * clip;
|
|
||||||
|
|
||||||
for (int j = 0; j < nOut; ++j) {
|
|
||||||
auto yCol = y({0,0, j,j+1});
|
|
||||||
const double norm2Col = yCol.reduceNumber(reduce::Norm2).e<double>(0);
|
|
||||||
if (norm2Col <= clip)
|
|
||||||
expect({0,0, j,j+1}).assign(yCol);
|
|
||||||
else
|
|
||||||
expect({0,0, j,j+1}).assign ( yCol * (clip / norm2Col) );
|
|
||||||
}
|
|
||||||
|
|
||||||
sd::ops::clipbynorm op;
|
|
||||||
auto result = op.evaluate({&y}, {clip}, {axis});
|
|
||||||
auto outFF = result.at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expect.isSameShape(outFF));
|
|
||||||
ASSERT_TRUE(expect.equalsTo(outFF));
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, clipbynorm_bp_test1) {
|
|
||||||
|
|
||||||
const int bS = 2;
|
|
||||||
const int nOut = 3;
|
|
||||||
const double clip = 0.7;
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
|
||||||
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
|
|
||||||
|
|
||||||
const OpArgsHolder argsHolderFF({&x}, {clip}, {});
|
|
||||||
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {});
|
|
||||||
|
|
||||||
sd::ops::clipbynorm opFF;
|
|
||||||
sd::ops::clipbynorm_bp opBP;
|
|
||||||
|
|
||||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
|
||||||
|
|
||||||
ASSERT_TRUE(isGradCorrect);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, clipbynorm_bp_test2) {
|
|
||||||
|
|
||||||
const int bS = 2;
|
|
||||||
const int nOut = 3;
|
|
||||||
const int axis = 0;
|
|
||||||
const double clip = 0.7;
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
|
||||||
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
|
|
||||||
|
|
||||||
const OpArgsHolder argsHolderFF({&x}, {clip}, {axis});
|
|
||||||
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis});
|
|
||||||
|
|
||||||
sd::ops::clipbynorm opFF;
|
|
||||||
sd::ops::clipbynorm_bp opBP;
|
|
||||||
|
|
||||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
|
||||||
|
|
||||||
ASSERT_TRUE(isGradCorrect);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, clipbynorm_bp_test3) {
|
|
||||||
|
|
||||||
const int bS = 2;
|
|
||||||
const int nOut = 3;
|
|
||||||
const int axis = 1;
|
|
||||||
const double clip = 1.;
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
|
||||||
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
|
|
||||||
|
|
||||||
const OpArgsHolder argsHolderFF({&x}, {clip}, {axis});
|
|
||||||
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis});
|
|
||||||
|
|
||||||
sd::ops::clipbynorm opFF;
|
|
||||||
sd::ops::clipbynorm_bp opBP;
|
|
||||||
|
|
||||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
|
||||||
|
|
||||||
ASSERT_TRUE(isGradCorrect);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, cumprod_1) {
|
TEST_F(DeclarableOpsTests9, cumprod_1) {
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@ import java.util.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class DefaultOpExecutioner implements OpExecutioner {
|
public abstract class DefaultOpExecutioner implements OpExecutioner {
|
||||||
|
|
||||||
private static final String SCOPE_PANIC_MSG = "For more details, see the ND4J User Guide: deeplearning4j.org/docs/latest/nd4j-overview#workspaces-panic";
|
private static final String SCOPE_PANIC_MSG = "For more details, see the ND4J User Guide: https://deeplearning4j.konduit.ai/nd4j/overview#workspaces-scope-panic";
|
||||||
|
|
||||||
protected ProfilingMode profilingMode = ProfilingMode.SCOPE_PANIC;
|
protected ProfilingMode profilingMode = ProfilingMode.SCOPE_PANIC;
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ import java.util.*;
|
||||||
/**
|
/**
|
||||||
* An ND4j backend.
|
* An ND4j backend.
|
||||||
*
|
*
|
||||||
* A "backend" is also described here: http://nd4j.org/backend.html
|
* A "backend" is also described here: https://deeplearning4j.konduit.ai/nd4j/backend
|
||||||
*
|
*
|
||||||
* A backend also has 2 variables to be aware of.
|
* A backend also has 2 variables to be aware of.
|
||||||
* 1 is the environment variable, ND4J_DYNAMIC_LOAD_CLASSPATH
|
* 1 is the environment variable, ND4J_DYNAMIC_LOAD_CLASSPATH
|
||||||
|
@ -219,7 +219,7 @@ public abstract class Nd4jBackend {
|
||||||
|
|
||||||
else
|
else
|
||||||
throw new NoAvailableBackendException(
|
throw new NoAvailableBackendException(
|
||||||
"Please ensure that you have an nd4j backend on your classpath. Please see: http://nd4j.org/getstarted.html");
|
"Please ensure that you have an nd4j backend on your classpath. Please see: https://deeplearning4j.konduit.ai/nd4j/backend");
|
||||||
|
|
||||||
triedDynamicLoad = true;
|
triedDynamicLoad = true;
|
||||||
//load all the discoverable uris and try to load the backend again
|
//load all the discoverable uris and try to load the backend again
|
||||||
|
|
|
@ -110,7 +110,7 @@ public class NativeOpsHolder {
|
||||||
}
|
}
|
||||||
} catch (Exception | Error e) {
|
} catch (Exception | Error e) {
|
||||||
throw new RuntimeException(
|
throw new RuntimeException(
|
||||||
"ND4J is probably missing dependencies. For more information, please refer to: http://nd4j.org/getstarted.html",
|
"ND4J is probably missing dependencies. For more information, please refer to: https://deeplearning4j.konduit.ai/nd4j/backend",
|
||||||
e);
|
e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,6 +69,8 @@
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<classifier>${javacpp.platform.android-x86_64}</classifier>
|
<classifier>${javacpp.platform.android-x86_64}</classifier>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<!--
|
||||||
|
iOS support removed for 1.0.0-beta7 release, to be restored at a later date
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>${project.groupId}</groupId>
|
<groupId>${project.groupId}</groupId>
|
||||||
<artifactId>${nd4j.backend}</artifactId>
|
<artifactId>${nd4j.backend}</artifactId>
|
||||||
|
@ -81,6 +83,7 @@
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<classifier>${javacpp.platform.ios-x86_64}</classifier>
|
<classifier>${javacpp.platform.ios-x86_64}</classifier>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
-->
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>${project.groupId}</groupId>
|
<groupId>${project.groupId}</groupId>
|
||||||
<artifactId>${nd4j.backend}</artifactId>
|
<artifactId>${nd4j.backend}</artifactId>
|
||||||
|
|
Loading…
Reference in New Issue