Merge pull request #8668 from KonduitAI/master

Update master - recent fixes
master
Alex Black 2020-02-04 18:19:26 +11:00 committed by GitHub
commit dcc1187e1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
222 changed files with 11966 additions and 5550 deletions

View File

@ -43,12 +43,6 @@ import java.util.Random;
public class CapsnetGradientCheckTest extends BaseDL4JTest { public class CapsnetGradientCheckTest extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false;
private static final double DEFAULT_EPS = 1e-6;
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
@Test @Test
public void testCapsNet() { public void testCapsNet() {

View File

@ -43,10 +43,6 @@ import static org.junit.Assert.assertTrue;
public class OutputLayerGradientChecks extends BaseDL4JTest { public class OutputLayerGradientChecks extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true; private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false;
private static final double DEFAULT_EPS = 1e-6;
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
static { static {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);

View File

@ -47,10 +47,6 @@ import static org.junit.Assert.assertTrue;
public class RnnGradientChecks extends BaseDL4JTest { public class RnnGradientChecks extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true; private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false;
private static final double DEFAULT_EPS = 1e-6;
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
static { static {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);

View File

@ -48,12 +48,6 @@ import static org.junit.Assert.assertTrue;
public class UtilLayerGradientChecks extends BaseDL4JTest { public class UtilLayerGradientChecks extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false;
private static final double DEFAULT_EPS = 1e-6;
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-6;
static { static {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@ -182,9 +176,9 @@ public class UtilLayerGradientChecks extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net)
.minAbsoluteError(1e-7) .minAbsoluteError(1e-6)
.labels(label).inputMask(inMask)); .input(input).labels(label).inputMask(inMask));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -233,8 +227,9 @@ public class UtilLayerGradientChecks extends BaseDL4JTest {
//Test ComputationGraph equivalent: //Test ComputationGraph equivalent:
ComputationGraph g = net.toComputationGraph(); ComputationGraph g = net.toComputationGraph();
boolean gradOKCG = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(g).inputs(new INDArray[]{in}) boolean gradOKCG = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(g)
.labels(new INDArray[]{labels}).excludeParams(excludeParams)); .minAbsoluteError(1e-6)
.inputs(new INDArray[]{in}).labels(new INDArray[]{labels}).excludeParams(excludeParams));
assertTrue(gradOKCG); assertTrue(gradOKCG);
TestUtils.testModelSerialization(g); TestUtils.testModelSerialization(g);

View File

@ -56,11 +56,6 @@ import static org.junit.Assert.assertTrue;
* @author Alex Black * @author Alex Black
*/ */
public class YoloGradientCheckTests extends BaseDL4JTest { public class YoloGradientCheckTests extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false;
private static final double DEFAULT_EPS = 1e-6;
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
static { static {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);

View File

@ -21,14 +21,13 @@ import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.GravesLSTM; import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -416,4 +415,53 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest {
} }
} }
} }
@Test
public void testMaskLayerDataTypes(){
for(DataType dt : new DataType[]{DataType.FLOAT16, DataType.BFLOAT16, DataType.FLOAT, DataType.DOUBLE,
DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64,
DataType.UINT8, DataType.UINT16, DataType.UINT32, DataType.UINT64}){
INDArray mask = Nd4j.rand(DataType.FLOAT, 2, 10).addi(0.3).castTo(dt);
for(DataType networkDtype : new DataType[]{DataType.FLOAT16, DataType.BFLOAT16, DataType.FLOAT, DataType.DOUBLE}){
INDArray in = Nd4j.rand(networkDtype, 2, 5, 10);
INDArray label1 = Nd4j.rand(networkDtype, 2, 5);
INDArray label2 = Nd4j.rand(networkDtype, 2, 5, 10);
for(PoolingType pt : PoolingType.values()) {
//System.out.println("Net: " + networkDtype + ", mask: " + dt + ", pt=" + pt);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new GlobalPoolingLayer(pt))
.layer(new OutputLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.output(in, false, mask, null);
net.output(in, false, mask, null);
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.list()
.layer(new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();
net2.output(in, false, mask, mask);
net2.output(in, false, mask, mask);
net.fit(in, label1, mask, null);
net2.fit(in, label2, mask, mask);
}
}
}
}
} }

View File

@ -19,8 +19,8 @@
<appender name="FILE" class="ch.qos.logback.core.FileAppender"> <appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file> <file>logs/application.log</file>
<encoder> <encoder>
<pattern>%date - [%level] - from %logger in %thread <pattern> %logger{15} - %message%n%xException{5}
%n%message%n%xException%n</pattern> </pattern>
</encoder> </encoder>
</appender> </appender>

View File

@ -19,8 +19,8 @@
<appender name="FILE" class="ch.qos.logback.core.FileAppender"> <appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file> <file>logs/application.log</file>
<encoder> <encoder>
<pattern>%date - [%level] - from %logger in %thread <pattern> %logger{15} - %message%n%xException{5}
%n%message%n%xException%n</pattern> </pattern>
</encoder> </encoder>
</appender> </appender>

View File

@ -19,8 +19,8 @@
<appender name="FILE" class="ch.qos.logback.core.FileAppender"> <appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file> <file>logs/application.log</file>
<encoder> <encoder>
<pattern>%date - [%level] - from %logger in %thread <pattern> %logger{15} - %message%n%xException{5}
%n%message%n%xException%n</pattern> </pattern>
</encoder> </encoder>
</appender> </appender>

View File

@ -131,7 +131,7 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, training, workspaceMgr); INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, training, workspaceMgr);
applyDropOutIfNecessary(training, workspaceMgr); applyDropOutIfNecessary(training, workspaceMgr);
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input, LayerWorkspaceMgr.noWorkspaces(), ArrayType.FF_WORKING_MEM); INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input.castTo(W.dataType()), workspaceMgr, ArrayType.FF_WORKING_MEM);
INDArray act2d = layerConf().getActivationFn().getActivation(input2d.mmul(W).addiRowVector(b), training); INDArray act2d = layerConf().getActivationFn().getActivation(input2d.mmul(W).addiRowVector(b), training);
if (maskArray != null) { if (maskArray != null) {

View File

@ -56,6 +56,7 @@ public class MaskedReductionUtil {
throw new IllegalArgumentException("Expect rank 2 array for mask: got " + mask.rank()); throw new IllegalArgumentException("Expect rank 2 array for mask: got " + mask.rank());
} }
toReduce = toReduce.castTo(dataType);
mask = mask.castTo(dataType); mask = mask.castTo(dataType);
//Sum pooling: easy. Multiply by mask, then sum as normal //Sum pooling: easy. Multiply by mask, then sum as normal
@ -64,13 +65,7 @@ public class MaskedReductionUtil {
switch (poolingType) { switch (poolingType) {
case MAX: case MAX:
//TODO This is ugly - replace it with something better... Need something like a Broadcast CAS op INDArray negInfMask = mask.castTo(dataType).rsub(1.0);
INDArray negInfMask;
if(mask.dataType() == DataType.BOOL){
negInfMask = Transforms.not(mask).castTo(dataType);
} else {
negInfMask = mask.rsub(1.0);
}
BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0)); BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0));
INDArray withInf = Nd4j.createUninitialized(dataType, toReduce.shape()); INDArray withInf = Nd4j.createUninitialized(dataType, toReduce.shape());
@ -121,18 +116,14 @@ public class MaskedReductionUtil {
//Mask: [minibatch, tsLength] //Mask: [minibatch, tsLength]
//Epsilon: [minibatch, vectorSize] //Epsilon: [minibatch, vectorSize]
mask = mask.castTo(input.dataType());
switch (poolingType) { switch (poolingType) {
case MAX: case MAX:
//TODO This is ugly - replace it with something better... Need something like a Broadcast CAS op INDArray negInfMask = mask.rsub(1.0);
INDArray negInfMask;
if(mask.dataType() == DataType.BOOL){
negInfMask = Transforms.not(mask).castTo(Nd4j.defaultFloatingPointType());
} else {
negInfMask = mask.rsub(1.0);
}
BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0)); BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0));
INDArray withInf = Nd4j.createUninitialized(input.shape()); INDArray withInf = Nd4j.createUninitialized(input.dataType(), input.shape());
Nd4j.getExecutioner().exec(new BroadcastAddOp(input, negInfMask, withInf, 0, 2)); Nd4j.getExecutioner().exec(new BroadcastAddOp(input, negInfMask, withInf, 0, 2));
//At this point: all the masked out steps have value -inf, hence can't be the output of the MAX op //At this point: all the masked out steps have value -inf, hence can't be the output of the MAX op
@ -145,7 +136,7 @@ public class MaskedReductionUtil {
//if out = avg(in,dims) then dL/dIn = 1/N * dL/dOut //if out = avg(in,dims) then dL/dIn = 1/N * dL/dOut
//With masking: N differs for different time series //With masking: N differs for different time series
INDArray out = Nd4j.createUninitialized(input.shape(), 'f'); INDArray out = Nd4j.createUninitialized(input.dataType(), input.shape(), 'f');
//Broadcast copy op, then divide and mask to 0 as appropriate //Broadcast copy op, then divide and mask to 0 as appropriate
Nd4j.getExecutioner().exec(new BroadcastCopyOp(out, epsilon2d, out, 0, 1)); Nd4j.getExecutioner().exec(new BroadcastCopyOp(out, epsilon2d, out, 0, 1));
@ -162,7 +153,7 @@ public class MaskedReductionUtil {
case PNORM: case PNORM:
//Similar to average and sum pooling: there's no N term here, so we can just set the masked values to 0 //Similar to average and sum pooling: there's no N term here, so we can just set the masked values to 0
INDArray masked2 = Nd4j.createUninitialized(input.shape()); INDArray masked2 = Nd4j.createUninitialized(input.dataType(), input.shape());
Nd4j.getExecutioner().exec(new BroadcastMulOp(input, mask, masked2, 0, 2)); Nd4j.getExecutioner().exec(new BroadcastMulOp(input, mask, masked2, 0, 2));
INDArray abs = Transforms.abs(masked2, true); INDArray abs = Transforms.abs(masked2, true);

View File

@ -21,8 +21,8 @@
<appender name="FILE" class="ch.qos.logback.core.FileAppender"> <appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file> <file>logs/application.log</file>
<encoder> <encoder>
<pattern>%date - [%level] - from %logger in %thread <pattern> %logger{15} - %message%n%xException{5}
%n%message%n%xException%n</pattern> </pattern>
</encoder> </encoder>
</appender> </appender>

View File

@ -21,8 +21,8 @@
<appender name="FILE" class="ch.qos.logback.core.FileAppender"> <appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file> <file>logs/application.log</file>
<encoder> <encoder>
<pattern>%date - [%level] - from %logger in %thread <pattern> %logger{15} - %message%n%xException{5}
%n%message%n%xException%n</pattern> </pattern>
</encoder> </encoder>
</appender> </appender>

View File

@ -21,8 +21,8 @@
<appender name="FILE" class="ch.qos.logback.core.FileAppender"> <appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file> <file>logs/application.log</file>
<encoder> <encoder>
<pattern>%date - [%level] - from %logger in %thread <pattern> %logger{15} - %message%n%xException{5}
%n%message%n%xException%n</pattern> </pattern>
</encoder> </encoder>
</appender> </appender>

View File

@ -21,8 +21,8 @@
<appender name="FILE" class="ch.qos.logback.core.FileAppender"> <appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file> <file>logs/application.log</file>
<encoder> <encoder>
<pattern>%date - [%level] - from %logger in %thread <pattern> %logger{15} - %message%n%xException{5}
%n%message%n%xException%n</pattern> </pattern>
</encoder> </encoder>
</appender> </appender>

View File

@ -21,8 +21,8 @@
<appender name="FILE" class="ch.qos.logback.core.FileAppender"> <appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file> <file>logs/application.log</file>
<encoder> <encoder>
<pattern>%date - [%level] - from %logger in %thread <pattern> %logger{15} - %message%n%xException{5}
%n%message%n%xException%n</pattern> </pattern>
</encoder> </encoder>
</appender> </appender>

View File

@ -21,8 +21,8 @@
<appender name="FILE" class="ch.qos.logback.core.FileAppender"> <appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file> <file>logs/application.log</file>
<encoder> <encoder>
<pattern>%date - [%level] - from %logger in %thread <pattern> %logger{15} - %message%n%xException{5}
%n%message%n%xException%n</pattern> </pattern>
</encoder> </encoder>
</appender> </appender>

View File

@ -21,8 +21,8 @@
<appender name="FILE" class="ch.qos.logback.core.FileAppender"> <appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file> <file>logs/application.log</file>
<encoder> <encoder>
<pattern>%date - [%level] - from %logger in %thread <pattern> %logger{15} - %message%n%xException{5}
%n%message%n%xException%n</pattern> </pattern>
</encoder> </encoder>
</appender> </appender>

View File

@ -21,8 +21,8 @@
<appender name="FILE" class="ch.qos.logback.core.FileAppender"> <appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file> <file>logs/application.log</file>
<encoder> <encoder>
<pattern>%date - [%level] - from %logger in %thread <pattern> %logger{15} - %message%n%xException{5}
%n%message%n%xException%n</pattern> </pattern>
</encoder> </encoder>
</appender> </appender>

View File

@ -21,8 +21,8 @@
<appender name="FILE" class="ch.qos.logback.core.FileAppender"> <appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file> <file>logs/application.log</file>
<encoder> <encoder>
<pattern>%date - [%level] - from %logger in %thread <pattern> %logger{15} - %message%n%xException{5}
%n%message%n%xException%n</pattern> </pattern>
</encoder> </encoder>
</appender> </appender>

View File

@ -21,8 +21,8 @@
<appender name="FILE" class="ch.qos.logback.core.FileAppender"> <appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file> <file>logs/application.log</file>
<encoder> <encoder>
<pattern>%date - [%level] - from %logger in %thread <pattern> %logger{15} - %message%n%xException{5}
%n%message%n%xException%n</pattern> </pattern>
</encoder> </encoder>
</appender> </appender>

View File

@ -19,8 +19,8 @@
<appender name="FILE" class="ch.qos.logback.core.FileAppender"> <appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file> <file>logs/application.log</file>
<encoder> <encoder>
<pattern>%date - [%level] - from %logger in %thread <pattern> %logger{15} - %message%n%xException{5}
%n%message%n%xException%n</pattern> </pattern>
</encoder> </encoder>
</appender> </appender>

View File

@ -5,7 +5,7 @@ option(NATIVE "Optimize for build machine (might not work on others)" OFF)
set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH}) set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH})
#ensure we create lib files #ensure we create lib files
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS OFF) set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS OFF)
option(CHECK_VECTORIZATION "checks for vectorization" OFF)
option(BUILD_TESTS "Build tests" OFF) option(BUILD_TESTS "Build tests" OFF)
option(FLATBUFFERS_BUILD_FLATC "Enable the build of the flatbuffers compiler" OFF) option(FLATBUFFERS_BUILD_FLATC "Enable the build of the flatbuffers compiler" OFF)
set(FLATBUFFERS_BUILD_FLATC "OFF" CACHE STRING "Hack to disable flatc build" FORCE) set(FLATBUFFERS_BUILD_FLATC "OFF" CACHE STRING "Hack to disable flatc build" FORCE)

View File

@ -5,7 +5,7 @@ project(mkldnn-download NONE)
include(ExternalProject) include(ExternalProject)
ExternalProject_Add(mkldnn ExternalProject_Add(mkldnn
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
GIT_TAG v1.1.2 GIT_TAG v1.1.3
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src" SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""

View File

@ -17,8 +17,11 @@ There's few additional arguments for `buildnativeoperations.sh` script you could
-b release OR -b debug // enables/desables debug builds. release is considered by default -b release OR -b debug // enables/desables debug builds. release is considered by default
-j XX // this argument defines how many threads will be used to binaries on your box. i.e. -j 8 -j XX // this argument defines how many threads will be used to binaries on your box. i.e. -j 8
-cc XX// CUDA-only argument, builds only binaries for target GPU architecture. use this for fast builds -cc XX// CUDA-only argument, builds only binaries for target GPU architecture. use this for fast builds
--check-vectorization auto-vectorization report for developers. (Currently, only GCC is supported)
``` ```
[More about AutoVectorization report](auto_vectorization/AutoVectorization.md)
You can find the compute capability for your card [on the NVIDIA website here](https://developer.nvidia.com/cuda-gpus). You can find the compute capability for your card [on the NVIDIA website here](https://developer.nvidia.com/cuda-gpus).
For example, a GTX 1080 has compute capability 6.1, for which you would use ```-cc 61``` (note no decimal point). For example, a GTX 1080 has compute capability 6.1, for which you would use ```-cc 61``` (note no decimal point).

View File

@ -0,0 +1,49 @@
# Auto-vectorization Report
This report tool is used to get a human-friendly compiler output of the auto-vectorization process. It is intended for developers to help them to investigate the obstacles that compiler faced during auto-vectorization.
## Usage
```--check-vectorization``` option should be added to the **release** build to be able to get the auto-vectorization report
```./buildnativeoperations.sh -a native -j 28 --check-vectorization```
it will output ```vecmiss.html``` inside blasbuild/cpu folder.
## Report Format
Each filename contains info about optimization attempts for the source code lines.
Each line number is also expandable (⇲) and contains distinct failure notes.
It is possible to click on the line number to see source code
| file name | total successful attempts | total failed attempts | ⇲ |
|---|---|---|--|
| line number | successful attempts | failed attempts | ⇲ |
|- failure reasons |
| line number | successful attempts | failed attempts |⇲ |
##### Requirements
- GCC (Currently, only GCC is supported)
- python3
### Detailed report with `-fsave-optimization-record` option:
If you want to get more detailed information (for now it reports the functions of failures) you should use new version of the toolchain (GCC > 9). As the new version of GCC compilers have `-fsave-optimization-record` option.
`buildnativeoperations.sh` using CMake will detect it and switch to the more detailed version.
Please, note that this option is still experimental and so the compiler can fail to output some json.gz file with error.
On that case try to exclude those files from the build.
And also the internal structure of the `-fsave-optimization-record` json.gz can be changed in future.
It outputs two files **vecmiss_fsave.html** and **vecmiss_fsave.html.js**. So to see report details you need to enable javascript on browser if it was disabled.
##### Requirements for the Detailed report
- GCC version > 9
- python3
- Cython (python3)
- json (python3)
- gzip (python3)
- c++filt
Internally, we are using Cython to speed up json.gz file processing (bigGzipJson.pyx). Because json.gz files can take big memory in raw when loaded in whole.
If you want to use bigGzipJson outside `buildnativeoperations.sh` and CMake then you should compile it manually using this command in auto_vectorization folder:
`python3 cython_setup.py build_ext --inplace`
json.gz files could be processed outside of `buildnativeoperations.sh`.
You need to call `python3 auto_vect.py --fsave` inside base source folder and where json.gz files exist.

View File

@ -0,0 +1,546 @@
'''
@author : Abdelrauf rauf@konduit.ai
'''
import re
import sys
import os
import subprocess
import fnmatch
import json
import gzip
try:
from bigGzipJson import json_gzip_extract_objects
except ImportError:
pass
from pathlib import Path
from multiprocessing import Pool, Manager ,cpu_count
import traceback
import html
mtch = re.compile(r"[^/]*([^:]+)\:(\d+)\:(\d+)\:(.*)")
replace_msg = re.compile(r"(\d+)?\.?(\d+)?_?\d+\.?(\d+)?")
progress_msg = re.compile(r"\s{0,4}\[\s{0,2}\d+\%\]")
file_dir_strip = str(Path(os.getcwd()))
pp_index = file_dir_strip.rfind("libnd4j")
if pp_index>=0:
file_dir_strip =file_dir_strip[:pp_index+len("libnd4j")]
BASE_URL = "https://github.com/eclipse/deeplearning4j/tree/master/libnd4j/"
if BASE_URL.endswith("/")==False:
BASE_URL = BASE_URL + "/"
#print(file_dir_strip)
class info:
def __repr__(self):
return str(self.__dict__)
FSAVE_IGNORE_EXTERNALS = True
def get_cxx_filt_result(strx):
if len(strx)<1:
return ""
res = subprocess.Popen(["c++filt","-i", strx], stdout=subprocess.PIPE).communicate()[0]
res =res.decode('utf-8')
#replace some long names to reduce size
res = res.replace("unsigned long long", "uLL")
res = res.replace("unsigned long int","uL")
res = res.replace("unsigned long", "uL")
res = res.replace("unsigned int", "ui")
res = res.replace("unsigned char", "uchar")
res = res.replace("unsigned short", "ushort")
res = res.replace("long long", "LL")
res = res.replace(", ",",")
return res.strip()
def internal_glob(dir, match):
listx = []
for root, dirnames, filenames in os.walk(dir):
for filename in fnmatch.filter(filenames, match):
listx.append(os.path.join(root, filename))
return listx
def get_obj_json_gz(filename):
with gzip.GzipFile(filename, 'r') as f:
return json.loads(f.read().decode('utf-8'))[-1]
def get_msg(msg):
msg = msg.lower().strip()
if "note: not vectorized:" in msg:
msg = replace_msg.sub("_numb",msg.replace("note: not vectorized:",""))
return( 0, 1, msg.strip())
elif "loop vectorized" in msg:
return (1, 0, None)
# elif msg.startswith("missed")==False:
# msg = replace_msg.sub("_numb",msg)
# return( 0, 0, msg.strip())
return None
class File_Info:
'''
Holds information about vectorized and miss vectorized lines for one file
'''
def __init__(self):
self.infos = {}
self.total_opted =0
self.total_missed = 0
self.external = False
def add_line(self, line_pos):
if line_pos not in self.infos:
v = info()
v.optimized = 0
v.missed = 0
v.miss_details = set()
self.infos[line_pos] = v
return v
else:
return self.infos[line_pos]
def add_line_fsave(self, line_pos):
if line_pos not in self.infos:
v = info()
v.optimized = 0
v.missed = 0
v.miss_details2 = dict()
self.infos[line_pos] = v
return v
else:
return self.infos[line_pos]
def add_fsave(self, line_pos,success, msg, function ,inline_fns=''):
v = self.add_line_fsave(line_pos)
if success and "loop vectorized" in msg:
v.optimized +=1
self.total_opted +=1
elif success==False and "not vectorized:" in msg:
#reduce this msg
msg = msg.replace("not vectorized:","")
v.missed +=1
self.total_missed +=1
msg = sys.intern(msg)
if msg in v.miss_details2:
ls = v.miss_details2.get(msg)
ls.add(function)
else:
ls =set()
v.miss_details2[msg]=ls
ls.add(function)
return self
def add(self, line_pos, msg_x):
v = self.add_line(line_pos)
if msg_x is not None:
v.optimized += msg_x[0]
v.missed += msg_x[1]
self.total_opted += msg_x[0]
self.total_missed += msg_x[1]
if msg_x[2] is not None:
v.miss_details.add(msg_x[2])
return self
def __repr__(self):
return str(self.__dict__)
def process_gzip_json_mp(args):
process_gzip_json_new(*args)
def process_gzip_json_new(json_gz_fname,list_Queue):
gz_name = Path(json_gz_fname).stem
#print("::--open and process {0}".format(gz_name))
queue_count = len(list_Queue)
#print(queue_count)
q = list_Queue[0]
old_fname = ''
total_c = 0
for x in json_gzip_extract_objects(json_gz_fname,'message','vectorized'):
external_source = True
if len(x['message'])>0 and 'location' in x:
line = int(x['location']['line'])
file_name = x['location']['file'].strip()
if file_dir_strip in file_name:
file_name = file_name.replace(file_dir_strip,'./')
external_source = False
msg = x['message'][0]
success = x['kind'] == 'success'
func = '' if 'function' not in x else x['function']
if file_name!=old_fname:
#send our info to the right consumer
queue_ind = hash(file_name) % queue_count
#print("quen index {0}".format(queue_ind))
q =list_Queue[queue_ind]
old_fname = file_name
total_c +=1
#print("pp {0} {1}".format(q,(file_name,line,success, msg, func,external_source )))
if FSAVE_IGNORE_EXTERNALS==True and external_source == True:
continue
q.put((file_name,line,success, msg, func,external_source ))
print("::finished {0:60s} :{1:8d}".format(gz_name,total_c))
def consume_processed_mp(args):
return consume_processed_new(*args)
def consume_processed_new(list_Queue , c_index):
info_ = dict()
func_list = dict()
last_func_index = 0
q = list_Queue[c_index]
print("::consumer {0}".format(c_index))
total_c = 0
r_c = 0
while True:
#print("try to get new from {0}".format(index))
obj = q.get()
#print("cc {0} {1}".format(q,obj))
if obj==None:
break #we received the end
file_name,line,success, msg, func, external_source = obj
try:
#get function index
func_index = -1
if func in func_list:
func_index = func_list[func]
else:
func_list[func] = last_func_index
func_index = last_func_index
last_func_index +=1
if file_name in info_:
info_[file_name].add_fsave(line, success, msg, func_index)
else:
info_[file_name] = File_Info().add_fsave(line, success, msg, func_index)
info_[file_name].external = external_source
total_c +=1
if total_c - r_c >10000:
r_c = total_c
print("::consumer {0:2d} :{1:10d}".format(c_index,total_c))
except Exception as e:
print(traceback.format_exc())
break
print("::consumer {0:2d} :{1:10d}".format(c_index,total_c))
#write to temp file
wr_fname= "vecmiss_fsave{0}.html".format(str(c_index) if len(list_Queue)>1 else '')
print("generate report for consumer {0} {1}".format(c_index,len(info_)))
try:
uniq_ind = str(c_index)+'_' if len(list_Queue)>1 else ''
generate_report(wr_fname,info_ ,only_body = False, unique_id_prefix = uniq_ind,fsave_format = True, function_list= func_list)
print(" consumer {0} saved output into {1}".format(c_index,wr_fname))
except Exception as e:
print(traceback.format_exc())
def obtain_info_from(input_):
info_ = dict()
for line in input_:
x = mtch.match(line)
external_source = True
if x:
file_name =x.group(1).strip()
if file_dir_strip in file_name:
file_name = file_name.replace(file_dir_strip,'')
external_source = False
line_number = int(x.group(2))
msg = x.group(4).lower()
msg = msg.replace(file_dir_strip,'./')
msg_x = get_msg(msg)
if msg_x is None:
continue
if file_name in info_:
#ignore col_number
info_[file_name].add(line_number,msg_x)
else:
#print("{0} {1}".format(file_name,external_source))
info_[file_name] = File_Info().add(line_number,msg_x)
info_[file_name].external = external_source
elif progress_msg.match(line):
#actually we redirect only, stderr so this should not happen
print("__"+line.strip())
elif "error" in line or "Error" in line:
print("****"+line.strip())
return info_
def custom_style(fsave):
st = '''<style>a{color:blue;}
a:link{text-decoration:none}a:visited{text-decoration:none}a:hover{cursor:pointer;text-decoration:underline}
a:active{text-decoration:underline}
.f.ext{display:none}
.f{color:#000;display:flex;overflow:hidden;justify-content:space-between;flex-wrap:wrap;align-items:baseline;width:100%}
.f>div{min-width:10%}.f>div:first-child{min-width:70%;text-overflow:ellipsis}
.f:nth-of-type(even){background-color:#f5f5f5}
.f>div.g{flex:0 0 100%}.f>div:nth-child(2){font-weight:600;color:green}
.f>div:nth-child(3){font-weight:600;color:red}
.f>div:nth-child(2)::after{content:'';color:green}.f>div:nth-child(3)::after{content:' -';color:red}
.f>div.g>div>div:nth-child(2){font-weight:600;color:green}
.f>div.g>div>div:nth-child(3){font-weight:600;color:red}
.f>div.g>div>div:nth-child(2)::after{content:'';color:green}
.f>div.g>div>div:nth-child(3)::after{content:' -';color:red}
.f>div.g>div{display:flex;justify-content:space-between;flex-wrap:wrap;align-items:baseline}
.f>div.g>div>div{min-width:10%;text-align:left}
.g>div:nth-of-type(even){background-color:#ede6fa}
.f>div.g>div>ul{flex:0 0 100%}input[type=checkbox]{opacity:0;display:none}label{cursor:pointer}
.f>label{color:red}input[type=checkbox]~.g{display:none}input[type=checkbox]:checked~.g{display:block}
input[type=checkbox]~ul{display:none}
input[type=checkbox]:checked~ul{display:block}input[type=checkbox]+label::after{content:"";display:block}
input[type=checkbox]:checked+label::after{content:"";display:block}
'''
if fsave==True:
st+='''.modal{display:none;height:100%;background-color:#144F84;color:#fff;opacity:.93;left:0;position:fixed;top:0;width:100%}
.modal.open{display:flex;flex-direction:column}.modal__header{height:auto;font-size:large;padding:10px;background-color:#000;color:#fff}
.modal__footer{height:auto;font-size:medium;background-color:#000}
.modal__content{height:100%;display:flex;flex-direction:column;padding:20px;overflow-y:auto}
.modal_close{cursor:pointer;float:right}li{cursor:pointer}
'''
return st + '''</style>'''
def header(fsave=False):
strx ='<!DOCTYPE html>\n<html>\n<head>\n<meta charset="UTF-8">\n<title>Auto-Vectorization</title>\n'
strx +='<base id="base_id" href="{0}" target="_blank" >'.format(BASE_URL)
strx +=custom_style(fsave)
strx +='\n</head>\n<body>\n'
return strx
def footer():
return '\n</body></html>'
def get_compressed_indices(set_a):
a_len = len(set_a)
if a_len<=1:
if a_len<1:
return ''
return str(set_a)[1:-1]
#we sorted and only saved difference
# 1,14,15,19 --> 1,13,1,4 10bytes=>8bytes
list_sorted = sorted(list(set_a))
last = list_sorted[0]
str_x = str(list_sorted[0])
for i in range(1,a_len):
str_x += ','+str(list_sorted[i]-last)
last = list_sorted[i]
return str_x
def get_content(k, v, unique_id_prefix = '', fsave_format=False):
inner_str=''
content = ''
inc_id = 0
for fk,fv in sorted(v.infos.items()):
if fsave_format==True:
inner_str+='<div><div><a>{0}</a></div><div>{1}</div><div>{2}</div><input type="checkbox" id="c{3}{4}"><label for="c{3}{4}"></label><ul>'.format(
fk,fv.optimized,fv.missed,unique_id_prefix,inc_id)
else:
inner_str+='<div><div><a href=".{0}#L{1}">{1}</a></div><div>{2}</div><div>{3}</div><input type="checkbox" id="c{4}{5}"><label for="c{4}{5}"></label><ul>'.format(
k,fk,fv.optimized,fv.missed,unique_id_prefix,inc_id)
inc_id+=1
if fsave_format==True:
#
for dt,df in fv.miss_details2.items():
#inner_str +='<li data-fns="{0}">{1}</li>'.format(str(df).replace(", ",",")[1:-1],dt)
inner_str +='<li data-fns="{0}">{1}</li>'.format(get_compressed_indices(df),dt)
else:
for dt in fv.miss_details:
inner_str+="<li>"+str(dt)+ "</li>"
inner_str+="</ul></div>\n"
content += '<div class="f'
if v.external:
content += " ext"
content += '">\n<div>{0}</div><div>{1}</div><div>{2}</div><input type="checkbox" id="i{3}{4}"><label for="i{3}{4}"></label>'.format(
k,v.total_opted,v.total_missed,unique_id_prefix,inc_id)
content += "<div class='g'>"
content += inner_str
content += "</div> </div>\n"
return content
def jscript_head():
return '''
window.onload = function () {
var modal = document.getElementsByClassName("modal")[0];
var modal_close = document.getElementsByClassName("modal_close")[0];
var content = document.getElementsByClassName("modal__content")[0];
a_tags = document.getElementsByTagName("a");
base_href = document.getElementById("base_id").href;
for(i=0;i<a_tags.length;i++){
a_tags[i].addEventListener("click", function () {
var source = event.target || event.srcElement;
file_src = source.parentElement.parentElement.parentElement.parentElement.children[0].innerText ;
link = base_href + file_src+'#L'+ source.innerText;
window.open(link, '_blank');
});
}
modal_close.addEventListener("click", function () {
content.innerHTML = '';
modal.className = 'modal';
});
'''
def jscipt_end():
return '''
tags = document.getElementsByTagName("li");
function escapeHtml(unsafe) {
return unsafe
.replace(/&/g, "&amp;")
.replace(/</g, "&lt;")
.replace(/>/g, "&gt;")
.replace(/"/g, "&quot;")
.replace(/'/g, "&#039;");
}
for (i = 0; i < tags.length; i++) {
tags[i].addEventListener("click", function () {
var source = event.target || event.srcElement;
funcs = source.dataset.fns.split(",")
strx = ''
//we saved differences,not real indices
last_ind = 0;
for (j = 0; j < funcs.length; j++) {
ind = last_ind + parseInt(funcs[j]);
strx += "<p>" + escapeHtml(func_list[ind]) + "</p>";
last_ind = ind;
}
if (strx.length > 0) {
content.innerHTML = strx;
modal.className = 'modal open';
}
});
}
};'''
def additional_tags(fsave):
if fsave==False:
return ''
#
return '''<script type='text/javascript'>
var script = document.createElement('script'); script.src = window.location.href+".js" ;
document.head.appendChild(script);
</script>
<div class="modal">
<div class="modal__header">Functions <span class="modal_close">X</span></div>
<div class="modal__content"></div>
<div class="modal__footer">========</div>
</div>
'''
def generate_report(output_name,info_ ,only_body = False, unique_id_prefix='',fsave_format = False , function_list = None):
'''
Generate Auto-Vectorization Report in html format
'''
temp_str =''
if fsave_format ==True:
# we gonna dump function_list as key list sorted by value
#and use it as jscript array
sorted_funcs_by_index = sorted(function_list.items(), key=lambda x: x[1])
del function_list
with open(output_name+ ".js","w") as f:
#temp_str =jscript_head() +'{ "fmaps":['
temp_str = jscript_head() + "\n var func_list = ["
for k,v in sorted_funcs_by_index:
#json.dumps using for escape
#print(str(v)+str(k))
temp_str+=json.dumps(get_cxx_filt_result(k))+","
#reduce write calls
if len(temp_str)>8192*2:
f.write(temp_str)
temp_str= ''
if len(temp_str)>0:
f.write(temp_str)
f.write('"-"];'+jscipt_end())
temp_str = ''
with open(output_name,"w") as f:
if only_body==False:
f.write(header(fsave_format))
f.write(additional_tags(fsave_format))
nm=0
for k,v in sorted(info_.items()): # sorted(info_.items(), key=lambda x: x[1].total_opted, reverse=True):
temp_str += get_content(k,v,unique_id_prefix+str(nm),fsave_format)
#reduce io write calls
if len(temp_str)>8192:
f.write(temp_str)
temp_str =''
nm+=1
if len(temp_str)>0:
f.write(temp_str)
if only_body==False:
f.write(footer())
def fsave_report_launch(json_gz_list):
cpus = cpu_count()
if cpus>32:
cpus = 24
c_count = 1 # 2 i sufficient # if cpus<=1 else min(4,cpus)
p_count = 3 if cpus<=1 else max(8, cpus - c_count)
m = Manager()
#consumer Queues
list_Queue = [m.Queue() for index in range(0,c_count)]
with Pool(processes=c_count) as consumers:
#start consumers
cs = consumers.map_async(consume_processed_mp,[(list_Queue, index,) for index in range(0,c_count)])
with Pool(processes=p_count) as processors:
processors.map(process_gzip_json_mp, [(fname, list_Queue,) for fname in json_gz_list])
#send ends to inform our consumers
#send ends
for q in list_Queue:
q.put(None)
#wait for consumers
cs.wait()
def main():
if "--fsave" in sys.argv:
json_gz_list = internal_glob(".","*.json.gz")
fsave_report_launch(json_gz_list)
return
file_info = obtain_info_from(sys.stdin)
if len(file_info)>0:
#print(file_info)
print("---generating vectorization html report--")
generate_report("vecmiss.html", file_info)
else:
# lets check if we got fsave files
json_gz_list = internal_glob(".","*.json.gz")
fsave_report_launch(json_gz_list)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,354 @@
'''
@author : Abdelrauf rauf@konduit.ai
Simple object xtractor form very big json files
'''
import sys
from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
cdef char JSON_1 = b':'
cdef char JSON_2 = b','
cdef char JSON_3 = b'{'
cdef char JSON_4 = b'}'
cdef char JSON_5 = b'['
cdef char JSON_6 = b']'
cdef char QUOTE = b'"'
cdef char ESCAPE = b"\\"
cdef char SPACE = b' '
cdef char TAB = b't'
cdef char CR = b'\r'
cdef char NL = b'\n'
cdef char B = b'\b'
cdef char EMPTY = b'\0'
cdef struct Span:
int b
int e
cdef inline Span read_unquoted(char *text, int start,int end):
cdef Span sp
cdef int j = start
while j < end:
#if text[j].isspace():
if text[j] == SPACE or text[j] == NL or text[j] == TAB or text[j] == CR or text[j] == B:
j += 1
continue
if text[j] != QUOTE and text[j] != JSON_1 and text[j] != JSON_2 and text[j] != JSON_3 and text[j] != JSON_4 and text[j] != JSON_5 and text[j] != JSON_6:
start = j
j += 1
while j < end:
# read till JSON or white space
if text[j] == SPACE or text[j] == NL or text[j] == TAB or text[j] == CR or text[j] == B:
sp.b = start
sp.e = j
return sp
elif text[j] == JSON_1 or text[j] == JSON_2 or text[j] == JSON_3 or text[j] == JSON_4 or text[j] == JSON_5 or text[j] == JSON_6:
sp.b = start
sp.e = j
return sp
j += 1
if j == end-1:
sp.b = start
sp.e = end
return sp
break
sp.b = j
sp.e = j
return sp
cdef inline Span read_seq_token(char *text,int start,int end):
#read quoted
#skip white_space
cdef Span sp
cdef int j = start
cdef char last_char
cdef char char_x
while j < end:
if text[j] == SPACE or text[j] == NL or text[j] == TAB or text[j] == CR or text[j] == B:
j += 1
continue
if text[j] == QUOTE:
last_char = EMPTY
#read till another quote
start = j
j += 1
while j < end:
char_x = text[j]
if char_x == QUOTE and last_char != ESCAPE:
# finished reading
sp.b =start
sp.e = j+1
return sp
last_char = char_x
j += 1
if j == end-1:
sp.b = start
sp.e = end
return sp
else:
break
return read_unquoted(text, j, end)
def tokenizer_spans(utext):
'''
we will just return tokenize spans
'''
token_spans = []
last_char = b''
end_i = len(utext)
cdef char *text = utext
i = 0
cdef Span sp
while i < end_i:
sp = read_seq_token(text, i, end_i)
i = sp.e
if sp.e > sp.b:
token_spans.append((sp.b, sp.e))
if i < end_i:
#if text[i] in JSON:
if text[i] == JSON_3 or text[i] == JSON_4 or text[i] == JSON_5 or text[i] == JSON_6 or text[i] == JSON_1 or text[i] == JSON_2:
token_spans.append((i, i+1))
i += 1
return token_spans
cdef class JsonObjXtractor:
'''
JsonObjXtractor that utilize cython better
'''
cdef Span* token_spans
cdef size_t size
def __cinit__(self, size_t count=4096):
self.token_spans = <Span*> PyMem_Malloc(count * sizeof(Span))
self.size = count
if not self.token_spans:
raise MemoryError()
def __tokenizer_spans(self,utext, length):
'''
we will just return token spans length
'''
last_char = b''
end_i = length
cdef char *text = utext
cdef int i = 0
cdef size_t j = 0
cdef Span sp
while i < end_i:
sp = read_seq_token(text, i, end_i)
i = sp.e
if sp.e > sp.b:
self.token_spans[j] = sp
j+=1
if j>self.size:
#we need to reallocate
self.__resize(self.size+self.size//2)
if i < end_i:
#if text[i] in JSON:
if text[i] == JSON_3 or text[i] == JSON_4 or text[i] == JSON_5 or text[i] == JSON_6 or text[i] == JSON_1 or text[i] == JSON_2:
sp.b=i
sp.e=i+1
self.token_spans[j] = sp
j+=1
if j>self.size:
#we need to reallocate
self.__resize(self.size+self.size//2)
i += 1
return j
def try_extract_parent_obj(self, json_bytes, property_name, next_contains_value=b'', debug=False):
'''
try_extract_parent_obj(json_text, property_name, next_contains_value='', debug=False):
make sure that passed variables encoded to bytes with encode('utf-8')
next_contains_value either direct content or followed by '['
tries to extract the parent object for given named object
if the left brace of the parent object is outside of the current buffer
it will be ignored
if the right brace is outside of the buffer it will be left to be handled by caller
'''
look_for_the_left = True
parent_left = []
parent_right = []
parent_objects = []
len_next = len(next_contains_value)
cdef int ind = 0
cdef int end
cdef int last_start = 0
property_name = b'"'+property_name+b'"'
cdef int lenx = self.__tokenizer_spans(json_bytes,len(json_bytes))
cdef char x
cdef int i = -1
cdef Span sp
while i < lenx-1:
i += 1
ind = self.token_spans[i].b
x = json_bytes[ind]
#print("-----{0} -- {1} -- {2} ".format(x,parent_left,parent_right))
if look_for_the_left == False:
if x == JSON_3:
parent_right.append(ind)
elif x == JSON_4:
if len(parent_right) == 0:
#we found parent closing brace
look_for_the_left = True
parent_objects.append((parent_left[-1], ind+1))
last_start = ind+1
#print("=============found {0}".format(parent_objects))
parent_left = []
parent_right = []
else:
parent_right.pop()
continue
#search obj
if look_for_the_left:
if x == JSON_3:
parent_left.append(ind)
last_start = ind
elif x == JSON_4:
if len(parent_left) >= 1:
#ignore
parent_left.pop()
if x == JSON_1: # ':'
#check to see if propertyname
old_property = EMPTY
if i > 1:
sp = self.token_spans[i-1]
old_property = json_bytes[sp.b:sp.e]
if old_property == property_name:
#we found
if len(parent_left) < 1:
#left brace is outside of the buffer
#we have to ignore it
#try to increase buffer
if debug:
print('''left brace of the parent is outside of the buffer and parent is big.
it will be ignored
try to choose disambiguous property names if you are looking for small objects''', file=sys.stderr)
last_start = ind+1
parent_left = []
parent_right = []
continue
else:
#print("++++++ look for the right brace")
if len_next>0 and i+1 < lenx:
i += 1
ind = self.token_spans[i].b
end = self.token_spans[i].e
m = json_bytes[ind]
if m == JSON_5:
#print ("----{0} {1}".format(m,JSON_5))
if i+1 < lenx:
i += 1
ind = self.token_spans[i].b
end = self.token_spans[i].e
#print ("----{0} == {1}".format(next_contains_value,json_bytes[ind:end]))
if len_next <= end-ind and next_contains_value in json_bytes[ind:end]:
look_for_the_left = False
continue
elif len_next <= end-ind and next_contains_value in json_bytes[ind:end]:
look_for_the_left = False
continue
#ignore as it does not have that value
parent_left = []
parent_right = []
last_start = ind + 1
else:
look_for_the_left = False
# lets return last succesful opened brace as the last
# or left brace failure case, safe closed brace
if len(parent_left)>0:
return (parent_objects, parent_left[-1])
return (parent_objects, last_start)
def __resize(self, size_t new_count):
cdef Span* mem = <Span*> PyMem_Realloc(self.token_spans, new_count * sizeof(Span))
if not mem:
raise MemoryError()
self.token_spans = mem
self.size = new_count
def __dealloc__(self):
PyMem_Free(self.token_spans)
import json
import gzip
import sys
DEBUG_LOG = False
def json_gzip_extract_objects(filename, property_name, next_contains_value=''):
strx = b''
started= False
b_next_contains_value = next_contains_value.encode('utf-8')
b_property_name = property_name.encode('utf-8')
#print(b_property_name)
objXt = JsonObjXtractor()
with gzip.open(filename, 'rb') as f:
if DEBUG_LOG:
print("opened {0}".format(filename), file=sys.stderr)
#instead of reading it as line, I will read it as binary bytes
is_End = False
#total = 0
while is_End==False:
buffer = f.read(8192*2)
lenx= len(buffer)
#total +=lenx
if lenx<1:
is_End = True
else:
strx = strx + buffer
objects , last_index = objXt.try_extract_parent_obj(strx,b_property_name,b_next_contains_value)
# if b_property_name in strx and b_next_contains_value in strx:
# print(strx)
# print(objects)
# print(last_index)
# print("===================================================")
for start,end in objects:
yield json.loads(strx[start:end]) #.decode('utf-8'))
#remove processed
if last_index< len(strx):
strx = strx[last_index:]
else:
strx = b''
#print('----+++')
if(len(strx)>16384*3):
#buffer to big
#try to avoid big parents
if DEBUG_LOG:
print("parent object is too big. please, look for better property name", file=sys.stderr)
break

View File

@ -0,0 +1,3 @@
from distutils.core import setup
from Cython.Build import cythonize
setup(ext_modules=cythonize("bigGzipJson.pyx", language_level="3"))

View File

@ -282,6 +282,32 @@ elseif(CPU_BLAS)
set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic") set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
endif() endif()
if(CHECK_VECTORIZATION)
set(VECT_FILES cpu/NativeOps.cpp ${OPS_SOURCES} ${HELPERS_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${LOOPS_SOURCES})
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
if (CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
set(CHECK_VECT_FLAGS "-ftree-vectorize -fsave-optimization-record")
#to process fsave-optimization-record we will need our cython version code
message("Build Auto vectorization helpers")
execute_process(COMMAND "python3" "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/cython_setup.py" "build_ext" "--inplace" WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/" RESULT_VARIABLE ret)
message("build='${ret}'")
#remove fail cases that gcc fails produce sometimes
file(GLOB_RECURSE FAILURE_CASES false ../include/loops/cpu/compilation_units/reduce3*.cpp)
#message("*****${FAILURE_CASES}")
foreach(FL_ITEM ${FAILURE_CASES})
message("Removing failure cases ${FL_ITEM}")
list(REMOVE_ITEM VECT_FILES ${FL_ITEM})
endforeach()
else()
set(CHECK_VECT_FLAGS "-ftree-vectorize -fopt-info-vec-optimized-missed")
endif()
message("CHECK VECTORIZATION ${CHECK_VECT_FLAGS}")
set_source_files_properties( ${VECT_FILES} PROPERTIES COMPILE_FLAGS "${CHECK_VECT_FLAGS}" )
endif()
endif()
message("CPU BLAS") message("CPU BLAS")
add_definitions(-D__CPUBLAS__=true) add_definitions(-D__CPUBLAS__=true)
add_library(nd4jobj OBJECT cpu/NativeOps.cpp cpu/GraphExecutioner.cpp add_library(nd4jobj OBJECT cpu/NativeOps.cpp cpu/GraphExecutioner.cpp

View File

@ -195,6 +195,56 @@ namespace nd4j {
NDArray(std::shared_ptr<DataBuffer> buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); NDArray(std::shared_ptr<DataBuffer> buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
/**
* This contructors create scalar array containing string utf8
*
*/
NDArray(const char* str, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext())
: NDArray(std::string(str), dtype, context) {
}
NDArray(const std::string& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
/**
* This contructors create scalar array containing string utf16
*
*/
NDArray(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext())
: NDArray(std::u16string(u16string), dtype, context) {
}
NDArray(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
/**
* This contructors create scalar array containing string utf32
*
*/
NDArray(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext())
: NDArray(std::u32string(u32string), dtype, context) {
}
NDArray(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
/**
* This contructors create array from vector of utf8 strings
*
*/
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::string>& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
/**
* This contructors create array from vector of utf16 strings
*
*/
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
/**
* This contructors create array from vector of utf32 strings
*
*/
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
#endif #endif
/** /**
@ -250,7 +300,6 @@ namespace nd4j {
*/ */
NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false); NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
/** /**
* This method returns new array with the same shape & data type * This method returns new array with the same shape & data type
* @return * @return
@ -1148,6 +1197,9 @@ namespace nd4j {
template <typename N> template <typename N>
NDArray asT() const; NDArray asT() const;
template <typename S>
NDArray asS() const;
NDArray asT(DataType dtype) const; NDArray asT(DataType dtype) const;

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019-2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,6 +17,7 @@
// //
// Created by raver119 on 2018-09-16. // Created by raver119 on 2018-09-16.
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
// //
#ifndef DEV_TESTS_NDARRAYFACTORY_H #ifndef DEV_TESTS_NDARRAYFACTORY_H
@ -106,25 +108,72 @@ namespace nd4j {
template <typename T> template <typename T>
static NDArray create(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<T>& data, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); static NDArray create(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<T>& data, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray string(const char *string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); /**
* This factory create array from utf8 string
* @return NDArray default dataType UTF8
*/
static NDArray string(const char *string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_(const char *string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_(const std::string &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray string(const std::string& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_(const char *string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); /**
* This factory create array from utf16 string
* @return NDArray default dataType UTF16
*/
static NDArray string(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string(const std::string &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); /**
* This factory create array from utf32 string
* @return NDArray default dataType UTF32
*/
static NDArray string(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_(const std::string &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); /**
* This factory create array from vector of utf8 strings
* @return NDArray default dataType UTF8
*/
static NDArray string( const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray string( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray string( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray string( const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray string(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); /**
static NDArray string(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); * This factory create array from vector of utf16 strings
* @return NDArray default dataType UTF16
*/
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string(char order, const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); /**
static NDArray string(char order, const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); * This factory create array from vector of utf32 strings
* @return NDArray default dataType UTF32
*/
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_(char order, const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_(char order, const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static ResultSet createSetOfArrs(const Nd4jLong numOfArrs, const void* buffer, const Nd4jLong* shapeInfo, const Nd4jLong* offsets, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); static ResultSet createSetOfArrs(const Nd4jLong numOfArrs, const void* buffer, const Nd4jLong* shapeInfo, const Nd4jLong* offsets, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());

View File

@ -1518,7 +1518,7 @@ ND4J_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPoi
typedef nd4j::ShapeList OpaqueShapeList; typedef nd4j::ShapeList OpaqueShapeList;
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs); ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs);
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs); ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs);
ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list); ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list);
ND4J_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i); ND4J_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i);
@ -1607,6 +1607,7 @@ ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *
ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
ND4J_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); ND4J_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
ND4J_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); ND4J_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
ND4J_EXPORT void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments);
ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments); ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments);
ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments); ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments);
ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments); ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments);

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019-2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,6 +17,7 @@
// //
// Created by GS <sgazeos@gmail.com> on 2018-12-20. // Created by GS <sgazeos@gmail.com> on 2018-12-20.
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
// //
#include <NDArrayFactory.h> #include <NDArrayFactory.h>
@ -25,6 +27,9 @@
#include <ShapeUtils.h> #include <ShapeUtils.h>
#include <type_traits> #include <type_traits>
#include <StringUtils.h>
namespace nd4j { namespace nd4j {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -85,45 +90,6 @@ namespace nd4j {
template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<uint8_t>& data, nd4j::LaunchContext * context); template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<uint8_t>& data, nd4j::LaunchContext * context);
template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool>& data, nd4j::LaunchContext * context); template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool>& data, nd4j::LaunchContext * context);
NDArray NDArrayFactory::string(const char *str, nd4j::LaunchContext * context) {
std::string s(str);
return string(s, context);
}
NDArray* NDArrayFactory::string_(const char *str, nd4j::LaunchContext * context) {
return string_(std::string(str), context);
}
NDArray NDArrayFactory::string(const std::string &str, nd4j::LaunchContext * context) {
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(1);
std::shared_ptr<DataBuffer> pBuffer = std::make_shared<DataBuffer>(headerLength + str.length(), DataType::UTF8, context->getWorkspace(), true);
NDArray res(pBuffer, ShapeDescriptor::scalarDescriptor(DataType::UTF8), context);
int8_t* buffer = reinterpret_cast<int8_t*>(res.getBuffer());
auto offsets = reinterpret_cast<Nd4jLong *>(buffer);
offsets[0] = 0;
offsets[1] = str.length();
auto data = buffer + headerLength;
memcpy(data, str.c_str(), str.length());
res.tickWriteHost();
res.syncToDevice();
return res;
}
NDArray* NDArrayFactory::string_(const std::string &str, nd4j::LaunchContext * context) {
auto res = new NDArray();
*res = NDArrayFactory::string(str, context);
return res;
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext * context) { NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext * context) {
@ -551,91 +517,175 @@ template ND4J_EXPORT NDArray NDArrayFactory::create(uint8_t * buffer, const char
template ND4J_EXPORT NDArray NDArrayFactory::create(int8_t* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context); template ND4J_EXPORT NDArray NDArrayFactory::create(int8_t* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context);
template ND4J_EXPORT NDArray NDArrayFactory::create(int16_t* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context); template ND4J_EXPORT NDArray NDArrayFactory::create(int16_t* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context);
/////////////////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::LaunchContext * context) { NDArray NDArrayFactory::string(const char16_t* u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
std::vector<const char*> vec(strings); return NDArray(u16string, dtype, context);
return NDArrayFactory::string(order, shape, vec, context);
} }
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(char order, const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::LaunchContext * context) { NDArray* NDArrayFactory::string_(const char16_t* u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
std::vector<std::string> vec(strings.size()); return string_(std::u16string(u16string), dtype, context);
int cnt = 0;
for (auto s:strings)
vec[cnt++] = std::string(s);
return NDArrayFactory::string(order, shape, vec, context);
} }
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_(const std::u16string& u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
NDArray NDArrayFactory::string(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::LaunchContext * context) {
std::vector<std::string> vec(string);
return NDArrayFactory::string(order, shape, vec, context);
}
NDArray* NDArrayFactory::string_(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::LaunchContext * context) {
std::vector<const char*> vec(strings);
return NDArrayFactory::string_(order, shape, vec, context);
}
NDArray* NDArrayFactory::string_(char order, const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::LaunchContext * context) {
std::vector<std::string> vec(strings.size());
int cnt = 0;
for (auto s:strings)
vec[cnt++] = std::string(s);
return NDArrayFactory::string_(order, shape, vec, context);
}
NDArray* NDArrayFactory::string_(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::LaunchContext * context) {
std::vector<std::string> vec(string);
return NDArrayFactory::string_(order, shape, vec, context);
}
NDArray NDArrayFactory::string(char order, const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::LaunchContext * context) {
if (context == nullptr)
context = nd4j::LaunchContext ::defaultContext();
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size());
std::vector<Nd4jLong> offsets(string.size() + 1);
Nd4jLong dataLength = 0;
for (int e = 0; e < string.size(); e++) {
offsets[e] = dataLength;
dataLength += string[e].length();
}
offsets[string.size()] = dataLength;
std::shared_ptr<DataBuffer> pBuffer = std::make_shared<DataBuffer>(headerLength + dataLength, DataType::UTF8, context->getWorkspace(), true);
NDArray res(pBuffer, ShapeDescriptor(DataType::UTF8, order, shape), context);
res.setAttached(context->getWorkspace() != nullptr);
if (res.lengthOf() != string.size())
throw std::invalid_argument("Number of strings should match length of array");
memcpy(res.buffer(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
auto data = static_cast<int8_t*>(res.buffer()) + headerLength;
int resLen = res.lengthOf();
for (int e = 0; e < resLen; e++) {
auto length = offsets[e+1] - offsets[e];
auto cdata = data + offsets[e];
memcpy(cdata, string[e].c_str(), string[e].length());
}
res.tickWriteHost();
res.syncToDevice();
return res;
}
NDArray* NDArrayFactory::string_(char order, const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::LaunchContext * context) {
auto res = new NDArray(); auto res = new NDArray();
*res = NDArrayFactory::string(order, shape, string, context); *res = NDArray(u16string, dtype, context);
return res; return res;
} }
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(const std::u16string& u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return NDArray(u16string, dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(const char32_t* u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return NDArray(u32string, dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_(const char32_t* u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return string_(std::u32string(u32string), dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_(const std::u32string& u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
auto res = new NDArray();
*res = NDArray(u32string, dtype, context);
return res;
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(const std::u32string& u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return NDArray(u32string, dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(const char* str, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return NDArray(str, dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_(const char* str, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return string_(std::string(str), dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_(const std::string& str, nd4j::DataType dtype, nd4j::LaunchContext* context) {
auto res = new NDArray();
*res = NDArray(str, dtype, context);
return res;
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(const std::string& str, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return NDArray(str, dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) {
return NDArray(shape, std::vector<const char*>(strings), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) {
return NDArray( shape, strings, dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dataType, nd4j::LaunchContext * context) {
return NDArray( shape, std::vector<std::string>(string), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) {
return NDArrayFactory::string_( shape, std::vector<const char*>(strings), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) {
std::vector<std::string> vec(strings.size());
int cnt = 0;
for (auto s:strings)
vec[cnt++] = std::string(s);
return NDArrayFactory::string_( shape, vec, dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dataType, nd4j::LaunchContext * context) {
return NDArrayFactory::string_( shape, std::vector<std::string>(string), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string( const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dataType, nd4j::LaunchContext * context) {
return NDArray(shape, string, dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_(const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dataType, nd4j::LaunchContext * context) {
auto res = new NDArray();
*res = NDArray( shape, string, dataType, context);
return res;
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArray( shape, std::vector<const char16_t*>(strings), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArray( shape, strings, dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArray( shape, std::vector<std::u16string>(string), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArrayFactory::string_( shape, std::vector<const char16_t*>(strings), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
std::vector<std::u16string> vec(strings.size());
int cnt = 0;
for (auto s : strings)
vec[cnt++] = std::u16string(s);
return NDArrayFactory::string_( shape, vec, dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArrayFactory::string_( shape, std::vector<std::u16string>(string), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
auto res = new NDArray();
*res = NDArray( shape, string, dataType, context);
return res;
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return NDArray( shape, string, dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArray( shape, std::vector<const char32_t*>(strings), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArray( shape, strings, dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArray(shape, std::vector<std::u32string>(string), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArrayFactory::string_( shape, std::vector<const char32_t*>(strings), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
std::vector<std::u32string> vec(strings.size());
int cnt = 0;
for (auto s : strings)
vec[cnt++] = std::u32string(s);
return NDArrayFactory::string_( shape, vec, dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArrayFactory::string_( shape, std::vector<std::u32string>(string), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
auto res = new NDArray();
*res = NDArray( shape, string, dataType, context);
return res;
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return NDArray( shape, string, dtype, context);
}
} }

View File

@ -1974,7 +1974,7 @@ void deleteShapeList(Nd4jPointer shapeList) {
delete list; delete list;
} }
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) { nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
nd4j::graph::VariableSpace varSpace; nd4j::graph::VariableSpace varSpace;
Context block(2, &varSpace); Context block(2, &varSpace);
nd4j::ShapeList inShapes; nd4j::ShapeList inShapes;
@ -1988,6 +1988,9 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
for (int e = 0; e < numBArgs; e++) for (int e = 0; e < numBArgs; e++)
block.getBArguments()->push_back(bArgs[e]); block.getBArguments()->push_back(bArgs[e]);
for (int e = 0; e < numDArgs; e++)
block.getDArguments()->push_back((nd4j::DataType) dArgs[e]);
for (int e = 0; e < numInputShapes; e++) { for (int e = 0; e < numInputShapes; e++) {
auto shape_ = reinterpret_cast<Nd4jLong *>(inputShapes[e]); auto shape_ = reinterpret_cast<Nd4jLong *>(inputShapes[e]);
@ -2015,11 +2018,11 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
return shapeList; return shapeList;
} }
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) { nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
try { try {
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash);
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs); return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs);
} catch (std::exception &e) { } catch (std::exception &e) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
@ -2130,7 +2133,7 @@ Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4
biArgs[e] = bArgs[e]; biArgs[e] = bArgs[e];
// hypothetically at this point we have everything filled // hypothetically at this point we have everything filled
auto hZ = op->execute(inputs, outputs, ttArgs, iiArgs, biArgs, isInplace); auto hZ = op->execute(inputs, outputs, ttArgs, iiArgs, biArgs, std::vector<nd4j::DataType>(), isInplace);
//auto hZ = op->execute(inputs, ttArgs, iiArgs, isInplace); //auto hZ = op->execute(inputs, ttArgs, iiArgs, isInplace);
@ -2788,6 +2791,15 @@ void setGraphContextIArguments(nd4j::graph::Context* ptr, Nd4jLong *arguments, i
void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) { void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) {
ptr->setBArguments(arguments, numberOfArguments); ptr->setBArguments(arguments, numberOfArguments);
} }
void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments) {
std::vector<nd4j::DataType> dtypes(numberOfArguments);
for (int e = 0; e < numberOfArguments; e++)
dtypes[e] = (nd4j::DataType) arguments[e];
ptr->setDArguments(dtypes);
}
void deleteGraphContext(nd4j::graph::Context* ptr) { void deleteGraphContext(nd4j::graph::Context* ptr) {
delete ptr; delete ptr;
} }

View File

@ -2684,7 +2684,7 @@ const char* getAllCustomOps() {
} }
nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) { nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
nd4j::graph::VariableSpace varSpace; nd4j::graph::VariableSpace varSpace;
Context block(2, &varSpace); Context block(2, &varSpace);
nd4j::ShapeList inShapes; nd4j::ShapeList inShapes;
@ -2698,6 +2698,9 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
for (int e = 0; e < numBArgs; e++) for (int e = 0; e < numBArgs; e++)
block.getBArguments()->push_back(bArgs[e]); block.getBArguments()->push_back(bArgs[e]);
for (int e = 0; e < numDArgs; e++)
block.getDArguments()->push_back((nd4j::DataType) dArgs[e]);
for (int e = 0; e < numInputShapes; e++) { for (int e = 0; e < numInputShapes; e++) {
auto shape_ = reinterpret_cast<Nd4jLong *>(inputShapes[e]); auto shape_ = reinterpret_cast<Nd4jLong *>(inputShapes[e]);
@ -2722,12 +2725,12 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
return shapeList; return shapeList;
} }
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) { nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
try { try {
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash);
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs,
iArgs, numIArgs, bArgs, numBArgs); iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs);
} catch (std::exception &e) { } catch (std::exception &e) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
@ -2831,7 +2834,7 @@ static FORCEINLINE Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer*
// hypothetically at this point we have everything filled // hypothetically at this point we have everything filled
auto dZ = op->execute(inputs, outputs, ttArgs, iiArgs, bbArgs, isInplace); auto dZ = op->execute(inputs, outputs, ttArgs, iiArgs, bbArgs, std::vector<nd4j::DataType>(), isInplace);
//auto dZ = op->execute(inputs, ttArgs, iiArgs, isInplace); //auto dZ = op->execute(inputs, ttArgs, iiArgs, isInplace);
@ -3596,6 +3599,14 @@ void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int n
ptr->setBArguments(arguments, numberOfArguments); ptr->setBArguments(arguments, numberOfArguments);
} }
void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments) {
std::vector<nd4j::DataType> dtypes(numberOfArguments);
for (int e = 0; e < numberOfArguments; e++)
dtypes[e] = (nd4j::DataType) arguments[e];
ptr->setDArguments(dtypes);
}
void deleteGraphContext(nd4j::graph::Context* ptr) { void deleteGraphContext(nd4j::graph::Context* ptr) {
delete ptr; delete ptr;
} }

View File

@ -55,6 +55,7 @@ TESTS="false"
VERBOSE="false" VERBOSE="false"
VERBOSE_ARG="VERBOSE=1" VERBOSE_ARG="VERBOSE=1"
HELPER= HELPER=
CHECK_VECTORIZATION="OFF"
NAME= NAME=
while [[ $# > 0 ]] while [[ $# > 0 ]]
do do
@ -114,6 +115,9 @@ case $key in
NAME="$value" NAME="$value"
shift # past argument shift # past argument
;; ;;
--check-vectorization)
CHECK_VECTORIZATION="ON"
;;
-j) -j)
MAKEJ="$value" MAKEJ="$value"
shift # past argument shift # past argument
@ -528,14 +532,27 @@ echo MINIFIER = "${MINIFIER_ARG}"
echo TESTS = "${TESTS_ARG}" echo TESTS = "${TESTS_ARG}"
echo NAME = "${NAME_ARG}" echo NAME = "${NAME_ARG}"
echo OPENBLAS_PATH = "$OPENBLAS_PATH" echo OPENBLAS_PATH = "$OPENBLAS_PATH"
echo CHECK_VECTORIZATION = "$CHECK_VECTORIZATION"
echo HELPERS = "$HELPERS" echo HELPERS = "$HELPERS"
mkbuilddir mkbuilddir
pwd pwd
eval $CMAKE_COMMAND "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" $HELPERS "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../.. eval $CMAKE_COMMAND "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" -DCHECK_VECTORIZATION="${CHECK_VECTORIZATION}" $HELPERS "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../..
if [ "$PARALLEL" == "true" ]; then if [ "$PARALLEL" == "true" ]; then
MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ" MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ"
fi fi
if [ "$VERBOSE" == "true" ]; then if [ "$VERBOSE" == "true" ]; then
MAKE_ARGUMENTS="$MAKE_ARGUMENTS $VERBOSE_ARG" MAKE_ARGUMENTS="$MAKE_ARGUMENTS $VERBOSE_ARG"
fi fi
if [ "$CHECK_VECTORIZATION" == "ON" ]; then
if [ "$MAKE_COMMAND" == "make" ]; then
MAKE_ARGUMENTS="$MAKE_ARGUMENTS --output-sync=target"
fi
exec 3>&1
eval $MAKE_COMMAND $MAKE_ARGUMENTS 2>&1 >&3 3>&- | python3 ../../auto_vectorization/auto_vect.py && cd ../../..
exec 3>&-
else
eval $MAKE_COMMAND $MAKE_ARGUMENTS && cd ../../.. eval $MAKE_COMMAND $MAKE_ARGUMENTS && cd ../../..
fi

View File

@ -95,6 +95,10 @@ namespace nd4j {
template<typename T> template<typename T>
// struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value || std::is_same<long long, T>::value; }; // struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value || std::is_same<long long, T>::value; };
struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<unsigned int, T>::value || std::is_same<long long, T>::value || std::is_same<unsigned long long, T>::value || std::is_same<long int, T>::value || std::is_same<long unsigned int, T>::value || std::is_same<int8_t, T>::value || std::is_same<uint8_t, T>::value || std::is_same<int16_t, T>::value || std::is_same<uint16_t, T>::value || std::is_same<bool, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value; }; struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<unsigned int, T>::value || std::is_same<long long, T>::value || std::is_same<unsigned long long, T>::value || std::is_same<long int, T>::value || std::is_same<long unsigned int, T>::value || std::is_same<int8_t, T>::value || std::is_same<uint8_t, T>::value || std::is_same<int16_t, T>::value || std::is_same<uint16_t, T>::value || std::is_same<bool, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value; };
template<typename T>
struct scalarTypesForExecution { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<Nd4jLong, T>::value || std::is_same<int, T>::value || std::is_same<bool, T>::value; };
}; };
@ -118,7 +122,7 @@ namespace nd4j {
} }
FORCEINLINE bool DataTypeUtils::isS(nd4j::DataType dataType) { FORCEINLINE bool DataTypeUtils::isS(nd4j::DataType dataType) {
return dataType == nd4j::DataType::UTF8; return dataType == nd4j::DataType::UTF8 || dataType == nd4j::DataType::UTF16 || dataType == nd4j::DataType::UTF32;
} }
FORCEINLINE bool DataTypeUtils::isZ(nd4j::DataType dataType) { FORCEINLINE bool DataTypeUtils::isZ(nd4j::DataType dataType) {
@ -366,6 +370,10 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
return std::string("UINT64"); return std::string("UINT64");
case UTF8: case UTF8:
return std::string("UTF8"); return std::string("UTF8");
case UTF16:
return std::string("UTF16");
case UTF32:
return std::string("UTF32");
default: default:
throw std::runtime_error("Unknown data type used"); throw std::runtime_error("Unknown data type used");
} }
@ -427,6 +435,8 @@ FORCEINLINE _CUDA_HD T DataTypeUtils::eps() {
case nd4j::DataType::UINT16: return (size_t) 2; case nd4j::DataType::UINT16: return (size_t) 2;
case nd4j::DataType::UTF8: case nd4j::DataType::UTF8:
case nd4j::DataType::UTF16:
case nd4j::DataType::UTF32:
case nd4j::DataType::INT32: case nd4j::DataType::INT32:
case nd4j::DataType::UINT32: case nd4j::DataType::UINT32:
case nd4j::DataType::HALF2: case nd4j::DataType::HALF2:
@ -451,6 +461,10 @@ FORCEINLINE _CUDA_HD T DataTypeUtils::eps() {
return nd4j::DataType::BOOL; return nd4j::DataType::BOOL;
} else if (std::is_same<T, std::string>::value) { } else if (std::is_same<T, std::string>::value) {
return nd4j::DataType::UTF8; return nd4j::DataType::UTF8;
} else if (std::is_same<T, std::u16string>::value) {
return nd4j::DataType::UTF16;
} else if (std::is_same<T, std::u32string>::value) {
return nd4j::DataType::UTF32;
} else if (std::is_same<T, float>::value) { } else if (std::is_same<T, float>::value) {
return nd4j::DataType::FLOAT32; return nd4j::DataType::FLOAT32;
} else if (std::is_same<T, float16>::value) { } else if (std::is_same<T, float16>::value) {

View File

@ -158,7 +158,7 @@ namespace nd4j {
iargs.push_back(_axis); iargs.push_back(_axis);
auto result = op.execute(inputs, {}, {}, {}); auto result = op.evaluate(inputs);
auto array = new NDArray(result->at(0)->dup()); auto array = new NDArray(result->at(0)->dup());

View File

@ -197,10 +197,12 @@ namespace nd4j {
void setTArguments(double *arguments, int numberOfArguments); void setTArguments(double *arguments, int numberOfArguments);
void setIArguments(Nd4jLong *arguments, int numberOfArguments); void setIArguments(Nd4jLong *arguments, int numberOfArguments);
void setBArguments(bool *arguments, int numberOfArguments); void setBArguments(bool *arguments, int numberOfArguments);
void setDArguments(nd4j::DataType *arguments, int numberOfArguments);
void setTArguments(const std::vector<double> &tArgs); void setTArguments(const std::vector<double> &tArgs);
void setIArguments(const std::vector<Nd4jLong> &tArgs); void setIArguments(const std::vector<Nd4jLong> &tArgs);
void setBArguments(const std::vector<bool> &tArgs); void setBArguments(const std::vector<bool> &tArgs);
void setDArguments(const std::vector<nd4j::DataType> &dArgs);
void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer); void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer);

View File

@ -47,6 +47,9 @@ namespace nd4j {
std::vector<int> _iArgs; std::vector<int> _iArgs;
std::vector<bool> _bArgs; std::vector<bool> _bArgs;
std::vector<int> _axis; std::vector<int> _axis;
std::vector<nd4j::DataType> _dArgs;
// TODO: remove this field
nd4j::DataType _dataType = nd4j::DataType::FLOAT32; nd4j::DataType _dataType = nd4j::DataType::FLOAT32;
bool _isInplace; bool _isInplace;
@ -93,6 +96,7 @@ namespace nd4j {
std::vector<double>* getTArguments(); std::vector<double>* getTArguments();
std::vector<int>* getIArguments(); std::vector<int>* getIArguments();
std::vector<bool>* getBArguments(); std::vector<bool>* getBArguments();
std::vector<nd4j::DataType>* getDArguments();
std::vector<int>* getAxis(); std::vector<int>* getAxis();
samediff::Engine engine(); samediff::Engine engine();
@ -100,6 +104,7 @@ namespace nd4j {
size_t numT(); size_t numT();
size_t numI(); size_t numI();
size_t numB(); size_t numB();
size_t numD();
std::pair<int, int>* input(int idx); std::pair<int, int>* input(int idx);

View File

@ -38,7 +38,9 @@ namespace nd4j {
class ND4J_EXPORT Node { class ND4J_EXPORT Node {
protected: protected:
// TODO: this field must be removed
nd4j::DataType _dataType; nd4j::DataType _dataType;
OpType _opType; OpType _opType;
ContextPrototype* _protoContext = nullptr; ContextPrototype* _protoContext = nullptr;
Nd4jLong _opNum; Nd4jLong _opNum;
@ -61,6 +63,7 @@ namespace nd4j {
// optional scalar. used in scalar ops and in summary stats // optional scalar. used in scalar ops and in summary stats
// TODO: this field must be removed
NDArray _scalar; NDArray _scalar;
bool _hasExternalOutputs; bool _hasExternalOutputs;
@ -87,15 +90,15 @@ namespace nd4j {
int _scope_id = 0; int _scope_id = 0;
std::string _scope_name; std::string _scope_name;
// TODO: these 3 fields should be removed
int _rewindNode = -1; int _rewindNode = -1;
std::pair<int, int> _rewindLayer = {-1, -1}; std::pair<int, int> _rewindLayer = {-1, -1};
Nd4jLong _frameId = -1; Nd4jLong _frameId = -1;
public: public:
Node(nd4j::ops::DeclarableOp *customOp, int id = 0, std::initializer_list<int> input = {}, std::initializer_list<int> output = {}, std::initializer_list<int> dimensions = {}, float scalar = 0.0f, std::initializer_list<double> tArgs = {}, std::initializer_list<int> iArgs = {}); explicit Node(nd4j::ops::DeclarableOp *customOp, int id = 0, std::initializer_list<int> input = {}, std::initializer_list<int> output = {}, std::initializer_list<int> dimensions = {}, float scalar = 0.0f, std::initializer_list<double> tArgs = {}, std::initializer_list<int> iArgs = {});
Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, int id = 0, std::initializer_list<int> input = {}, std::initializer_list<int> output = {}, std::initializer_list<int> dimensions = {}, float scalar = 0.0f, std::initializer_list<double> tArgs = {}, std::initializer_list<int> iArgs = {}); explicit Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, int id = 0, std::initializer_list<int> input = {}, std::initializer_list<int> output = {}, std::initializer_list<int> dimensions = {}, float scalar = 0.0f, std::initializer_list<double> tArgs = {}, std::initializer_list<int> iArgs = {});
Node(const nd4j::graph::FlatNode *node); explicit Node(const nd4j::graph::FlatNode *node);
~Node(); ~Node();
bool equals(Node *other); bool equals(Node *other);

View File

@ -60,11 +60,13 @@ enum DType {
DType_QINT16 = 16, DType_QINT16 = 16,
DType_BFLOAT16 = 17, DType_BFLOAT16 = 17,
DType_UTF8 = 50, DType_UTF8 = 50,
DType_UTF16 = 51,
DType_UTF32 = 52,
DType_MIN = DType_INHERIT, DType_MIN = DType_INHERIT,
DType_MAX = DType_UTF8 DType_MAX = DType_UTF32
}; };
inline const DType (&EnumValuesDType())[19] { inline const DType (&EnumValuesDType())[21] {
static const DType values[] = { static const DType values[] = {
DType_INHERIT, DType_INHERIT,
DType_BOOL, DType_BOOL,
@ -84,7 +86,9 @@ inline const DType (&EnumValuesDType())[19] {
DType_QINT8, DType_QINT8,
DType_QINT16, DType_QINT16,
DType_BFLOAT16, DType_BFLOAT16,
DType_UTF8 DType_UTF8,
DType_UTF16,
DType_UTF32
}; };
return values; return values;
} }
@ -142,6 +146,8 @@ inline const char * const *EnumNamesDType() {
"", "",
"", "",
"UTF8", "UTF8",
"UTF16",
"UTF32",
nullptr nullptr
}; };
return names; return names;

View File

@ -42,7 +42,9 @@ nd4j.graph.DType = {
QINT8: 15, QINT8: 15,
QINT16: 16, QINT16: 16,
BFLOAT16: 17, BFLOAT16: 17,
UTF8: 50 UTF8: 50,
UTF16: 51,
UTF32: 52
}; };
/** /**

View File

@ -26,6 +26,8 @@ public enum DType : sbyte
QINT16 = 16, QINT16 = 16,
BFLOAT16 = 17, BFLOAT16 = 17,
UTF8 = 50, UTF8 = 50,
UTF16 = 51,
UTF32 = 52,
}; };

View File

@ -23,8 +23,10 @@ public final class DType {
public static final byte QINT16 = 16; public static final byte QINT16 = 16;
public static final byte BFLOAT16 = 17; public static final byte BFLOAT16 = 17;
public static final byte UTF8 = 50; public static final byte UTF8 = 50;
public static final byte UTF16 = 51;
public static final byte UTF32 = 52;
public static final String[] names = { "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "UTF8", }; public static final String[] names = { "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "UTF8", "UTF16", "UTF32", };
public static String name(int e) { return names[e]; } public static String name(int e) { return names[e]; }
} }

View File

@ -22,4 +22,6 @@ class DType(object):
QINT16 = 16 QINT16 = 16
BFLOAT16 = 17 BFLOAT16 = 17
UTF8 = 50 UTF8 = 50
UTF16 = 51
UTF32 = 52

View File

@ -112,6 +112,14 @@ public struct FlatNode : IFlatbufferObject
public int VarControlDepsLength { get { int o = __p.__offset(44); return o != 0 ? __p.__vector_len(o) : 0; } } public int VarControlDepsLength { get { int o = __p.__offset(44); return o != 0 ? __p.__vector_len(o) : 0; } }
public string ControlDepFor(int j) { int o = __p.__offset(46); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; } public string ControlDepFor(int j) { int o = __p.__offset(46); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
public int ControlDepForLength { get { int o = __p.__offset(46); return o != 0 ? __p.__vector_len(o) : 0; } } public int ControlDepForLength { get { int o = __p.__offset(46); return o != 0 ? __p.__vector_len(o) : 0; } }
public DType ExtraTypes(int j) { int o = __p.__offset(48); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; }
public int ExtraTypesLength { get { int o = __p.__offset(48); return o != 0 ? __p.__vector_len(o) : 0; } }
#if ENABLE_SPAN_T
public Span<byte> GetExtraTypesBytes() { return __p.__vector_as_span(48); }
#else
public ArraySegment<byte>? GetExtraTypesBytes() { return __p.__vector_as_arraysegment(48); }
#endif
public DType[] GetExtraTypesArray() { return __p.__vector_as_array<DType>(48); }
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder, public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
int id = 0, int id = 0,
@ -135,9 +143,11 @@ public struct FlatNode : IFlatbufferObject
Offset<FlatArray> scalarOffset = default(Offset<FlatArray>), Offset<FlatArray> scalarOffset = default(Offset<FlatArray>),
VectorOffset controlDepsOffset = default(VectorOffset), VectorOffset controlDepsOffset = default(VectorOffset),
VectorOffset varControlDepsOffset = default(VectorOffset), VectorOffset varControlDepsOffset = default(VectorOffset),
VectorOffset controlDepForOffset = default(VectorOffset)) { VectorOffset controlDepForOffset = default(VectorOffset),
builder.StartObject(22); VectorOffset extraTypesOffset = default(VectorOffset)) {
builder.StartObject(23);
FlatNode.AddOpNum(builder, opNum); FlatNode.AddOpNum(builder, opNum);
FlatNode.AddExtraTypes(builder, extraTypesOffset);
FlatNode.AddControlDepFor(builder, controlDepForOffset); FlatNode.AddControlDepFor(builder, controlDepForOffset);
FlatNode.AddVarControlDeps(builder, varControlDepsOffset); FlatNode.AddVarControlDeps(builder, varControlDepsOffset);
FlatNode.AddControlDeps(builder, controlDepsOffset); FlatNode.AddControlDeps(builder, controlDepsOffset);
@ -162,7 +172,7 @@ public struct FlatNode : IFlatbufferObject
return FlatNode.EndFlatNode(builder); return FlatNode.EndFlatNode(builder);
} }
public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(22); } public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(23); }
public static void AddId(FlatBufferBuilder builder, int id) { builder.AddInt(0, id, 0); } public static void AddId(FlatBufferBuilder builder, int id) { builder.AddInt(0, id, 0); }
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); } public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
public static void AddOpType(FlatBufferBuilder builder, OpType opType) { builder.AddSbyte(2, (sbyte)opType, 0); } public static void AddOpType(FlatBufferBuilder builder, OpType opType) { builder.AddSbyte(2, (sbyte)opType, 0); }
@ -224,6 +234,10 @@ public struct FlatNode : IFlatbufferObject
public static VectorOffset CreateControlDepForVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); } public static VectorOffset CreateControlDepForVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
public static VectorOffset CreateControlDepForVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); } public static VectorOffset CreateControlDepForVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
public static void StartControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } public static void StartControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
public static void AddExtraTypes(FlatBufferBuilder builder, VectorOffset extraTypesOffset) { builder.AddOffset(22, extraTypesOffset.Value, 0); }
public static VectorOffset CreateExtraTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); }
public static VectorOffset CreateExtraTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
public static void StartExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) { public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {
int o = builder.EndObject(); int o = builder.EndObject();
return new Offset<FlatNode>(o); return new Offset<FlatNode>(o);

View File

@ -72,6 +72,10 @@ public final class FlatNode extends Table {
public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; } public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; }
public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; } public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; } public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
public byte extraTypes(int j) { int o = __offset(48); return o != 0 ? bb.get(__vector(o) + j * 1) : 0; }
public int extraTypesLength() { int o = __offset(48); return o != 0 ? __vector_len(o) : 0; }
public ByteBuffer extraTypesAsByteBuffer() { return __vector_as_bytebuffer(48, 1); }
public ByteBuffer extraTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 48, 1); }
public static int createFlatNode(FlatBufferBuilder builder, public static int createFlatNode(FlatBufferBuilder builder,
int id, int id,
@ -95,9 +99,11 @@ public final class FlatNode extends Table {
int scalarOffset, int scalarOffset,
int controlDepsOffset, int controlDepsOffset,
int varControlDepsOffset, int varControlDepsOffset,
int controlDepForOffset) { int controlDepForOffset,
builder.startObject(22); int extraTypesOffset) {
builder.startObject(23);
FlatNode.addOpNum(builder, opNum); FlatNode.addOpNum(builder, opNum);
FlatNode.addExtraTypes(builder, extraTypesOffset);
FlatNode.addControlDepFor(builder, controlDepForOffset); FlatNode.addControlDepFor(builder, controlDepForOffset);
FlatNode.addVarControlDeps(builder, varControlDepsOffset); FlatNode.addVarControlDeps(builder, varControlDepsOffset);
FlatNode.addControlDeps(builder, controlDepsOffset); FlatNode.addControlDeps(builder, controlDepsOffset);
@ -122,7 +128,7 @@ public final class FlatNode extends Table {
return FlatNode.endFlatNode(builder); return FlatNode.endFlatNode(builder);
} }
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); } public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(23); }
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); } public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); } public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); } public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
@ -171,6 +177,9 @@ public final class FlatNode extends Table {
public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); } public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); }
public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addExtraTypes(FlatBufferBuilder builder, int extraTypesOffset) { builder.addOffset(22, extraTypesOffset, 0); }
public static int createExtraTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); }
public static void startExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); }
public static int endFlatNode(FlatBufferBuilder builder) { public static int endFlatNode(FlatBufferBuilder builder) {
int o = builder.endObject(); int o = builder.endObject();
return o; return o;

View File

@ -339,7 +339,29 @@ class FlatNode(object):
return self._tab.VectorLen(o) return self._tab.VectorLen(o)
return 0 return 0
def FlatNodeStart(builder): builder.StartObject(22) # FlatNode
def ExtraTypes(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Int8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1))
return 0
# FlatNode
def ExtraTypesAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int8Flags, o)
return 0
# FlatNode
def ExtraTypesLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48))
if o != 0:
return self._tab.VectorLen(o)
return 0
def FlatNodeStart(builder): builder.StartObject(23)
def FlatNodeAddId(builder, id): builder.PrependInt32Slot(0, id, 0) def FlatNodeAddId(builder, id): builder.PrependInt32Slot(0, id, 0)
def FlatNodeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) def FlatNodeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
def FlatNodeAddOpType(builder, opType): builder.PrependInt8Slot(2, opType, 0) def FlatNodeAddOpType(builder, opType): builder.PrependInt8Slot(2, opType, 0)
@ -375,4 +397,6 @@ def FlatNodeAddVarControlDeps(builder, varControlDeps): builder.PrependUOffsetTR
def FlatNodeStartVarControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4) def FlatNodeStartVarControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def FlatNodeAddControlDepFor(builder, controlDepFor): builder.PrependUOffsetTRelativeSlot(21, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepFor), 0) def FlatNodeAddControlDepFor(builder, controlDepFor): builder.PrependUOffsetTRelativeSlot(21, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepFor), 0)
def FlatNodeStartControlDepForVector(builder, numElems): return builder.StartVector(4, numElems, 4) def FlatNodeStartControlDepForVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def FlatNodeAddExtraTypes(builder, extraTypes): builder.PrependUOffsetTRelativeSlot(22, flatbuffers.number_types.UOffsetTFlags.py_type(extraTypes), 0)
def FlatNodeStartExtraTypesVector(builder, numElems): return builder.StartVector(1, numElems, 1)
def FlatNodeEnd(builder): return builder.EndObject() def FlatNodeEnd(builder): return builder.EndObject()

View File

@ -26,7 +26,7 @@ public struct UIVariable : IFlatbufferObject
#endif #endif
public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); } public byte[] GetNameArray() { return __p.__vector_as_array<byte>(6); }
public VarType Type { get { int o = __p.__offset(8); return o != 0 ? (VarType)__p.bb.GetSbyte(o + __p.bb_pos) : VarType.VARIABLE; } } public VarType Type { get { int o = __p.__offset(8); return o != 0 ? (VarType)__p.bb.GetSbyte(o + __p.bb_pos) : VarType.VARIABLE; } }
public DataType Datatype { get { int o = __p.__offset(10); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } } public DType Datatype { get { int o = __p.__offset(10); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } }
public long Shape(int j) { int o = __p.__offset(12); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; } public long Shape(int j) { int o = __p.__offset(12); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; }
public int ShapeLength { get { int o = __p.__offset(12); return o != 0 ? __p.__vector_len(o) : 0; } } public int ShapeLength { get { int o = __p.__offset(12); return o != 0 ? __p.__vector_len(o) : 0; } }
#if ENABLE_SPAN_T #if ENABLE_SPAN_T
@ -70,7 +70,7 @@ public struct UIVariable : IFlatbufferObject
Offset<IntPair> idOffset = default(Offset<IntPair>), Offset<IntPair> idOffset = default(Offset<IntPair>),
StringOffset nameOffset = default(StringOffset), StringOffset nameOffset = default(StringOffset),
VarType type = VarType.VARIABLE, VarType type = VarType.VARIABLE,
DataType datatype = DataType.INHERIT, DType datatype = DType.INHERIT,
VectorOffset shapeOffset = default(VectorOffset), VectorOffset shapeOffset = default(VectorOffset),
VectorOffset controlDepsOffset = default(VectorOffset), VectorOffset controlDepsOffset = default(VectorOffset),
StringOffset outputOfOpOffset = default(StringOffset), StringOffset outputOfOpOffset = default(StringOffset),
@ -101,7 +101,7 @@ public struct UIVariable : IFlatbufferObject
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); } public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); } public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
public static void AddType(FlatBufferBuilder builder, VarType type) { builder.AddSbyte(2, (sbyte)type, 0); } public static void AddType(FlatBufferBuilder builder, VarType type) { builder.AddSbyte(2, (sbyte)type, 0); }
public static void AddDatatype(FlatBufferBuilder builder, DataType datatype) { builder.AddSbyte(3, (sbyte)datatype, 0); } public static void AddDatatype(FlatBufferBuilder builder, DType datatype) { builder.AddSbyte(3, (sbyte)datatype, 0); }
public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(4, shapeOffset.Value, 0); } public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(4, shapeOffset.Value, 0); }
public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); } public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); }
public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); } public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); }

View File

@ -38,7 +38,8 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VT_SCALAR = 40, VT_SCALAR = 40,
VT_CONTROLDEPS = 42, VT_CONTROLDEPS = 42,
VT_VARCONTROLDEPS = 44, VT_VARCONTROLDEPS = 44,
VT_CONTROLDEPFOR = 46 VT_CONTROLDEPFOR = 46,
VT_EXTRATYPES = 48
}; };
int32_t id() const { int32_t id() const {
return GetField<int32_t>(VT_ID, 0); return GetField<int32_t>(VT_ID, 0);
@ -106,6 +107,9 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor() const { const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPFOR); return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPFOR);
} }
const flatbuffers::Vector<int8_t> *extraTypes() const {
return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_EXTRATYPES);
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_ID) && VerifyField<int32_t>(verifier, VT_ID) &&
@ -153,6 +157,8 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VerifyOffset(verifier, VT_CONTROLDEPFOR) && VerifyOffset(verifier, VT_CONTROLDEPFOR) &&
verifier.VerifyVector(controlDepFor()) && verifier.VerifyVector(controlDepFor()) &&
verifier.VerifyVectorOfStrings(controlDepFor()) && verifier.VerifyVectorOfStrings(controlDepFor()) &&
VerifyOffset(verifier, VT_EXTRATYPES) &&
verifier.VerifyVector(extraTypes()) &&
verifier.EndTable(); verifier.EndTable();
} }
}; };
@ -226,6 +232,9 @@ struct FlatNodeBuilder {
void add_controlDepFor(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor) { void add_controlDepFor(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor) {
fbb_.AddOffset(FlatNode::VT_CONTROLDEPFOR, controlDepFor); fbb_.AddOffset(FlatNode::VT_CONTROLDEPFOR, controlDepFor);
} }
void add_extraTypes(flatbuffers::Offset<flatbuffers::Vector<int8_t>> extraTypes) {
fbb_.AddOffset(FlatNode::VT_EXTRATYPES, extraTypes);
}
explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -261,9 +270,11 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNode(
flatbuffers::Offset<FlatArray> scalar = 0, flatbuffers::Offset<FlatArray> scalar = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0, flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> varControlDeps = 0, flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> varControlDeps = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor = 0) { flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor = 0,
flatbuffers::Offset<flatbuffers::Vector<int8_t>> extraTypes = 0) {
FlatNodeBuilder builder_(_fbb); FlatNodeBuilder builder_(_fbb);
builder_.add_opNum(opNum); builder_.add_opNum(opNum);
builder_.add_extraTypes(extraTypes);
builder_.add_controlDepFor(controlDepFor); builder_.add_controlDepFor(controlDepFor);
builder_.add_varControlDeps(varControlDeps); builder_.add_varControlDeps(varControlDeps);
builder_.add_controlDeps(controlDeps); builder_.add_controlDeps(controlDeps);
@ -311,7 +322,8 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNodeDirect(
flatbuffers::Offset<FlatArray> scalar = 0, flatbuffers::Offset<FlatArray> scalar = 0,
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr, const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
const std::vector<flatbuffers::Offset<flatbuffers::String>> *varControlDeps = nullptr, const std::vector<flatbuffers::Offset<flatbuffers::String>> *varControlDeps = nullptr,
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor = nullptr) { const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor = nullptr,
const std::vector<int8_t> *extraTypes = nullptr) {
return nd4j::graph::CreateFlatNode( return nd4j::graph::CreateFlatNode(
_fbb, _fbb,
id, id,
@ -335,7 +347,8 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNodeDirect(
scalar, scalar,
controlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDeps) : 0, controlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDeps) : 0,
varControlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*varControlDeps) : 0, varControlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*varControlDeps) : 0,
controlDepFor ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepFor) : 0); controlDepFor ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepFor) : 0,
extraTypes ? _fbb.CreateVector<int8_t>(*extraTypes) : 0);
} }
inline const nd4j::graph::FlatNode *GetFlatNode(const void *buf) { inline const nd4j::graph::FlatNode *GetFlatNode(const void *buf) {

View File

@ -398,11 +398,36 @@ nd4j.graph.FlatNode.prototype.controlDepForLength = function() {
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0; return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
}; };
/**
* @param {number} index
* @returns {nd4j.graph.DType}
*/
nd4j.graph.FlatNode.prototype.extraTypes = function(index) {
var offset = this.bb.__offset(this.bb_pos, 48);
return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0);
};
/**
* @returns {number}
*/
nd4j.graph.FlatNode.prototype.extraTypesLength = function() {
var offset = this.bb.__offset(this.bb_pos, 48);
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
};
/**
* @returns {Int8Array}
*/
nd4j.graph.FlatNode.prototype.extraTypesArray = function() {
var offset = this.bb.__offset(this.bb_pos, 48);
return offset ? new Int8Array(this.bb.bytes().buffer, this.bb.bytes().byteOffset + this.bb.__vector(this.bb_pos + offset), this.bb.__vector_len(this.bb_pos + offset)) : null;
};
/** /**
* @param {flatbuffers.Builder} builder * @param {flatbuffers.Builder} builder
*/ */
nd4j.graph.FlatNode.startFlatNode = function(builder) { nd4j.graph.FlatNode.startFlatNode = function(builder) {
builder.startObject(22); builder.startObject(23);
}; };
/** /**
@ -854,6 +879,35 @@ nd4j.graph.FlatNode.startControlDepForVector = function(builder, numElems) {
builder.startVector(4, numElems, 4); builder.startVector(4, numElems, 4);
}; };
/**
* @param {flatbuffers.Builder} builder
* @param {flatbuffers.Offset} extraTypesOffset
*/
nd4j.graph.FlatNode.addExtraTypes = function(builder, extraTypesOffset) {
builder.addFieldOffset(22, extraTypesOffset, 0);
};
/**
* @param {flatbuffers.Builder} builder
* @param {Array.<nd4j.graph.DType>} data
* @returns {flatbuffers.Offset}
*/
nd4j.graph.FlatNode.createExtraTypesVector = function(builder, data) {
builder.startVector(1, data.length, 1);
for (var i = data.length - 1; i >= 0; i--) {
builder.addInt8(data[i]);
}
return builder.endVector();
};
/**
* @param {flatbuffers.Builder} builder
* @param {number} numElems
*/
nd4j.graph.FlatNode.startExtraTypesVector = function(builder, numElems) {
builder.startVector(1, numElems, 1);
};
/** /**
* @param {flatbuffers.Builder} builder * @param {flatbuffers.Builder} builder
* @returns {flatbuffers.Offset} * @returns {flatbuffers.Offset}

View File

@ -266,8 +266,8 @@ struct UIVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VarType type() const { VarType type() const {
return static_cast<VarType>(GetField<int8_t>(VT_TYPE, 0)); return static_cast<VarType>(GetField<int8_t>(VT_TYPE, 0));
} }
DataType datatype() const { DType datatype() const {
return static_cast<DataType>(GetField<int8_t>(VT_DATATYPE, 0)); return static_cast<DType>(GetField<int8_t>(VT_DATATYPE, 0));
} }
const flatbuffers::Vector<int64_t> *shape() const { const flatbuffers::Vector<int64_t> *shape() const {
return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE); return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_SHAPE);
@ -342,7 +342,7 @@ struct UIVariableBuilder {
void add_type(VarType type) { void add_type(VarType type) {
fbb_.AddElement<int8_t>(UIVariable::VT_TYPE, static_cast<int8_t>(type), 0); fbb_.AddElement<int8_t>(UIVariable::VT_TYPE, static_cast<int8_t>(type), 0);
} }
void add_datatype(DataType datatype) { void add_datatype(DType datatype) {
fbb_.AddElement<int8_t>(UIVariable::VT_DATATYPE, static_cast<int8_t>(datatype), 0); fbb_.AddElement<int8_t>(UIVariable::VT_DATATYPE, static_cast<int8_t>(datatype), 0);
} }
void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) { void add_shape(flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape) {
@ -389,7 +389,7 @@ inline flatbuffers::Offset<UIVariable> CreateUIVariable(
flatbuffers::Offset<IntPair> id = 0, flatbuffers::Offset<IntPair> id = 0,
flatbuffers::Offset<flatbuffers::String> name = 0, flatbuffers::Offset<flatbuffers::String> name = 0,
VarType type = VarType_VARIABLE, VarType type = VarType_VARIABLE,
DataType datatype = DataType_INHERIT, DType datatype = DType_INHERIT,
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0, flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0, flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
flatbuffers::Offset<flatbuffers::String> outputOfOp = 0, flatbuffers::Offset<flatbuffers::String> outputOfOp = 0,
@ -421,7 +421,7 @@ inline flatbuffers::Offset<UIVariable> CreateUIVariableDirect(
flatbuffers::Offset<IntPair> id = 0, flatbuffers::Offset<IntPair> id = 0,
const char *name = nullptr, const char *name = nullptr,
VarType type = VarType_VARIABLE, VarType type = VarType_VARIABLE,
DataType datatype = DataType_INHERIT, DType datatype = DType_INHERIT,
const std::vector<int64_t> *shape = nullptr, const std::vector<int64_t> *shape = nullptr,
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr, const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
const char *outputOfOp = nullptr, const char *outputOfOp = nullptr,

View File

@ -503,11 +503,11 @@ nd4j.graph.UIVariable.prototype.type = function() {
}; };
/** /**
* @returns {nd4j.graph.DataType} * @returns {nd4j.graph.DType}
*/ */
nd4j.graph.UIVariable.prototype.datatype = function() { nd4j.graph.UIVariable.prototype.datatype = function() {
var offset = this.bb.__offset(this.bb_pos, 10); var offset = this.bb.__offset(this.bb_pos, 10);
return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT; return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT;
}; };
/** /**
@ -668,10 +668,10 @@ nd4j.graph.UIVariable.addType = function(builder, type) {
/** /**
* @param {flatbuffers.Builder} builder * @param {flatbuffers.Builder} builder
* @param {nd4j.graph.DataType} datatype * @param {nd4j.graph.DType} datatype
*/ */
nd4j.graph.UIVariable.addDatatype = function(builder, datatype) { nd4j.graph.UIVariable.addDatatype = function(builder, datatype) {
builder.addFieldInt8(3, datatype, nd4j.graph.DataType.INHERIT); builder.addFieldInt8(3, datatype, nd4j.graph.DType.INHERIT);
}; };
/** /**

View File

@ -551,6 +551,18 @@ namespace nd4j {
bool Context::isInference() { bool Context::isInference() {
return _execMode == samediff::ExecutionMode::MODE_INFERENCE; return _execMode == samediff::ExecutionMode::MODE_INFERENCE;
} }
void Context::setDArguments(nd4j::DataType *arguments, int numberOfArguments) {
_dArgs.clear();
for (int e = 0; e < numberOfArguments; e++)
_dArgs.emplace_back(arguments[e]);
}
void Context::setDArguments(const std::vector<nd4j::DataType> &dArgs) {
_dArgs.clear();
for (auto d:dArgs)
_dArgs.emplace_back(d);
}
} }
} }

View File

@ -173,5 +173,13 @@ namespace nd4j {
return clone; return clone;
} }
std::vector<nd4j::DataType> *ContextPrototype::getDArguments() {
return &_dArgs;
}
size_t ContextPrototype::numD() {
return _dArgs.size();
}
} }
} }

View File

@ -49,11 +49,10 @@ namespace nd4j {
delete[] newShape; delete[] newShape;
return NDArrayFactory::empty_(dtype, nullptr); return NDArrayFactory::empty_(dtype, nullptr);
} }
// TODO fix UTF16 and UTF32
if (dtype == UTF8) { if (dtype == UTF8) {
bool isBe = BitwiseUtils::isBE(); bool isBe = BitwiseUtils::isBE();
bool canKeep = (isBe && flatArray->byteOrder() == nd4j::graph::ByteOrder_BE) || (!isBe && flatArray->byteOrder() == nd4j::graph::ByteOrder_LE); bool canKeep = (isBe && flatArray->byteOrder() == nd4j::graph::ByteOrder_BE) || (!isBe && flatArray->byteOrder() == nd4j::graph::ByteOrder_LE);
auto order = shape::order(newShape);
std::vector<std::string> substrings(length); std::vector<std::string> substrings(length);
std::vector<Nd4jLong> shapeVector(rank); std::vector<Nd4jLong> shapeVector(rank);
@ -88,8 +87,8 @@ namespace nd4j {
delete[] offsets; delete[] offsets;
delete[] newShape; delete[] newShape;
// string order always 'c'
return NDArrayFactory::string_(order, shapeVector, substrings); return NDArrayFactory::string_(shapeVector, substrings);
} }

View File

@ -587,6 +587,12 @@ namespace nd4j {
block->getIArguments()->emplace_back(node->extraInteger()->Get(e)); block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
} }
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e));
}
}
this->setContextPrototype(block); this->setContextPrototype(block);
this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar)); this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar));
block->setOpDescriptor(this->getCustomOp()->getOpDescriptor()); block->setOpDescriptor(this->getCustomOp()->getOpDescriptor());
@ -618,6 +624,12 @@ namespace nd4j {
block->getIArguments()->emplace_back(node->extraInteger()->Get(e)); block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
} }
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e));
}
}
this->setContextPrototype(block); this->setContextPrototype(block);
this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar)); this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar));
@ -652,6 +664,12 @@ namespace nd4j {
block->getBArguments()->push_back(node->extraBools()->Get(e)); block->getBArguments()->push_back(node->extraBools()->Get(e));
} }
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e));
}
}
for (auto v: _dimensions) for (auto v: _dimensions)
block->getAxis()->emplace_back(v); block->getAxis()->emplace_back(v);

View File

@ -58,6 +58,8 @@ table FlatNode {
varControlDeps:[string]; varControlDeps:[string];
controlDepFor:[string]; controlDepFor:[string];
// DArgs
extraTypes:[DType];
} }
root_type FlatNode; root_type FlatNode;

View File

@ -171,7 +171,10 @@ namespace nd4j {
* @param numStrings * @param numStrings
* @return * @return
*/ */
static Nd4jLong stringBufferHeaderRequirements(Nd4jLong numStrings); static FORCEINLINE Nd4jLong stringBufferHeaderRequirements(Nd4jLong numStrings) {
// we store +1 offset
return (numStrings + 1) * sizeof(Nd4jLong);
}
/* /*
* check whether arr1/arr2 is sub-array of arr2/arr1, * check whether arr1/arr2 is sub-array of arr2/arr1,

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019-2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,6 +17,7 @@
// //
// Created by raver119 on 20/04/18. // Created by raver119 on 20/04/18.
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
// //
#ifndef LIBND4J_STRINGUTILS_H #ifndef LIBND4J_STRINGUTILS_H
@ -27,6 +29,7 @@
#include <sstream> #include <sstream>
#include <vector> #include <vector>
#include <NDArray.h> #include <NDArray.h>
#include <unicode.h>
namespace nd4j { namespace nd4j {
class ND4J_EXPORT StringUtils { class ND4J_EXPORT StringUtils {
@ -85,6 +88,55 @@ namespace nd4j {
* @return * @return
*/ */
static std::vector<std::string> split(const std::string &haystack, const std::string &delimiter); static std::vector<std::string> split(const std::string &haystack, const std::string &delimiter);
/**
* This method convert u8 string to u16
* @param const reference to input string
* @param reference to output u16string
* @return boolean status
*/
static bool u8StringToU16String(const std::string& u8, std::u16string& u16);
/**
* This method convert u8 string to u32
* @param const reference to input string
* @param reference to output u32string
* @return boolean status
*/
static bool u8StringToU32String(const std::string& u8, std::u32string& u32);
/**
* This method convert u16 string to u32
* @param const reference to input u16string
* @param reference to output u32string
* @return boolean status
*/
static bool u16StringToU32String(const std::u16string& u16, std::u32string& u32);
/**
* This method convert u16 string to u8 string
* @param const reference to input u16string
* @param reference to output string
* @return boolean status
*/
static bool u16StringToU8String(const std::u16string& u16, std::string& u8);
/**
* This method convert u32 string to u16 string
* @param const reference to input u32string
* @param reference to output u16string
* @return boolean status
*/
static bool u32StringToU16String(const std::u32string& u32, std::u16string& u16);
/**
* This method convert u32 string to u8 string
* @param const reference to input u32string
* @param reference to output string
* @return boolean status
*/
static bool u32StringToU8String(const std::u32string& u32, std::string& u8);
}; };
} }

View File

@ -40,7 +40,7 @@ namespace nd4j {
NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); //[nHeads*hS, batch*timeSteps] NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); //[nHeads*hS, batch*timeSteps]
nd4j::ops::matmul mmul; nd4j::ops::matmul mmul;
mmul.execute({&projectionPrep, &inputPrep}, {&projected}, {}, {}, {}); mmul.execute({&projectionPrep, &inputPrep}, {&projected});
projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength}); projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
projected.permutei({2, 0, 1, 3}); //[minibatch, numHeads, projectedSize, seqLength] projected.permutei({2, 0, 1, 3}); //[minibatch, numHeads, projectedSize, seqLength]
@ -66,7 +66,7 @@ namespace nd4j {
nd4j::ops::matmul_bp mmulBp; nd4j::ops::matmul_bp mmulBp;
NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context); NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context);
NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context); NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context);
mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, {&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {}); mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, std::vector<NDArray*>{&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {});
dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
dLdProjectionMatrix->assign(dLdProjectionPrep); dLdProjectionMatrix->assign(dLdProjectionPrep);

View File

@ -1019,15 +1019,6 @@ std::vector<int> ShapeUtils::tadAxesForSimpleBroadcast(const NDArray& max, const
return numOfMinTads == 1 ? maxTadDims : std::vector<int>(); return numOfMinTads == 1 ? maxTadDims : std::vector<int>();
} }
Nd4jLong ShapeUtils::stringBufferHeaderRequirements(Nd4jLong numStrings) {
// we store +1 offset
auto base = numStrings + 1;
// since we return number of bytes...
return base * sizeof(Nd4jLong);
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
/* /*
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) { bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) {

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019-2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,6 +17,7 @@
// //
// Created by raver119 on 20/04/18. // Created by raver119 on 20/04/18.
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
// //
#include <helpers/StringUtils.h> #include <helpers/StringUtils.h>
@ -49,13 +51,8 @@ namespace nd4j {
if (!array.isS()) if (!array.isS())
throw nd4j::datatype_exception::build("StringUtils::byteLength expects one of String types;", array.dataType()); throw nd4j::datatype_exception::build("StringUtils::byteLength expects one of String types;", array.dataType());
uint64_t result = 0;
// our buffer stores offsets, and the last value is basically number of bytes used
auto buffer = array.bufferAsT<Nd4jLong>(); auto buffer = array.bufferAsT<Nd4jLong>();
result = buffer[array.lengthOf()]; return buffer[array.lengthOf()];
return result;
} }
std::vector<std::string> StringUtils::split(const std::string &haystack, const std::string &delimiter) { std::vector<std::string> StringUtils::split(const std::string &haystack, const std::string &delimiter) {
@ -73,4 +70,89 @@ namespace nd4j {
return output; return output;
} }
bool StringUtils::u8StringToU16String(const std::string& u8, std::u16string& u16) {
if (u8.empty())
return false;
u16.resize(unicode::offsetUtf8StringInUtf16(u8.data(), u8.size()) / sizeof(char16_t));
if (u8.size() == u16.size())
u16.assign(u8.begin(), u8.end());
else
return unicode::utf8to16(u8.data(), &u16[0], u8.size());
return true;
}
bool StringUtils::u8StringToU32String(const std::string& u8, std::u32string& u32) {
if (u8.empty())
return false;
u32.resize( unicode::offsetUtf8StringInUtf32(u8.data(), u8.size()) / sizeof(char32_t) );
if (u8.size() == u32.size())
u32.assign(u8.begin(), u8.end());
else
return unicode::utf8to32(u8.data(), &u32[0], u8.size());
return true;
}
bool StringUtils::u16StringToU32String(const std::u16string& u16, std::u32string& u32) {
if (u16.empty())
return false;
u32.resize(unicode::offsetUtf16StringInUtf32(u16.data(), u16.size()) / sizeof(char32_t));
if (u16.size() == u32.size())
u32.assign(u16.begin(), u16.end());
else
return unicode::utf16to32(u16.data(), &u32[0], u16.size());
return true;
}
bool StringUtils::u16StringToU8String(const std::u16string& u16, std::string& u8) {
if (u16.empty())
return false;
u8.resize(unicode::offsetUtf16StringInUtf8(u16.data(), u16.size()));
if (u16.size() == u8.size())
u8.assign(u16.begin(), u16.end());
else
return unicode::utf16to8(u16.data(), &u8[0], u16.size());
return true;
}
bool StringUtils::u32StringToU16String(const std::u32string& u32, std::u16string& u16) {
if (u32.empty())
return false;
u16.resize(unicode::offsetUtf32StringInUtf16(u32.data(), u32.size()) / sizeof(char16_t));
if (u32.size() == u16.size())
u16.assign(u32.begin(), u32.end());
else
return unicode::utf32to16(u32.data(), &u16[0], u32.size());
return true;
}
bool StringUtils::u32StringToU8String(const std::u32string& u32, std::string& u8) {
if (u32.empty())
return false;
u8.resize(unicode::offsetUtf32StringInUtf8(u32.data(), u32.size()));
if (u32.size() == u8.size())
u8.assign(u32.begin(), u32.end());
else
return unicode::utf32to8(u32.data(), &u8[0], u32.size());
return true;
}
} }

View File

@ -0,0 +1,456 @@
/*******************************************************************************
* Copyright (c) 2015-2020 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
//
#include <unicode.h>
namespace nd4j {
namespace unicode {
constexpr uint32_t ONEBYTEBOUND = 0x00000080;
constexpr uint32_t TWOBYTEBOUND = 0x00000800;
constexpr uint32_t THREEBYTEBOUND = 0x00010000;
constexpr uint16_t HIGHBYTEMIN = 0xd800u;
constexpr uint16_t HIGHBYTEMAX = 0xdbffu;
constexpr uint16_t TRAILBYTEMIN = 0xdc00u;
constexpr uint16_t TRAILBYTEMAX = 0xdfffu;
constexpr uint16_t HIGHBYTEOFFSET = HIGHBYTEMIN - (0x10000 >> 10);
constexpr uint32_t BYTEOFFSET = 0x10000u - (HIGHBYTEMIN << 10) - TRAILBYTEMIN;
// Maximum valid value for a Unicode code point
constexpr uint32_t CODEPOINTMAX = 0x0010ffffu;
template<typename T>
FORCEINLINE uint8_t castToU8(const T cp) {
return static_cast<uint8_t>(0xff & cp);
}
template<typename T>
FORCEINLINE uint16_t castToU16(const T cp) {
return static_cast<uint16_t>(0xffff & cp);
}
template<typename T>
FORCEINLINE uint32_t castToU32(const T cp) {
return static_cast<uint32_t>(0xffffff & cp);
}
template<typename T>
FORCEINLINE bool isTrail(const T cp) {
return ((castToU8(cp) >> 6) == 0x2);
}
template <typename T>
FORCEINLINE bool isHighSurrogate(const T cp) {
return (cp & 0xfffffc00) == 0xd800;
}
template <typename T>
bool isLowSurrogate(const T cp) {
return (cp & 0xfffffc00) == 0xdc00;
}
template <typename T>
FORCEINLINE bool isLeadSurrogate(const T cp) {
return (cp >= HIGHBYTEMIN && cp <= HIGHBYTEMAX);
}
template <typename T>
FORCEINLINE bool isTrailSurrogate(const T cp) {
return (cp >= TRAILBYTEMIN && cp <= TRAILBYTEMAX);
}
template <typename T>
FORCEINLINE bool isSurrogateU8(const T cp) {
return (cp >= HIGHBYTEMIN && cp <= TRAILBYTEMAX);
}
template <typename T>
FORCEINLINE bool isSurrogateU16(const T cp) {
return ((cp - 0xd800u) < 2048u);
}
template <typename T>
FORCEINLINE bool isSymbolU8Valid(const T cp) {
return (cp <= CODEPOINTMAX && !isSurrogateU8(cp));
}
template <typename T>
FORCEINLINE bool isSymbolValid(const T cp) {
return (cp <= CODEPOINTMAX);
}
template <typename T>
FORCEINLINE uint32_t surrogateU32(const T& high, const T& low) {
return (high << 10) + low - 0x35fdc00;
}
template <typename T>
Nd4jLong symbolLength(const T* it) {
uint8_t lead = castToU8(*it);
if (lead < 0x80)
return 1;
else if ((lead >> 5) == 0x6)
return 2;
else if ((lead >> 4) == 0xe)
return 3;
else if ((lead >> 3) == 0x1e)
return 4;
else
return 0;
}
template <typename T>
Nd4jLong symbolLength32(const T* it) {
auto lead = castToU32(*it);
if (lead < ONEBYTEBOUND)
return 1;
else if (lead < TWOBYTEBOUND)
return 2;
else if (lead < THREEBYTEBOUND)
return 3;
else if (lead <= CODEPOINTMAX)
return 4;
else
return 0;
}
template <typename T>
Nd4jLong symbolLength16(const T* it) {
uint32_t lead = castToU16(*it);
if (!isLeadSurrogate(lead)) {
if (lead < ONEBYTEBOUND)
return 1;
else if (lead < TWOBYTEBOUND)
return 2;
else if (lead < THREEBYTEBOUND)
return 3;
else
return 0;
}
else {
return 4;
}
}
Nd4jLong offsetUtf8StringInUtf32(const void* start, const void* end) {
Nd4jLong count = 0;
for (auto it = static_cast<const int8_t*>(start); it != end; it++) {
auto length = symbolLength(it);
it += (length > 0) ? (length - 1) : 0;
count += 1;
}
return static_cast<Nd4jLong>(count * sizeof(char32_t));
}
Nd4jLong offsetUtf16StringInUtf32(const void* start, const void* end) {
Nd4jLong count = 0;
for (auto it = static_cast<const uint16_t*>(start); it != end;) {
auto length = symbolLength16(it);
it += (4 == length) ? 2 : 1;
count += 1;
}
return static_cast<Nd4jLong>(count*sizeof(char32_t));
}
Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end) {
Nd4jLong count = 0;
for (auto it = static_cast<const int8_t*>(start); it != end; it++) {
auto length = symbolLength(it);
auto step = ((length > 0) ? (length - 1) : 0);
it += step;
count += (4 == length) ? 2 : 1;
}
return static_cast<Nd4jLong>(count*sizeof(char16_t));
}
Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end) {
Nd4jLong count = 0;
for (auto it = static_cast<const uint16_t*>(start); it != end;) {
auto length = symbolLength16(it);
it += (4 == length) ? 2 : 1;
count += length;
}
return static_cast<Nd4jLong>(count);
}
Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end) {
Nd4jLong count = 0;
for (auto it = static_cast<const uint32_t*>(start); it != end; it++) {
auto length = symbolLength32(it);
count += (4 == length) ? 2 : 1;;
}
return static_cast<Nd4jLong>(count*sizeof(char16_t));
}
Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end) {
Nd4jLong count = 0;
for (auto it = static_cast<const uint32_t*>(start); it != end; it++) {
count += symbolLength32(it);
}
return count;
}
bool isStringValidU8(const void* start, const void* stop) {
for (auto it = static_cast<const int8_t*>(start); it != stop; it++) {
if (!isSymbolU8Valid( castToU8(*it) )) {
return false;
}
}
return true;
}
bool isStringValidU16(const void* start, const void* stop) {
for (auto it = static_cast<const uint16_t*>(start); it != stop; it++) {
if (!isSymbolValid( castToU32(*it) )) {
return false;
}
}
return true;
}
bool isStringValidU32(const void* start, const void* stop) {
for (auto it = static_cast<const uint32_t*>(start); it != stop; it++) {
if (!isSymbolValid( castToU32(*it) )) {
return false;
}
}
return true;
}
void* utf16to8Ptr(const void* start, const void* end, void* res) {
auto result = static_cast<int8_t*>(res);
// result have to be pre-allocated
for (auto it = static_cast<const uint16_t*>(start); it != end;) {
uint32_t cp = castToU16(*it++);
if (!isLeadSurrogate(cp)) {
if (cp < 0x80) { // for one byte
*(result++) = static_cast<uint8_t>(cp);
}
else if (cp < 0x800) { // for two bytes
*(result++) = static_cast<uint8_t>((cp >> 6) | 0xc0);
*(result++) = static_cast<uint8_t>((cp & 0x3f) | 0x80);
}
else{ // for three bytes
*(result++) = static_cast<uint8_t>((cp >> 12) | 0xe0);
*(result++) = static_cast<uint8_t>(((cp >> 6) & 0x3f) | 0x80);
*(result++) = static_cast<uint8_t>((cp & 0x3f) | 0x80);
}
}
else {
if (it != end) {
uint32_t trail_surrogate = castToU16(*it++);
if (isTrailSurrogate(trail_surrogate))
cp = (cp << 10) + trail_surrogate + BYTEOFFSET;
}
// for four bytes
*(result++) = static_cast<uint8_t>((cp >> 18) | 0xf0);
*(result++) = static_cast<uint8_t>(((cp >> 12) & 0x3f) | 0x80);
*(result++) = static_cast<uint8_t>(((cp >> 6) & 0x3f) | 0x80);
*(result++) = static_cast<uint8_t>((cp & 0x3f) | 0x80);
}
}
return result;
}
void* utf8to16Ptr(const void* start, const void* end, void* res) {
auto result = static_cast<uint16_t*>(res);
// result have to be pre-allocated
for (auto it = static_cast<const int8_t*>(start); it != end;) {
auto nLength = symbolLength(it);
uint32_t cp = castToU8(*it++);
if (4 != nLength) {
if (2 == nLength) {
cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f);
}
else if (3 == nLength) {
cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff);
cp += (*it++) & 0x3f;
}
*(result++) = static_cast<uint16_t>(cp);
}
else {
cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff);
cp += (castToU8(*it++) << 6) & 0xfff;
cp += (*it++) & 0x3f;
//make a surrogate pair
*(result++) = static_cast<uint16_t>((cp >> 10) + HIGHBYTEOFFSET);
*(result++) = static_cast<uint16_t>((cp & 0x3ff) + TRAILBYTEMIN);
}
}
return result;
}
void* utf32to8Ptr( const void* start, const void* end, void* result) {
auto res = static_cast<uint8_t*>(result);
// result have to be pre-allocated
for (auto it = static_cast<const uint32_t*>(start); it != end; it++) {
if (*it < 0x80) // for one byte
*(res++) = static_cast<uint8_t>(*it);
else if (*it < 0x800) { // for two bytes
*(res++) = static_cast<uint8_t>((*it >> 6) | 0xc0);
*(res++) = static_cast<uint8_t>((*it & 0x3f) | 0x80);
}
else if (*it < 0x10000) { // for three bytes
*(res++) = static_cast<uint8_t>((*it >> 12) | 0xe0);
*(res++) = static_cast<uint8_t>(((*it >> 6) & 0x3f) | 0x80);
*(res++) = static_cast<uint8_t>((*it & 0x3f) | 0x80);
}
else { // for four bytes
*(res++) = static_cast<uint8_t>((*it >> 18) | 0xf0);
*(res++) = static_cast<uint8_t>(((*it >> 12) & 0x3f) | 0x80);
*(res++) = static_cast<uint8_t>(((*it >> 6) & 0x3f) | 0x80);
*(res++) = static_cast<uint8_t>((*it & 0x3f) | 0x80);
}
}
return result;
}
void* utf8to32Ptr(const void* start, const void* end, void* res) {
auto result = static_cast<uint32_t*>(res);
// result have to be pre-allocated
for (auto it = static_cast<const int8_t*>(start); it != end;) {
auto nLength = symbolLength(it);
uint32_t cp = castToU8(*it++);
if (2 == nLength) {
cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f);
}
else if (3 == nLength) {
cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff);
cp += (*it++) & 0x3f;
}
else if (4 == nLength) {
cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff);
cp += (castToU8(*it++) << 6) & 0xfff;
cp += (*it++) & 0x3f;
}
(*result++) = cp;
}
return result;
}
void* utf16to32Ptr(const void* start, const void* end, void* res) {
auto result = static_cast<uint32_t*>(res);
// result have to be pre-allocated
for (auto it = static_cast<const uint16_t*>(start); it != end; it++) {
uint32_t cpHigh = castToU32(*it);
if (!isSurrogateU16(cpHigh)) {
*result++ = cpHigh;
}
else {
it++;
uint32_t cpLow = castToU32(*it);
if (isHighSurrogate(cpHigh) && it != end && isLowSurrogate(cpLow)) {
*result++ = surrogateU32(cpHigh, cpLow);
}
}
}
return result;
}
void* utf32to16Ptr(const void* start, const void* end, void* res) {
auto result = static_cast<uint16_t*>(res);
// result have to be pre-allocate
for (auto it = static_cast<const uint32_t*>(start); it != end; it++) {
uint32_t cpHigh = castToU32(*it);
// todo check do we need this as we have pre-validation, if yes find out how to check u16
if (cpHigh < 0 || cpHigh > 0x10FFFF || (cpHigh >= 0xD800 && cpHigh <= 0xDFFF)) {
// Invalid code point. Replace with sentinel, per Unicode standard:
*result++ = u'\uFFFD';
}
else if (cpHigh < 0x10000UL) { // In the BMP.
*result++ = static_cast<char16_t>(cpHigh);
}
else {
*result++ = static_cast<char16_t>(((cpHigh - 0x10000UL) / 0x400U) + 0xD800U);
*result++ = static_cast<char16_t>(((cpHigh - 0x10000UL) % 0x400U) + 0xDC00U);
}
}
return result;
}
Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize) {
return offsetUtf8StringInUtf32(input, static_cast<const int8_t*>(input) + nInputSize);
}
Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize) {
return offsetUtf16StringInUtf32(input, static_cast<const uint16_t*>(input) + nInputSize);
}
Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize) {
return offsetUtf8StringInUtf16(input, static_cast<const int8_t*>(input) + nInputSize);
}
Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize) {
return offsetUtf16StringInUtf8(input, static_cast<const uint16_t*>(input) + nInputSize);
}
Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize) {
return offsetUtf32StringInUtf8(input, static_cast<const uint32_t*>(input) + nInputSize);
}
Nd4jLong offsetUtf32StringInUtf16(const void* input, const uint32_t nInputSize) {
return offsetUtf32StringInUtf16(input, static_cast<const uint32_t*>(input) + nInputSize);
}
bool utf8to16(const void* input, void* output, uint32_t nInputSize) {
return utf8to16Ptr(input, static_cast<const int8_t*>(input) + nInputSize, output);
}
bool utf8to32(const void* input, void* output, uint32_t nInputSize) {
return utf8to32Ptr(input, static_cast<const int8_t*>(input) + nInputSize, output);
}
bool utf16to32(const void* input, void* output, uint32_t nInputSize) {
return utf16to32Ptr(input, static_cast<const uint16_t*>(input) + nInputSize, output);
}
bool utf16to8(const void* input, void* output, uint32_t nInputSize) {
return utf16to8Ptr(input, static_cast<const uint16_t*>(input) + nInputSize, output);
}
bool utf32to16(const void* input, void* output, uint32_t nInputSize) {
return utf32to16Ptr(input, static_cast<const uint32_t*>(input) + nInputSize, output);
}
bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize) {
return utf32to8Ptr(input, static_cast<const uint32_t*>(input) + nInputSize, output);
}
}
}

View File

@ -0,0 +1,189 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
//
#ifndef LIBND4J_UNICODE_H
#define LIBND4J_UNICODE_H
#include <NDArray.h>
namespace nd4j {
namespace unicode {
/**
* This method calculate u16 offset based on utf8
* @param const pointer to the utf8 string start point
* @param size of the string
* @return offset of utf16
*/
Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end);
/**
* This method calculate u8 offset based on utf16
* @param const pointer to the utf16 string start point
* @param size of the string
* @return offset of utf8
*/
Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end);
/**
* This method calculate u32 offset based on utf16
* @param const pointer to the utf16 string start point
* @param size of the string
* @return offset of utf32
*/
Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end);
/**
* This method calculate u32 offset based on utf8
* @param const pointer to the utf16 string start point
* @param size of the string
* @return offset of utf8
*/
Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end);
/*
* This function check is valid charecter in u8 string
*/
bool isStringValidU8(const void* start, const void* stop);
/*
* This function check is valid charecter in u16 string
*/
bool isStringValidU16(const void* start, const void* stop);
/*
* This function check is valid u32 charecter in string
*/
bool isStringValidU32(const void* start, const void* stop);
/**
* This method count offset for utf8 string in utf32
* @param const pointer to the utf8 string start point
* @param size of the string
* @return offset
*/
Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize);
/**
* This method count offset for utf8 string in utf32
* @param const pointer to the utf8 string start point
* @param const end pointer to the utf8 string
* @return offset
*/
Nd4jLong offsetUtf8StringInUtf32(const void* input, const void* stop);
/**
* This method count offset for utf32 based on utf16 string
* @param const pointer to the utf16 string start point
* @param size of the string
* @return offset
*/
Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize);
/**
* This method calculate offset of u16 based on utf8
* @param const pointer to the utf8 string start point
* @param size of the string
* @return offset of utf16
*/
Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize);
/**
* This method calculate offset of u8 based on utf16
* @param const pointer to the utf16 string start point
* @param size of the string
* @return offset of utf8
*/
Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize);
/**
* This method calculate offset of u32 based on utf8
* @param const pointer to the utf16 string start point
* @param size of the string
* @return offset of utf32
*/
Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize);
/**
* This method calculate offset of u32 based on utf16
* @param const pointer to the utf16 string start point
* @param size of the string
* @return offset of utf32
*/
Nd4jLong offsetUtf32StringInUtf16(const void* input, const uint32_t nInputSize);
/**
* This method convert utf8 string to utf16 string
* @param const pointer to the utf8 string start point
* @param reference to start point to utf16
* @param size of input utf8 string
* @return status of convertion
*/
bool utf8to16(const void* input, void* output, uint32_t nInputSize);
/**
* This method convert utf8 string to utf32 string
* @param const pointer to the utf8 string start point
* @param reference to start point to utf32
* @param size of input utf8 string
* @return status of convertion
*/
bool utf8to32(const void* input, void* output, uint32_t nInputSize);
/**
* This method convert utf16 string to utf32 string
* @param const pointer to the utf16 string start point
* @param reference to start point to utf32
* @param size of input utf16 string
* @return status of convertion
*/
bool utf16to32(const void* input, void* output, uint32_t nInputSize);
/**
* This method convert utf16 string to utf8 string
* @param const pointer to the utf16 string start point
* @param reference to start point to utf8
* @param size of input utf16 string
* @return status of convertion
*/
bool utf16to8(const void* input, void* output, uint32_t nInputSize);
/**
* This method convert utf32 string to utf16 string
* @param const pointer to the utf32 string start point
* @param reference to start point to utf16
* @param size of input utf32 string
* @return status of convertion
*/
bool utf32to16(const void* input, void* output, uint32_t nInputSize);
/**
* This method convert utf32 string to utf8 string
* @param const pointer to the utf32 string start point
* @param reference to start point to utf8
* @param size of input utf32 string
* @return status of convertion
*/
bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize);
}
}
#endif //LIBND4J_UNICODE_H

View File

@ -29,6 +29,7 @@ using namespace randomOps;
namespace functions { namespace functions {
namespace random { namespace random {
template<typename X> template<typename X>
template<typename OpClass> template<typename OpClass>
void RandomFunction<X>::execTransform(Nd4jPointer state, void RandomFunction<X>::execTransform(Nd4jPointer state,
@ -56,6 +57,19 @@ namespace functions {
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
if(shape::elementWiseStride(zShapeInfo) == 1 && shape::elementWiseStride(xShapeInfo) == 1 && shape::elementWiseStride(yShapeInfo) == 1 &&
shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(yShapeInfo) ){
auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD
for (auto i = start; i < stop; i++) {
z[i] = OpClass::op(x[i], y[i], i, length, rng, extraArguments);
}
};
samediff::Threads::parallel_for(func, 0, length, 1);
}
else{
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
@ -69,6 +83,7 @@ namespace functions {
samediff::Threads::parallel_for(func, 0, length, 1); samediff::Threads::parallel_for(func, 0, length, 1);
} }
}
else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
@ -169,6 +184,17 @@ namespace functions {
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
if(shape::elementWiseStride(zShapeInfo) == 1 && shape::elementWiseStride(xShapeInfo) == 1 && shape::order(xShapeInfo) == shape::order(zShapeInfo)){
auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD
for (auto i = start; i < stop; i++) {
z[i] = OpClass::op(x[i], i, length, rng, extraArguments);
}
};
samediff::Threads::parallel_for(func, 0, length, 1);
}
else{
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (uint64_t i = start; i < stop; i += increment) { for (uint64_t i = start; i < stop; i += increment) {
@ -179,6 +205,7 @@ namespace functions {
samediff::Threads::parallel_for(func, 0, length, 1); samediff::Threads::parallel_for(func, 0, length, 1);
} }
}
else { else {
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
@ -208,6 +235,19 @@ namespace functions {
auto length = shape::length(zShapeInfo); auto length = shape::length(zShapeInfo);
nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state); nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
if(shape::elementWiseStride(zShapeInfo) == 1){
auto func = PRAGMA_THREADS_FOR {
PRAGMA_OMP_SIMD
for (auto i = start; i < stop; i++) {
z[i] = OpClass::op( i, length, rng, extraArguments);
}
};
samediff::Threads::parallel_for(func, 0, length, 1);
}
else{
nd4j::OmpLaunchHelper info(length); nd4j::OmpLaunchHelper info(length);
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
@ -223,6 +263,7 @@ namespace functions {
samediff::Threads::parallel_for(func, 0, length, 1); samediff::Threads::parallel_for(func, 0, length, 1);
} }
}
template<typename X> template<typename X>
void RandomFunction<X>::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraArguments) { void RandomFunction<X>::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraArguments) {

View File

@ -1516,7 +1516,9 @@
#define INPUT_LIST(INDEX) reinterpret_cast<nd4j::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList()) #define INPUT_LIST(INDEX) reinterpret_cast<nd4j::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList())
#define D_ARG(INDEX) block.getDArguments()->at(INDEX)
#define INT_ARG(INDEX) block.getIArguments()->at(INDEX) #define INT_ARG(INDEX) block.getIArguments()->at(INDEX)
#define I_ARG(INDEX) INT_ARG(INDEX)
#define T_ARG(INDEX) block.getTArguments()->at(INDEX) #define T_ARG(INDEX) block.getTArguments()->at(INDEX)
#define B_ARG(INDEX) block.getBArguments()->at(INDEX) #define B_ARG(INDEX) block.getBArguments()->at(INDEX)

View File

@ -36,9 +36,8 @@ namespace nd4j {
public: public:
BooleanOp(const char *name, int numInputs, bool scalar); BooleanOp(const char *name, int numInputs, bool scalar);
bool evaluate(std::initializer_list<nd4j::NDArray*> args); bool verify(const std::vector<nd4j::NDArray*>& args);
bool evaluate(std::vector<nd4j::NDArray*>& args); bool verify(nd4j::graph::Context& block);
bool evaluate(nd4j::graph::Context& block);
Nd4jStatus execute(Context* block) override; Nd4jStatus execute(Context* block) override;

View File

@ -169,13 +169,22 @@ namespace nd4j {
*/ */
virtual Nd4jStatus execute(Context* block); virtual Nd4jStatus execute(Context* block);
nd4j::ResultSet* execute(std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32); Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs);
Nd4jStatus execute(std::initializer_list<NDArray*> inputs, std::initializer_list<NDArray*> outputs , std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
Nd4jStatus execute(nd4j::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<NDArray*> outputs , std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
nd4j::ResultSet* execute(const std::vector<NDArray*>& inputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs = std::vector<bool>(), bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32); template <class T, typename = std::enable_if<DataTypeUtils::scalarTypesForExecution<T>::value>>
Nd4jStatus execute(std::vector<NDArray*>& inputs, std::vector<NDArray*>& outputs , std::vector<double>& tArgs, std::vector<Nd4jLong>& iArgs, std::vector<bool>& bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32); Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, std::initializer_list<T> tArgs);
Nd4jStatus execute(nd4j::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<NDArray*>& outputs, std::vector<double>& tArgs, std::vector<Nd4jLong>& iArgs, std::vector<bool>& bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false);
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs);
template <class T, typename = std::enable_if<DataTypeUtils::scalarTypesForExecution<T>::value>>
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs, std::initializer_list<T> args);
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false);
Nd4jStatus execute(nd4j::graph::RandomGenerator& rng, const std::vector<NDArray*>& inputs, const std::vector<NDArray*>& outputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs, const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
nd4j::ResultSet* execute(const nd4j::OpArgsHolder& holder, bool isInplace = false); nd4j::ResultSet* execute(const nd4j::OpArgsHolder& holder, bool isInplace = false);

View File

@ -73,7 +73,7 @@ namespace nd4j {
// at first step we build fwd activation // at first step we build fwd activation
nd4j::ops::crelu op; nd4j::ops::crelu op;
auto tmpResult = op.execute({input}, {}, {}, {}); auto tmpResult = op.evaluate({input});
if (tmpResult->status() != ND4J_STATUS_OK) if (tmpResult->status() != ND4J_STATUS_OK)
return tmpResult->status(); return tmpResult->status();
@ -84,7 +84,7 @@ namespace nd4j {
helpers::reluDerivative(block.launchContext(), actv, epsilonNext); helpers::reluDerivative(block.launchContext(), actv, epsilonNext);
// now we split updated array into 2 chunks along last dimension // now we split updated array into 2 chunks along last dimension
nd4j::ops::concat_bp opc; nd4j::ops::concat_bp opc;
auto dec = opc.execute({input, input, actv}, {}, {-1}, {}); auto dec = opc.evaluate({input, input, actv}, {-1});
if (dec->status() != ND4J_STATUS_OK) if (dec->status() != ND4J_STATUS_OK)
return dec->status(); return dec->status();

View File

@ -103,7 +103,7 @@ namespace nd4j {
// if (output->isEmpty()) // if (output->isEmpty())
Nd4jLong width = condition->rankOf(); Nd4jLong width = condition->rankOf();
nd4j::ops::Where op; nd4j::ops::Where op;
std::unique_ptr<ResultSet> res(op.execute({condition}, {}, {}, {})); std::unique_ptr<ResultSet> res(op.evaluate({condition}));
REQUIRE_OK(res->status()); REQUIRE_OK(res->status());
NDArray* whereTrue = res->at(0); NDArray* whereTrue = res->at(0);
if (whereTrue->isEmpty()) if (whereTrue->isEmpty())

View File

@ -66,7 +66,7 @@ namespace nd4j {
auto gradY = OUTPUT_VARIABLE(1); auto gradY = OUTPUT_VARIABLE(1);
gradX->assign(epsNext); gradX->assign(epsNext);
nd4j::ops::floormod op; nd4j::ops::floormod op;
std::unique_ptr<ResultSet> tmpResult(op.execute({x, y}, {}, {}, {})); std::unique_ptr<ResultSet> tmpResult(op.evaluate({x, y}));
if (gradY->rankOf() == gradX->rankOf()) if (gradY->rankOf() == gradX->rankOf())
epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult->at(0), *gradY); epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult->at(0), *gradY);

View File

@ -118,7 +118,7 @@ namespace ops {
DECLARE_TYPES(Pow_bp) { DECLARE_TYPES(Pow_bp) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes({ ALL_FLOATS, ALL_INTS }) ->setAllowedInputTypes({ ALL_FLOATS, ALL_INTS })
->setAllowedOutputTypes({ ALL_FLOATS }); // TODO maybe wourth to add ALL_INTS ->setAllowedOutputTypes({ ALL_FLOATS });
} }
} }

View File

@ -81,7 +81,7 @@ namespace nd4j {
} }
// now once we have all strings in single vector time to fill // now once we have all strings in single vector time to fill
auto tmp = NDArrayFactory::string('c', {(Nd4jLong) strings.size()}, strings); auto tmp = NDArrayFactory::string({(Nd4jLong) strings.size()}, strings);
auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size()); auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size());
// for CUDA mostly // for CUDA mostly

View File

@ -197,8 +197,7 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
// ***** calculations ***** // // ***** calculations ***** //
// notations: // notations:
// f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output // f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output, g = dLdO
// g = dLdO
// stdInv = 1 / (v + eps)^0.5 // stdInv = 1 / (v + eps)^0.5
// N - batch size (product of spatial dimensions) // N - batch size (product of spatial dimensions)

View File

@ -222,7 +222,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
nd4j::ops::conv2d_bp conv2dBP; nd4j::ops::conv2d_bp conv2dBP;
const Nd4jStatus status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {}); auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
if (status != ND4J_STATUS_OK) if (status != ND4J_STATUS_OK)
return status; return status;

View File

@ -91,7 +91,7 @@ namespace ops {
} }
nd4j::ops::softmax softmax; nd4j::ops::softmax softmax;
softmax.execute({weights}, {weights}, {}, {-2}, {}, true); softmax.execute({weights}, std::vector<NDArray*>{weights}, {}, {-2}, {}, {}, true);
mmul.execute({values, weights}, {output}, {}, {}, {}); mmul.execute({values, weights}, {output}, {}, {}, {});
@ -189,7 +189,7 @@ namespace ops {
nd4j::ops::matmul_bp mmul_bp; nd4j::ops::matmul_bp mmul_bp;
NDArray dLdw(weights.getShapeInfo(), block.workspace()); NDArray dLdw(weights.getShapeInfo(), block.workspace());
mmul_bp.execute({values, &weights, eps}, {dLdv, &dLdw}, {}, {}, {}); mmul_bp.execute({values, &weights, eps}, std::vector<NDArray*>{dLdv, &dLdw}, {}, {}, {});
NDArray dLds(preSoftmax.shapeInfo(), block.workspace()); NDArray dLds(preSoftmax.shapeInfo(), block.workspace());
nd4j::ops::softmax_bp softmax_bp; nd4j::ops::softmax_bp softmax_bp;
@ -198,7 +198,7 @@ namespace ops {
if(normalization) if(normalization)
dLds /= factor; dLds /= factor;
mmul_bp.execute({keys, queries, &dLds}, {dLdk, dLdq}, {}, {1}, {}); mmul_bp.execute({keys, queries, &dLds}, std::vector<NDArray*>{dLdk, dLdq}, {}, {1}, {});
return Status::OK(); return Status::OK();
} }

View File

@ -239,7 +239,7 @@ namespace ops {
auto epsPostReshape = epsPerm.reshape(eps->ordering(), {miniBatchSize * queryCount, outSize}); auto epsPostReshape = epsPerm.reshape(eps->ordering(), {miniBatchSize * queryCount, outSize});
nd4j::ops::matmul_bp matmulBp; nd4j::ops::matmul_bp matmulBp;
NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext()); NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext());
matmulBp.execute({&attnResults, Wo, &epsPostReshape}, {&dLdPreWo, dLdWo}, {}, {}, {}); matmulBp.execute({&attnResults, Wo, &epsPostReshape}, std::vector<NDArray*>{&dLdPreWo, dLdWo}, {}, {}, {});
// dLdAttn // dLdAttn
dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues.sizeAt(2)}); dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues.sizeAt(2)});

View File

@ -31,31 +31,28 @@ namespace ops {
CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) { CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", input->rankOf());
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
auto argI = *(block.getIArguments());
auto output = OUTPUT_VARIABLE(0);
const auto kH = INT_ARG(0); const auto kH = INT_ARG(0);
const auto kW = INT_ARG(1); const auto kW = INT_ARG(1);
const auto sH = INT_ARG(2); const auto sH = INT_ARG(2);
const auto sW = INT_ARG(3); const auto sW = INT_ARG(3);
int pH = INT_ARG(4); auto pH = INT_ARG(4);
int pW = INT_ARG(5); auto pW = INT_ARG(5);
const auto dH = INT_ARG(6); const auto dH = INT_ARG(6);
const auto dW = INT_ARG(7); const auto dW = INT_ARG(7);
const auto isSameMode = static_cast<bool>(INT_ARG(8)); const auto isSameMode = static_cast<bool>(INT_ARG(8));
const auto extraParam0 = INT_ARG(9); const auto extraParam0 = INT_ARG(9);
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D op: input should have rank of 4, but got %i instead", input->rankOf());
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW); REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
int oH = 0; int oH = 0;
int oW = 0; int oW = 0;
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
@ -207,7 +204,6 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
} }
return Status::OK(); return Status::OK();
} }
DECLARE_SHAPE_FN(avgpool2d_bp) { DECLARE_SHAPE_FN(avgpool2d_bp) {

View File

@ -51,14 +51,14 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2})); std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str()); REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
if(!isNCDHW) { if(!isNCDHW) {
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
@ -176,8 +176,8 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2})); std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2})); std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCDHW) { if(!isNCDHW) {
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]

View File

@ -32,6 +32,7 @@ namespace ops {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// maxpool2d corresponds to poolingMode=0 // maxpool2d corresponds to poolingMode=0
CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) { CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", input->rankOf());

View File

@ -40,7 +40,7 @@ namespace nd4j {
//nd4j_printf("Matrix x(%ix%i), Matrix w(%ix%i), b(1x%i)\n", x->sizeAt(0), x->sizeAt(1), w->sizeAt(0), w->sizeAt(1), b->lengthOf()); //nd4j_printf("Matrix x(%ix%i), Matrix w(%ix%i), b(1x%i)\n", x->sizeAt(0), x->sizeAt(1), w->sizeAt(0), w->sizeAt(1), b->lengthOf());
nd4j::ops::xw_plus_b op; nd4j::ops::xw_plus_b op;
std::unique_ptr<ResultSet> result(op.execute({x, w, b}, {}, {}, {})); std::unique_ptr<ResultSet> result(op.evaluate({x, w, b}));
REQUIRE_TRUE(Status::OK() == result->status(), 0, "relu_layer: xw_plus_b op failed on input data."); REQUIRE_TRUE(Status::OK() == result->status(), 0, "relu_layer: xw_plus_b op failed on input data.");
auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0;

View File

@ -34,7 +34,7 @@ namespace nd4j {
auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0); auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);
bitcast res; bitcast res;
auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, false); auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, {}, false);
if (tZ != &z0) { if (tZ != &z0) {
delete tZ; delete tZ;
} }

View File

@ -112,7 +112,7 @@ namespace ops {
NDArray originalIndices(*indices); //->ordering(), indices->shapeInfo(), indices->dataType()); NDArray originalIndices(*indices); //->ordering(), indices->shapeInfo(), indices->dataType());
originalIndices.linspace(0); originalIndices.linspace(0);
ops::dynamic_partition op; ops::dynamic_partition op;
auto res = op.execute({&originalIndices, indices}, {}, {numPartition}); auto res = op.evaluate({&originalIndices, indices}, {numPartition});
REQUIRE_TRUE(res->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning."); REQUIRE_TRUE(res->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
ops::dynamic_stitch stichOp; ops::dynamic_stitch stichOp;
std::vector<NDArray*> partitions(numPartition * 2); std::vector<NDArray*> partitions(numPartition * 2);
@ -121,7 +121,7 @@ namespace ops {
partitions[i + numPartition] = gradOutList[i]; partitions[i + numPartition] = gradOutList[i];
} }
auto result = stichOp.execute(partitions, {}, {numPartition}, {}, false); auto result = stichOp.evaluate(partitions, {numPartition});
REQUIRE_TRUE(result->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning."); REQUIRE_TRUE(result->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
result->at(0)->reshapei(outputList[0]->getShapeAsVector()); result->at(0)->reshapei(outputList[0]->getShapeAsVector());
outputList[1]->assign(indices); outputList[1]->assign(indices);

View File

@ -66,7 +66,7 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) {
nd4j::ops::gather op; nd4j::ops::gather op;
std::unique_ptr<ResultSet> result(op.execute({input, indeces}, {}, {0}, {})); std::unique_ptr<ResultSet> result(op.evaluate({input, indeces}, {0}));
REQUIRE_TRUE(result->status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op."); REQUIRE_TRUE(result->status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op.");
REQUIRE_TRUE(result->at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op."); REQUIRE_TRUE(result->at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op.");
output->assign(result->at(0)); output->assign(result->at(0));
@ -94,7 +94,7 @@ DECLARE_SHAPE_FN(embedding_lookup) {
for (int e = 1; e < outRank; e++) for (int e = 1; e < outRank; e++)
shapeInfo[e] = shape::sizeAt(inShapeInfo, e); shapeInfo[e] = shape::sizeAt(inShapeInfo, e);
auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), shape::order(inShapeInfo), shapeInfo); auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), shapeInfo);
return SHAPELIST(outShapeInfo); return SHAPELIST(outShapeInfo);
} }

View File

@ -32,6 +32,11 @@ namespace ops {
auto finish = INPUT_VARIABLE(1); auto finish = INPUT_VARIABLE(1);
auto numOfElements = INPUT_VARIABLE(2); auto numOfElements = INPUT_VARIABLE(2);
if (numOfElements->e<Nd4jLong>(0) == 1) {
output->assign(start);
return Status::OK();
}
output->linspace(start->e<double>(0), (finish->e<double>(0) - start->e<double>(0)) / (numOfElements->e<Nd4jLong>(0) - 1.)); output->linspace(start->e<double>(0), (finish->e<double>(0) - start->e<double>(0)) / (numOfElements->e<Nd4jLong>(0) - 1.));
return Status::OK(); return Status::OK();
} }

View File

@ -74,6 +74,8 @@ namespace nd4j {
DECLARE_SHAPE_FN(onehot) { DECLARE_SHAPE_FN(onehot) {
auto inShape = inputShape->at(0); auto inShape = inputShape->at(0);
nd4j::DataType dtype = block.numD() > 0 ? D_ARG(0) : nd4j::DataType::FLOAT32;
int depth = -1; int depth = -1;
Nd4jLong axis = -1; Nd4jLong axis = -1;
@ -99,7 +101,7 @@ namespace nd4j {
shape.push_back(shape::shapeOf(inShape)[e]); shape.push_back(shape::shapeOf(inShape)[e]);
shape.insert(shape.begin() + axis, depth); shape.insert(shape.begin() + axis, depth);
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', rank + 1, shape.data()); newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', rank + 1, shape.data());
return SHAPELIST(newShape); return SHAPELIST(newShape);
} }

View File

@ -25,7 +25,7 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
OP_IMPL(ones_as, 1, 1, false) { CUSTOM_OP_IMPL(ones_as, 1, 1, false, 0, 0) {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
output->assign(1); output->assign(1);
@ -33,11 +33,21 @@ namespace nd4j {
return Status::OK(); return Status::OK();
} }
DECLARE_SHAPE_FN(ones_as) {
auto in = inputShape->at(0);
auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in);
auto shape = nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in);
nd4j_printf("numD: %i; dtype: %s\n", block.numD(), DataTypeUtils::asString(dtype).c_str());
return SHAPELIST(shape);
}
DECLARE_TYPES(ones_as) { DECLARE_TYPES(ones_as) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(nd4j::DataType::ANY)
->setSameMode(true); ->setSameMode(false);
} }
} }
} }

View File

@ -130,7 +130,7 @@ DECLARE_SHAPE_FN(range) {
const int numIArgs = block.getIArguments()->size(); const int numIArgs = block.getIArguments()->size();
Nd4jLong steps = 0; Nd4jLong steps = 0;
nd4j::DataType dataType = nd4j::DataType::INHERIT; nd4j::DataType dataType = block.numD() ? D_ARG(0) : nd4j::DataType::INHERIT;
if (numInArrs > 0) { if (numInArrs > 0) {
auto isR = INPUT_VARIABLE(0)->isR(); auto isR = INPUT_VARIABLE(0)->isR();
@ -159,6 +159,8 @@ DECLARE_SHAPE_FN(range) {
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
steps = static_cast<Nd4jLong >((limit - start) / delta); steps = static_cast<Nd4jLong >((limit - start) / delta);
if (!block.numD())
dataType = INPUT_VARIABLE(0)->dataType(); dataType = INPUT_VARIABLE(0)->dataType();
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit)) if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
@ -187,6 +189,8 @@ DECLARE_SHAPE_FN(range) {
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
steps = static_cast<Nd4jLong >((limit - start) / delta); steps = static_cast<Nd4jLong >((limit - start) / delta);
if (!block.numD())
dataType = INPUT_VARIABLE(0)->dataType(); dataType = INPUT_VARIABLE(0)->dataType();
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit)) if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
@ -214,10 +218,12 @@ DECLARE_SHAPE_FN(range) {
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
if (!block.numD()) {
if (limit > DataTypeUtils::max<int>()) if (limit > DataTypeUtils::max<int>())
dataType = nd4j::DataType::INT64; dataType = nd4j::DataType::INT64;
else else
dataType = nd4j::DataType::INT32; dataType = nd4j::DataType::INT32;
}
steps = (limit - start) / delta; steps = (limit - start) / delta;
@ -248,10 +254,13 @@ DECLARE_SHAPE_FN(range) {
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
steps = static_cast<Nd4jLong >((limit - start) / delta); steps = static_cast<Nd4jLong >((limit - start) / delta);
if (!block.numD()) {
if (Environment::getInstance()->precisionBoostAllowed()) if (Environment::getInstance()->precisionBoostAllowed())
dataType = nd4j::DataType::DOUBLE; dataType = nd4j::DataType::DOUBLE;
else else
dataType = Environment::getInstance()->defaultFloatDataType(); dataType = Environment::getInstance()->defaultFloatDataType();
}
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit)) if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
++steps; ++steps;

View File

@ -0,0 +1,75 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit, K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by GS <sgazeos@gmail.com> at 01/22/2020
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_solve)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/solve.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(solve, 2, 1, false, 0, 0) {
auto a = INPUT_VARIABLE(0);
auto b = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
bool useAdjoint = false;
if (block.numB() > 0) {
useAdjoint = B_ARG(0);
}
REQUIRE_TRUE(a->rankOf() >=2, 0, "solve: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf());
REQUIRE_TRUE(b->rankOf() >=2, 0, "solve: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf());
REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "solve: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2));
REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, "solve: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2));
auto input = a;
if (useAdjoint) {
auto adjointA = a->ulike();
helpers::adjointMatrix(block.launchContext(), a, &adjointA);
input = new NDArray(adjointA); //.detach();
};
auto res = helpers::solveFunctor(block.launchContext(), input, b, useAdjoint, z);
if (input != a)
delete input;
return Status::OK();
}
DECLARE_SHAPE_FN(solve) {
auto in0 = inputShape->at(1);
auto in1 = inputShape->at(1);
auto luShape = ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace());
return SHAPELIST(CONSTANT(luShape));
}
DECLARE_TYPES(solve) {
getOpDescriptor()
->setAllowedInputTypes({ALL_FLOATS})
->setAllowedOutputTypes({ALL_FLOATS})
->setSameMode(false);
}
}
}
#endif

View File

@ -25,7 +25,7 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
OP_IMPL(zeros_as, 1, 1, false) { CUSTOM_OP_IMPL(zeros_as, 1, 1, false, 0, 0) {
auto out = OUTPUT_VARIABLE(0); auto out = OUTPUT_VARIABLE(0);
out->assign(0); // output is filled by zero by default out->assign(0); // output is filled by zero by default
@ -35,11 +35,20 @@ namespace nd4j {
DECLARE_SYN(zeroslike, zeros_as); DECLARE_SYN(zeroslike, zeros_as);
DECLARE_SYN(zeros_like, zeros_as); DECLARE_SYN(zeros_like, zeros_as);
DECLARE_SHAPE_FN(zeros_as) {
auto in = inputShape->at(0);
auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in);
auto shape = nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in);
return SHAPELIST(shape);
}
DECLARE_TYPES(zeros_as) { DECLARE_TYPES(zeros_as) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(nd4j::DataType::ANY)
->setSameMode(true); ->setSameMode(false);
} }
} }
} }

View File

@ -84,7 +84,7 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
// forward steps // forward steps
nd4j::ops::dynamic_rnn dynamicRnn; nd4j::ops::dynamic_rnn dynamicRnn;
auto resultsFW = dynamicRnn.execute({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {}, {timeMajor}, {}, false, x->dataType()); auto resultsFW = dynamicRnn.evaluate({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {timeMajor});
hFW->assign(resultsFW->at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW] hFW->assign(resultsFW->at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW]
hFWFinal->assign(resultsFW->at(1)); hFWFinal->assign(resultsFW->at(1));
@ -97,17 +97,17 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
// reverse x // reverse x
nd4j::ops::reverse_sequence reverse; nd4j::ops::reverse_sequence reverse;
auto resultsIn = timeMajor ? reverse.execute({x, seqLen}, {}, {0, 1}, {}, false, x->dataType()) : reverse.execute({x, seqLen}, {}, {1, 0}, {}, false, x->dataType()); auto resultsIn = timeMajor ? reverse.evaluate({x, seqLen}, {0, 1}) : reverse.evaluate({x, seqLen}, {1, 0});
REQUIRE_TRUE (resultsIn->status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence."); REQUIRE_TRUE (resultsIn->status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence.");
auto revInput = resultsIn->at(0); auto revInput = resultsIn->at(0);
// backward steps // backward steps
auto resultsBW = dynamicRnn.execute({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {}, {timeMajor}, {}); auto resultsBW = dynamicRnn.evaluate({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor});
auto hBWtemp = resultsBW->at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW] auto hBWtemp = resultsBW->at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW]
hBWFinal->assign(resultsBW->at(1)); hBWFinal->assign(resultsBW->at(1));
// reverse hBWtemp // reverse hBWtemp
auto resultsOut = timeMajor ? reverse.execute({hBWtemp, seqLen}, {}, {0, 1}, {}) : reverse.execute({hBWtemp, seqLen}, {}, {1, 0}, {}); auto resultsOut = timeMajor ? reverse.evaluate({hBWtemp, seqLen}, {0, 1}) : reverse.evaluate({hBWtemp, seqLen}, {1, 0});
hBW->assign(resultsOut->at(0)); hBW->assign(resultsOut->at(0));
delete resultsOut; delete resultsOut;

View File

@ -48,7 +48,7 @@ namespace ops {
auto conv = ArrayUtils::toLongVector(*block.getIArguments()); auto conv = ArrayUtils::toLongVector(*block.getIArguments());
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), shape::order(in), conv); auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in), shape::order(in), conv);
return SHAPELIST(newShape); return SHAPELIST(newShape);
} }

View File

@ -487,7 +487,7 @@ namespace nd4j {
* *
*/ */
#if NOT_EXCLUDED(OP_zeros_as) #if NOT_EXCLUDED(OP_zeros_as)
DECLARE_OP(zeros_as, 1, 1, false); DECLARE_CUSTOM_OP(zeros_as, 1, 1, false, 0, 0);
#endif #endif
/** /**
@ -497,7 +497,7 @@ namespace nd4j {
* *
*/ */
#if NOT_EXCLUDED(OP_ones_as) #if NOT_EXCLUDED(OP_ones_as)
DECLARE_OP(ones_as, 1, 1, false); DECLARE_CUSTOM_OP(ones_as, 1, 1, false, 0, 0);
#endif #endif
/** /**
@ -1076,6 +1076,24 @@ namespace nd4j {
DECLARE_CUSTOM_OP(triangular_solve, 2, 1, true, 0, 0); DECLARE_CUSTOM_OP(triangular_solve, 2, 1, true, 0, 0);
#endif #endif
/**
* solve op. - solve systems of linear equations - general method.
*
* input params:
* 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations
* 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations
*
* boolean args:
* 0 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used
*
* return value:
* tensor with dimension (x * y * z * ::: * M * K) with solutions
*
*/
#if NOT_EXCLUDED(OP_solve)
DECLARE_CUSTOM_OP(solve, 2, 1, true, 0, 0);
#endif
/** /**
* lu op. - make LUP decomposition of given batch of 2D square matricies * lu op. - make LUP decomposition of given batch of 2D square matricies
* *

View File

@ -237,11 +237,46 @@ namespace helpers {
samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1); samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1);
} }
template <typename T>
static void doolitleLU(LaunchContext* context, NDArray* compound, Nd4jLong rowNum) {
auto input = compound->dup();
compound->nullify();
// Decomposing matrix into Upper and Lower
// triangular matrix
for (auto i = 0; i < rowNum; i++) {
// Upper Triangular
for (auto k = i; k < rowNum; k++) {
// Summation of L(i, j) * U(j, k)
int sum = 0;
for (int j = 0; j < i; j++)
sum += compound->t<T>(i,j) * compound->t<T>(j,k);
// Evaluating U(i, k)
compound->t<T>(i, k) = input.t<T>(i, k) - sum;
}
// Lower Triangular
for (int k = i + 1; k < rowNum; k++) {
// Summation of L(k, j) * U(j, i)
int sum = 0;
for (int j = 0; j < i; j++)
sum += compound->t<T>(k,j) * compound->t<T>(j, i);
// Evaluating L(k, i)
compound->t<T>(k, i) = (input.t<T>(k, i) - sum) / compound->t<T>(i,i);
}
}
}
template <typename T, typename I> template <typename T, typename I>
static void luNN_(LaunchContext *context, NDArray* compound, NDArray* permutation, Nd4jLong rowNum) { static void luNN_(LaunchContext *context, NDArray* compound, NDArray* permutation, Nd4jLong rowNum) {
//const int rowNum = compound->rows(); //const int rowNum = compound->rows();
// const int columnNum = output->columns(); // const int columnNum = output->columns();
if (permutation) { // LUP algorithm
permutation->linspace(0); permutation->linspace(0);
auto permutationBuf = permutation->bufferAsT<I>(); //dataBuffer()->primaryAsT<I>(); auto permutationBuf = permutation->bufferAsT<I>(); //dataBuffer()->primaryAsT<I>();
auto compoundBuf = compound->bufferAsT<T>(); auto compoundBuf = compound->bufferAsT<T>();
@ -252,12 +287,17 @@ namespace helpers {
if (pivotIndex < 0) { if (pivotIndex < 0) {
throw std::runtime_error("helpers::luNN_: input matrix is singular."); throw std::runtime_error("helpers::luNN_: input matrix is singular.");
} }
math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)], permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]); math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)],
permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]);
swapRows(compoundBuf, compoundShape, i, pivotIndex); swapRows(compoundBuf, compoundShape, i, pivotIndex);
processColumns(i, rowNum, compoundBuf, compoundShape); processColumns(i, rowNum, compoundBuf, compoundShape);
} }
} }
else { // Doolitle algorithm with LU decomposition
doolitleLU<T>(context, compound, rowNum);
}
}
template <typename T, typename I> template <typename T, typename I>
static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) { static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) {
@ -265,17 +305,20 @@ namespace helpers {
output->assign(input); // fill up output tensor with zeros output->assign(input); // fill up output tensor with zeros
ResultSet outputs = output->allTensorsAlongDimension({-2, -1}); ResultSet outputs = output->allTensorsAlongDimension({-2, -1});
ResultSet permutations = permutationVectors->allTensorsAlongDimension({-1}); ResultSet permutations;
if (permutationVectors)
permutations = permutationVectors->allTensorsAlongDimension({-1});
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i += increment) {
luNN_<T, I>(context, outputs.at(i), permutations.at(i), n); luNN_<T, I>(context, outputs.at(i), permutationVectors?permutations.at(i):nullptr, n);
} }
}; };
samediff::Threads::parallel_for(loop, 0, outputs.size(), 1); samediff::Threads::parallel_for(loop, 0, outputs.size(), 1);
} }
void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) { void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) {
BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(input->dataType(), permutation?permutation->dataType():DataType::INT32, lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES);
} }
// BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES); // BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES);

View File

@ -0,0 +1,100 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit, K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author GS <sgazeos@gmail.com>
//
#include <op_boilerplate.h>
#include <NDArray.h>
#include <NDArrayFactory.h>
#include <execution/Threads.h>
#include <helpers/MmulHelper.h>
#include "../triangular_solve.h"
#include "../lup.h"
#include "../solve.h"
namespace nd4j {
namespace ops {
namespace helpers {
// --------------------------------------------------------------------------------------------------------------------------------------- //
template <typename T>
static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
auto inputPart = input->allTensorsAlongDimension({-2, -1});
auto outputPart = output->allTensorsAlongDimension({-2, -1});
output->assign(input);
auto batchLoop = PRAGMA_THREADS_FOR {
for (auto batch = start; batch < stop; batch += increment) {
for (auto r = 0; r < input->rows(); r++) {
for (auto c = 0; c < r; c++) {
math::nd4j_swap(outputPart[batch]->t<T>(r, c) , outputPart[batch]->t<T>(c, r));
}
}
}
};
samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1);
}
// --------------------------------------------------------------------------------------------------------------------------------------- //
template <typename T>
static int solveFunctor_(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) {
// stage 1: LU decomposition batched
auto leftOutput = leftInput->ulike();
auto permuShape = rightInput->getShapeAsVector(); permuShape.pop_back();
auto permutations = NDArrayFactory::create<int>('c', permuShape, context);
helpers::lu(context, leftInput, &leftOutput, &permutations);
auto P = leftInput->ulike(); //permutations batched matrix
P.nullify(); // to fill up matricies with zeros
auto PPart = P.allTensorsAlongDimension({-2,-1});
auto permutationsPart = permutations.allTensorsAlongDimension({-1});
for (auto batch = 0; batch < permutationsPart.size(); ++batch) {
for (auto row = 0; row < PPart[batch]->rows(); ++row) {
PPart[batch]->t<T>(row, permutationsPart[batch]->t<int>(row)) = T(1.f);
}
}
auto leftLower = leftOutput.dup();
auto rightOutput = rightInput->ulike();
auto rightPermuted = rightOutput.ulike();
MmulHelper::matmul(&P, rightInput, &rightPermuted, 0, 0);
ResultSet leftLowerPart = leftLower.allTensorsAlongDimension({-2, -1});
for (auto i = 0; i < leftLowerPart.size(); i++) {
for (auto r = 0; r < leftLowerPart[i]->rows(); r++)
leftLowerPart[i]->t<T>(r,r) = (T)1.f;
}
// stage 2: triangularSolveFunctor for Lower with given b
helpers::triangularSolveFunctor(context, &leftLower, &rightPermuted, true, false, &rightOutput);
// stage 3: triangularSolveFunctor for Upper with output of previous stage
helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, false, output);
return Status::OK();
}
// --------------------------------------------------------------------------------------------------------------------------------------- //
int solveFunctor(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) {
BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, (context, leftInput, rightInput, adjoint, output), FLOAT_TYPES);
}
// --------------------------------------------------------------------------------------------------------------------------------------- //
void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, (context, input, output), FLOAT_TYPES);
}
// --------------------------------------------------------------------------------------------------------------------------------------- //
}
}
}

View File

@ -41,13 +41,16 @@ namespace helpers {
template <typename T> template <typename T>
static void lowerTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { static void lowerTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
auto rows = leftInput->rows(); auto rows = leftInput->rows();
auto cols = rightInput->columns();
//output->t<T>(0,0) = rightInput->t<T>(0,0) / leftInput->t<T>(0,0); //output->t<T>(0,0) = rightInput->t<T>(0,0) / leftInput->t<T>(0,0);
for (auto r = 0; r < rows; r++) { for (auto r = 0; r < rows; r++) {
auto sum = rightInput->t<T>(r, 0); for (auto j = 0; j < cols; j++) {
auto sum = rightInput->t<T>(r, j);
for (auto c = 0; c < r; c++) { for (auto c = 0; c < r; c++) {
sum -= leftInput->t<T>(r,c) * output->t<T>(c, 0); sum -= leftInput->t<T>(r, c) * output->t<T>(c, j);
}
output->t<T>(r, j) = sum / leftInput->t<T>(r, r);
} }
output->t<T>(r, 0) = sum / leftInput->t<T>(r, r);
} }
} }
@ -68,13 +71,15 @@ namespace helpers {
template <typename T> template <typename T>
static void upperTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { static void upperTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) {
auto rows = leftInput->rows(); auto rows = leftInput->rows();
auto cols = rightInput->columns();
for (auto r = rows; r > 0; r--) { for (auto r = rows; r > 0; r--) {
auto sum = rightInput->t<T>(r - 1, 0); for (auto j = 0; j < cols; j++) {
auto sum = rightInput->t<T>(r - 1, j);
for (auto c = r; c < rows; c++) { for (auto c = r; c < rows; c++) {
sum -= leftInput->t<T>(r - 1, c) * output->t<T>(c, 0); sum -= leftInput->t<T>(r - 1, c) * output->t<T>(c, j);
}
output->t<T>(r - 1, j) = sum / leftInput->t<T>(r - 1, r - 1);
} }
output->t<T>(r - 1, 0) = sum / leftInput->t<T>(r - 1, r - 1);
} }
} }

Some files were not shown because too many files have changed in this diff Show More